Mirko Trasciatti commited on
Commit
cf42079
·
1 Parent(s): a9a341a

REVERT to propagate_in_video_iterator - it DOES support bidirectional propagation!

Browse files
Files changed (1) hide show
  1. app.py +15 -37
app.py CHANGED
@@ -486,47 +486,25 @@ def segment_video_multi(video_file, objects_json):
486
  # Skip initial model inference - go straight to propagation
487
  # The propagation loop will handle init_frame when it reaches it
488
 
489
- # Propagate through ALL frames explicitly (frame-by-frame)
490
- # This ensures bidirectional propagation from init_frame
491
- # Based on: https://huggingface.co/spaces/yonigozlan/Segment-Anything-2-video-tracking
492
  video_segments = {}
493
  confidence_scores = []
494
 
495
- print(f" Propagating masks through all frames (0 {len(video_frames)-1})...")
496
- print(f" Annotation at frame {init_frame} will guide propagation")
497
 
498
- with torch.inference_mode():
499
- # Process ALL frames in sequential order (0→N)
500
- # SAM2's temporal model expects sequential processing
501
- # The annotated frame (init_frame) is already in processed_frames
502
- for frame_idx in range(len(video_frames)):
503
- frame_pil = video_frames[frame_idx]
504
- pixel_values = None
505
-
506
- # Check if this frame was already processed (e.g., the annotated frame)
507
- if inference_session.processed_frames is None or frame_idx not in inference_session.processed_frames:
508
- pixel_values = processor(images=frame_pil, device=device, return_tensors="pt").pixel_values[0]
509
-
510
- # Call model - it will use annotation if frame_idx == init_frame
511
- sam2_output = model(
512
- inference_session=inference_session,
513
- frame=pixel_values,
514
- frame_idx=frame_idx
515
- )
516
-
517
- # Post-process masks
518
- H = inference_session.video_height
519
- W = inference_session.video_width
520
- pred_masks = sam2_output.pred_masks.detach().cpu()
521
- video_res_masks = processor.post_process_masks(
522
- [pred_masks],
523
- original_sizes=[[H, W]],
524
- binarize=False
525
- )[0]
526
-
527
- video_segments[frame_idx] = video_res_masks
528
- mask_float = video_res_masks.float() if video_res_masks.dtype == torch.bool else video_res_masks
529
- confidence_scores.append(float(mask_float.mean()))
530
 
531
  print(f" ✅ Got masks for {len(video_segments)} frames (init_frame was {init_frame})")
532
 
 
486
  # Skip initial model inference - go straight to propagation
487
  # The propagation loop will handle init_frame when it reaches it
488
 
489
+ # Use propagate_in_video_iterator for BIDIRECTIONAL propagation
490
+ # According to SAM2 docs, this should propagate both forward AND backward
491
+ # from the annotated frame (init_frame)
492
  video_segments = {}
493
  confidence_scores = []
494
 
495
+ print(f" Using propagate_in_video_iterator for bidirectional propagation from frame {init_frame}...")
 
496
 
497
+ for sam2_output in model.propagate_in_video_iterator(inference_session):
498
+ video_res_masks = processor.post_process_masks(
499
+ [sam2_output.pred_masks],
500
+ original_sizes=[[inference_session.video_height, inference_session.video_width]],
501
+ binarize=False
502
+ )[0]
503
+ video_segments[sam2_output.frame_idx] = video_res_masks
504
+
505
+ # Calculate confidence
506
+ mask_float = video_res_masks.float() if video_res_masks.dtype == torch.bool else video_res_masks
507
+ confidence_scores.append(float(mask_float.mean()))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
 
509
  print(f" ✅ Got masks for {len(video_segments)} frames (init_frame was {init_frame})")
510