Mirko Trasciatti
commited on
Commit
·
cf42079
1
Parent(s):
a9a341a
REVERT to propagate_in_video_iterator - it DOES support bidirectional propagation!
Browse files
app.py
CHANGED
|
@@ -486,47 +486,25 @@ def segment_video_multi(video_file, objects_json):
|
|
| 486 |
# Skip initial model inference - go straight to propagation
|
| 487 |
# The propagation loop will handle init_frame when it reaches it
|
| 488 |
|
| 489 |
-
#
|
| 490 |
-
#
|
| 491 |
-
#
|
| 492 |
video_segments = {}
|
| 493 |
confidence_scores = []
|
| 494 |
|
| 495 |
-
print(f"
|
| 496 |
-
print(f" Annotation at frame {init_frame} will guide propagation")
|
| 497 |
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
# Call model - it will use annotation if frame_idx == init_frame
|
| 511 |
-
sam2_output = model(
|
| 512 |
-
inference_session=inference_session,
|
| 513 |
-
frame=pixel_values,
|
| 514 |
-
frame_idx=frame_idx
|
| 515 |
-
)
|
| 516 |
-
|
| 517 |
-
# Post-process masks
|
| 518 |
-
H = inference_session.video_height
|
| 519 |
-
W = inference_session.video_width
|
| 520 |
-
pred_masks = sam2_output.pred_masks.detach().cpu()
|
| 521 |
-
video_res_masks = processor.post_process_masks(
|
| 522 |
-
[pred_masks],
|
| 523 |
-
original_sizes=[[H, W]],
|
| 524 |
-
binarize=False
|
| 525 |
-
)[0]
|
| 526 |
-
|
| 527 |
-
video_segments[frame_idx] = video_res_masks
|
| 528 |
-
mask_float = video_res_masks.float() if video_res_masks.dtype == torch.bool else video_res_masks
|
| 529 |
-
confidence_scores.append(float(mask_float.mean()))
|
| 530 |
|
| 531 |
print(f" ✅ Got masks for {len(video_segments)} frames (init_frame was {init_frame})")
|
| 532 |
|
|
|
|
| 486 |
# Skip initial model inference - go straight to propagation
|
| 487 |
# The propagation loop will handle init_frame when it reaches it
|
| 488 |
|
| 489 |
+
# Use propagate_in_video_iterator for BIDIRECTIONAL propagation
|
| 490 |
+
# According to SAM2 docs, this should propagate both forward AND backward
|
| 491 |
+
# from the annotated frame (init_frame)
|
| 492 |
video_segments = {}
|
| 493 |
confidence_scores = []
|
| 494 |
|
| 495 |
+
print(f" Using propagate_in_video_iterator for bidirectional propagation from frame {init_frame}...")
|
|
|
|
| 496 |
|
| 497 |
+
for sam2_output in model.propagate_in_video_iterator(inference_session):
|
| 498 |
+
video_res_masks = processor.post_process_masks(
|
| 499 |
+
[sam2_output.pred_masks],
|
| 500 |
+
original_sizes=[[inference_session.video_height, inference_session.video_width]],
|
| 501 |
+
binarize=False
|
| 502 |
+
)[0]
|
| 503 |
+
video_segments[sam2_output.frame_idx] = video_res_masks
|
| 504 |
+
|
| 505 |
+
# Calculate confidence
|
| 506 |
+
mask_float = video_res_masks.float() if video_res_masks.dtype == torch.bool else video_res_masks
|
| 507 |
+
confidence_scores.append(float(mask_float.mean()))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
|
| 509 |
print(f" ✅ Got masks for {len(video_segments)} frames (init_frame was {init_frame})")
|
| 510 |
|