Mirko Trasciatti commited on
Commit
a9a341a
·
1 Parent(s): dc9a383

Match reference Space: process frames 0→N sequentially, SAM2 handles bidirectional propagation internally

Browse files
Files changed (1) hide show
  1. app.py +33 -56
app.py CHANGED
@@ -492,64 +492,41 @@ def segment_video_multi(video_file, objects_json):
492
  video_segments = {}
493
  confidence_scores = []
494
 
495
- print(f" Propagating masks bidirectionally from frame {init_frame}...")
 
496
 
497
  with torch.inference_mode():
498
- # STEP 1: Process BACKWARD from init_frame to 0
499
- if init_frame > 0:
500
- print(f" Backward: frames {init_frame} 0")
501
- for frame_idx in range(init_frame, -1, -1):
502
- frame_pil = video_frames[frame_idx]
503
- pixel_values = None
504
- if inference_session.processed_frames is None or frame_idx not in inference_session.processed_frames:
505
- pixel_values = processor(images=frame_pil, device=device, return_tensors="pt").pixel_values[0]
506
-
507
- sam2_output = model(
508
- inference_session=inference_session,
509
- frame=pixel_values,
510
- frame_idx=frame_idx
511
- )
512
-
513
- H = inference_session.video_height
514
- W = inference_session.video_width
515
- pred_masks = sam2_output.pred_masks.detach().cpu()
516
- video_res_masks = processor.post_process_masks(
517
- [pred_masks],
518
- original_sizes=[[H, W]],
519
- binarize=False
520
- )[0]
521
-
522
- video_segments[frame_idx] = video_res_masks
523
- mask_float = video_res_masks.float() if video_res_masks.dtype == torch.bool else video_res_masks
524
- confidence_scores.append(float(mask_float.mean()))
525
-
526
- # STEP 2: Process FORWARD from init_frame+1 to end
527
- if init_frame < len(video_frames) - 1:
528
- print(f" Forward: frames {init_frame+1} → {len(video_frames)-1}")
529
- for frame_idx in range(init_frame + 1, len(video_frames)):
530
- frame_pil = video_frames[frame_idx]
531
- pixel_values = None
532
- if inference_session.processed_frames is None or frame_idx not in inference_session.processed_frames:
533
- pixel_values = processor(images=frame_pil, device=device, return_tensors="pt").pixel_values[0]
534
-
535
- sam2_output = model(
536
- inference_session=inference_session,
537
- frame=pixel_values,
538
- frame_idx=frame_idx
539
- )
540
-
541
- H = inference_session.video_height
542
- W = inference_session.video_width
543
- pred_masks = sam2_output.pred_masks.detach().cpu()
544
- video_res_masks = processor.post_process_masks(
545
- [pred_masks],
546
- original_sizes=[[H, W]],
547
- binarize=False
548
- )[0]
549
-
550
- video_segments[frame_idx] = video_res_masks
551
- mask_float = video_res_masks.float() if video_res_masks.dtype == torch.bool else video_res_masks
552
- confidence_scores.append(float(mask_float.mean()))
553
 
554
  print(f" ✅ Got masks for {len(video_segments)} frames (init_frame was {init_frame})")
555
 
 
492
  video_segments = {}
493
  confidence_scores = []
494
 
495
+ print(f" Propagating masks through all frames (0 → {len(video_frames)-1})...")
496
+ print(f" Annotation at frame {init_frame} will guide propagation")
497
 
498
  with torch.inference_mode():
499
+ # Process ALL frames in sequential order (0→N)
500
+ # SAM2's temporal model expects sequential processing
501
+ # The annotated frame (init_frame) is already in processed_frames
502
+ for frame_idx in range(len(video_frames)):
503
+ frame_pil = video_frames[frame_idx]
504
+ pixel_values = None
505
+
506
+ # Check if this frame was already processed (e.g., the annotated frame)
507
+ if inference_session.processed_frames is None or frame_idx not in inference_session.processed_frames:
508
+ pixel_values = processor(images=frame_pil, device=device, return_tensors="pt").pixel_values[0]
509
+
510
+ # Call model - it will use annotation if frame_idx == init_frame
511
+ sam2_output = model(
512
+ inference_session=inference_session,
513
+ frame=pixel_values,
514
+ frame_idx=frame_idx
515
+ )
516
+
517
+ # Post-process masks
518
+ H = inference_session.video_height
519
+ W = inference_session.video_width
520
+ pred_masks = sam2_output.pred_masks.detach().cpu()
521
+ video_res_masks = processor.post_process_masks(
522
+ [pred_masks],
523
+ original_sizes=[[H, W]],
524
+ binarize=False
525
+ )[0]
526
+
527
+ video_segments[frame_idx] = video_res_masks
528
+ mask_float = video_res_masks.float() if video_res_masks.dtype == torch.bool else video_res_masks
529
+ confidence_scores.append(float(mask_float.mean()))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
 
531
  print(f" ✅ Got masks for {len(video_segments)} frames (init_frame was {init_frame})")
532