Skip to content

Commit

Permalink
sam2 large video support (#335)
Browse files Browse the repository at this point in the history
split, process and merge
  • Loading branch information
hrnn authored Jan 7, 2025
1 parent 7daa162 commit df3fff5
Show file tree
Hide file tree
Showing 2 changed files with 378 additions and 86 deletions.
159 changes: 73 additions & 86 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import urllib.request
from base64 import b64encode
from concurrent.futures import ThreadPoolExecutor, as_completed
from enum import Enum
from importlib import resources
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, cast
Expand Down Expand Up @@ -54,6 +53,13 @@
frames_to_bytes,
video_writer,
)
from vision_agent.utils.video_tracking import (
ODModels,
merge_segments,
post_process,
process_segment,
split_frames_into_segments,
)

register_heif_opener()

Expand Down Expand Up @@ -224,118 +230,99 @@ def sam2(
return ret["return_data"] # type: ignore


class ODModels(str, Enum):
COUNTGD = "countgd"
FLORENCE2 = "florence2"
OWLV2 = "owlv2"


def od_sam2_video_tracking(
od_model: ODModels,
prompt: str,
frames: List[np.ndarray],
chunk_length: Optional[int] = 10,
fine_tune_id: Optional[str] = None,
) -> Dict[str, Any]:
results: List[Optional[List[Dict[str, Any]]]] = [None] * len(frames)
SEGMENT_SIZE = 50
OVERLAP = 1 # Number of overlapping frames between segments

if chunk_length is None:
step = 1 # Process every frame
elif chunk_length <= 0:
raise ValueError("chunk_length must be a positive integer or None.")
else:
step = chunk_length # Process frames with the specified step size
image_size = frames[0].shape[:2]

# Split frames into segments with overlap
segments = split_frames_into_segments(frames, SEGMENT_SIZE, OVERLAP)

def _apply_object_detection( # inner method to avoid circular importing issues.
od_model: ODModels,
prompt: str,
segment_index: int,
frame_number: int,
fine_tune_id: str,
segment_frames: list,
) -> tuple:
"""
Applies the specified object detection model to the given image.
Args:
od_model: The object detection model to use.
prompt: The prompt for the object detection model.
segment_index: The index of the current segment.
frame_number: The number of the current frame.
fine_tune_id: Optional fine-tune ID for the model.
segment_frames: List of frames for the current segment.
Returns:
A tuple containing the object detection results and the name of the function used.
"""

for idx in range(0, len(frames), step):
if od_model == ODModels.COUNTGD:
results[idx] = countgd_object_detection(prompt=prompt, image=frames[idx])
segment_results = countgd_object_detection(
prompt=prompt, image=segment_frames[frame_number]
)
function_name = "countgd_object_detection"

elif od_model == ODModels.OWLV2:
results[idx] = owlv2_object_detection(
prompt=prompt, image=frames[idx], fine_tune_id=fine_tune_id
segment_results = owlv2_object_detection(
prompt=prompt,
image=segment_frames[frame_number],
fine_tune_id=fine_tune_id,
)
function_name = "owlv2_object_detection"

elif od_model == ODModels.FLORENCE2:
results[idx] = florence2_object_detection(
prompt=prompt, image=frames[idx], fine_tune_id=fine_tune_id
segment_results = florence2_object_detection(
prompt=prompt,
image=segment_frames[frame_number],
fine_tune_id=fine_tune_id,
)
function_name = "florence2_object_detection"

else:
raise NotImplementedError(
f"Object detection model '{od_model}' is not implemented."
)

image_size = frames[0].shape[:2]

def _transform_detections(
input_list: List[Optional[List[Dict[str, Any]]]],
) -> List[Optional[Dict[str, Any]]]:
output_list: List[Optional[Dict[str, Any]]] = []

for _, frame in enumerate(input_list):
if frame is not None:
labels = [detection["label"] for detection in frame]
bboxes = [
denormalize_bbox(detection["bbox"], image_size)
for detection in frame
]

output_list.append(
{
"labels": labels,
"bboxes": bboxes,
}
)
else:
output_list.append(None)

return output_list
return segment_results, function_name

# Process each segment and collect detections
detections_per_segment: List[Any] = []
for segment_index, segment in enumerate(segments):
segment_detections = process_segment(
segment_frames=segment,
od_model=od_model,
prompt=prompt,
fine_tune_id=fine_tune_id,
chunk_length=chunk_length,
image_size=image_size,
segment_index=segment_index,
object_detection_tool=_apply_object_detection,
)
detections_per_segment.append(segment_detections)

output = _transform_detections(results)
merged_detections = merge_segments(detections_per_segment)
post_processed = post_process(merged_detections, image_size)

buffer_bytes = frames_to_bytes(frames)
files = [("video", buffer_bytes)]
payload = {"bboxes": json.dumps(output), "chunk_length_frames": chunk_length}
metadata = {"function_name": function_name}

detections = send_task_inference_request(
payload,
"sam2",
files=files,
metadata=metadata,
)

return_data = []
for frame in detections:
return_frame_data = []
for detection in frame:
mask = rle_decode_array(detection["mask"])
label = str(detection["id"]) + ": " + detection["label"]
return_frame_data.append(
{"label": label, "mask": mask, "score": 1.0, "rle": detection["mask"]}
)
return_data.append(return_frame_data)
return_data = add_bboxes_from_masks(return_data)
return_data = nms(return_data, iou_threshold=0.95)

# We save the RLE for display purposes, re-calculting RLE can get very expensive.
# Deleted here because we are returning the numpy masks instead
display_data = []
for frame in return_data:
display_frame_data = []
for obj in frame:
display_frame_data.append(
{
"label": obj["label"],
"score": obj["score"],
"bbox": denormalize_bbox(obj["bbox"], image_size),
"mask": obj["rle"],
}
)
del obj["rle"]
display_data.append(display_frame_data)

return {"files": files, "return_data": return_data, "display_data": detections}
return {
"files": files,
"return_data": post_processed["return_data"],
"display_data": post_processed["display_data"],
}


# Owl V2 Tools
Expand Down
Loading

0 comments on commit df3fff5

Please sign in to comment.