yasserDahou commited on
Commit
9abbbe0
·
verified ·
1 Parent(s): 9b053e0

Update modeling_falcon_perception.py

Browse files
Files changed (1) hide show
  1. modeling_falcon_perception.py +30 -5
modeling_falcon_perception.py CHANGED
@@ -733,13 +733,38 @@ class FalconPerceptionForSegmentation(PreTrainedModel):
733
  tokens_B1[should_stop_B, :] = self._pad_token_id
734
  padded_tokens[:, pos] = tokens_B1[:, -1]
735
 
736
- # Decode coords
737
  coord_logits = self.decode_coords(h_BSD[:, -1:], tokens_B1)
738
- xy_b2 = torch.argmax(coord_logits, dim=-1) / coord_logits.size(-1)
739
- coord_preds = [{"x": xy[0].item(), "y": xy[1].item()} for xy in xy_b2]
740
  sample_w_coord = torch.where(tokens_B1 == self.config.coord_token_id)[0]
 
 
 
 
 
 
741
  for i, b in enumerate(sample_w_coord.tolist()):
742
- aux_output_B[b].append(coord_preds[i])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
743
 
744
  # Decode sizes
745
  size_logits = self.decode_sizes(h_BSD[:, -1:], tokens_B1)
@@ -847,4 +872,4 @@ class FalconPerceptionForSegmentation(PreTrainedModel):
847
  "mask_rle": mask_rle,
848
  })
849
 
850
- return detections
 
733
  tokens_B1[should_stop_B, :] = self._pad_token_id
734
  padded_tokens[:, pos] = tokens_B1[:, -1]
735
 
736
+ # Decode coords (with deduplication to avoid repeating the same location)
737
  coord_logits = self.decode_coords(h_BSD[:, -1:], tokens_B1)
 
 
738
  sample_w_coord = torch.where(tokens_B1 == self.config.coord_token_id)[0]
739
+
740
+ num_bins = coord_logits.size(-1)
741
+ coord_repeat_threshold = 0.01 # coords within 1% of image size are considered duplicates
742
+ max_coord_attempts = 100
743
+ xy_b2 = torch.zeros(B, 2, device=device, dtype=self.dtype)
744
+
745
  for i, b in enumerate(sample_w_coord.tolist()):
746
+ logits_b = coord_logits[i].clone() # (2, num_bins)
747
+ existing_coords = [
748
+ item for item in aux_output_B[b]
749
+ if isinstance(item, dict) and "x" in item and "y" in item
750
+ ]
751
+ pred_x, pred_y = 0.0, 0.0
752
+ for _ in range(max_coord_attempts):
753
+ pred_bins = torch.argmax(logits_b, dim=-1) # (2,)
754
+ pred_x = pred_bins[0].item() / (num_bins - 1)
755
+ pred_y = pred_bins[1].item() / (num_bins - 1)
756
+ is_repeat = any(
757
+ abs(ec["x"] - pred_x) < coord_repeat_threshold
758
+ and abs(ec["y"] - pred_y) < coord_repeat_threshold
759
+ for ec in existing_coords
760
+ )
761
+ if not is_repeat:
762
+ break
763
+ logits_b[0, pred_bins[0]] = float("-inf")
764
+ logits_b[1, pred_bins[1]] = float("-inf")
765
+ xy_b2[b, 0] = pred_x
766
+ xy_b2[b, 1] = pred_y
767
+ aux_output_B[b].append({"x": pred_x, "y": pred_y})
768
 
769
  # Decode sizes
770
  size_logits = self.decode_sizes(h_BSD[:, -1:], tokens_B1)
 
872
  "mask_rle": mask_rle,
873
  })
874
 
875
+ return detections