Mirko Trasciatti
commited on
Commit
·
a9a341a
1
Parent(s):
dc9a383
Match reference Space: process frames 0→N sequentially, SAM2 handles bidirectional propagation internally
Browse files
app.py
CHANGED
|
@@ -492,64 +492,41 @@ def segment_video_multi(video_file, objects_json):
|
|
| 492 |
video_segments = {}
|
| 493 |
confidence_scores = []
|
| 494 |
|
| 495 |
-
print(f" Propagating masks
|
|
|
|
| 496 |
|
| 497 |
with torch.inference_mode():
|
| 498 |
-
#
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
for frame_idx in range(init_frame + 1, len(video_frames)):
|
| 530 |
-
frame_pil = video_frames[frame_idx]
|
| 531 |
-
pixel_values = None
|
| 532 |
-
if inference_session.processed_frames is None or frame_idx not in inference_session.processed_frames:
|
| 533 |
-
pixel_values = processor(images=frame_pil, device=device, return_tensors="pt").pixel_values[0]
|
| 534 |
-
|
| 535 |
-
sam2_output = model(
|
| 536 |
-
inference_session=inference_session,
|
| 537 |
-
frame=pixel_values,
|
| 538 |
-
frame_idx=frame_idx
|
| 539 |
-
)
|
| 540 |
-
|
| 541 |
-
H = inference_session.video_height
|
| 542 |
-
W = inference_session.video_width
|
| 543 |
-
pred_masks = sam2_output.pred_masks.detach().cpu()
|
| 544 |
-
video_res_masks = processor.post_process_masks(
|
| 545 |
-
[pred_masks],
|
| 546 |
-
original_sizes=[[H, W]],
|
| 547 |
-
binarize=False
|
| 548 |
-
)[0]
|
| 549 |
-
|
| 550 |
-
video_segments[frame_idx] = video_res_masks
|
| 551 |
-
mask_float = video_res_masks.float() if video_res_masks.dtype == torch.bool else video_res_masks
|
| 552 |
-
confidence_scores.append(float(mask_float.mean()))
|
| 553 |
|
| 554 |
print(f" ✅ Got masks for {len(video_segments)} frames (init_frame was {init_frame})")
|
| 555 |
|
|
|
|
| 492 |
video_segments = {}
|
| 493 |
confidence_scores = []
|
| 494 |
|
| 495 |
+
print(f" Propagating masks through all frames (0 → {len(video_frames)-1})...")
|
| 496 |
+
print(f" Annotation at frame {init_frame} will guide propagation")
|
| 497 |
|
| 498 |
with torch.inference_mode():
|
| 499 |
+
# Process ALL frames in sequential order (0→N)
|
| 500 |
+
# SAM2's temporal model expects sequential processing
|
| 501 |
+
# The annotated frame (init_frame) is already in processed_frames
|
| 502 |
+
for frame_idx in range(len(video_frames)):
|
| 503 |
+
frame_pil = video_frames[frame_idx]
|
| 504 |
+
pixel_values = None
|
| 505 |
+
|
| 506 |
+
# Check if this frame was already processed (e.g., the annotated frame)
|
| 507 |
+
if inference_session.processed_frames is None or frame_idx not in inference_session.processed_frames:
|
| 508 |
+
pixel_values = processor(images=frame_pil, device=device, return_tensors="pt").pixel_values[0]
|
| 509 |
+
|
| 510 |
+
# Call model - it will use annotation if frame_idx == init_frame
|
| 511 |
+
sam2_output = model(
|
| 512 |
+
inference_session=inference_session,
|
| 513 |
+
frame=pixel_values,
|
| 514 |
+
frame_idx=frame_idx
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
# Post-process masks
|
| 518 |
+
H = inference_session.video_height
|
| 519 |
+
W = inference_session.video_width
|
| 520 |
+
pred_masks = sam2_output.pred_masks.detach().cpu()
|
| 521 |
+
video_res_masks = processor.post_process_masks(
|
| 522 |
+
[pred_masks],
|
| 523 |
+
original_sizes=[[H, W]],
|
| 524 |
+
binarize=False
|
| 525 |
+
)[0]
|
| 526 |
+
|
| 527 |
+
video_segments[frame_idx] = video_res_masks
|
| 528 |
+
mask_float = video_res_masks.float() if video_res_masks.dtype == torch.bool else video_res_masks
|
| 529 |
+
confidence_scores.append(float(mask_float.mean()))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
|
| 531 |
print(f" ✅ Got masks for {len(video_segments)} frames (init_frame was {init_frame})")
|
| 532 |
|