lxzcpro commited on
Commit
a8246e3
·
1 Parent(s): 144afae

Implement SAM2 and better inpainting

Browse files
Files changed (4) hide show
  1. src/matcher.py +44 -36
  2. src/painter.py +70 -1
  3. src/pipeline.py +43 -30
  4. src/segmenter.py +58 -1
src/matcher.py CHANGED
@@ -3,62 +3,70 @@ from PIL import Image
3
  from transformers import CLIPProcessor, CLIPModel
4
 
5
  class CLIPMatcher:
6
- def __init__(self, model_name='openai/clip-vit-base-patch32'):
7
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
8
  self.model = CLIPModel.from_pretrained(model_name).to(self.device)
9
  self.processor = CLIPProcessor.from_pretrained(model_name)
10
- self.model.eval()
11
 
12
- def match_segments(self, image, segments, text_query):
13
- if not segments:
14
- return None
 
 
15
 
16
- ignore_words = ['remove', 'delete', 'erase', 'the', 'a', 'an']
17
- query_words = text_query.lower().split()
18
- clean_query = " ".join([w for w in query_words if w not in ignore_words])
 
19
 
20
- # If query becomes empty (e.g. user just typed "remove"), fallback to original
21
- target_text = clean_query if clean_query else text_query
22
-
23
- print(f"Debug: CLIP searching for object: '{target_text}'")
24
-
25
  pil_image = Image.fromarray(image)
26
- best_score = -float('inf')
27
- best_segment = None
28
-
29
  crops = []
30
  valid_segments = []
31
-
 
 
 
 
32
  for seg in segments:
33
  x1, y1, x2, y2 = seg['bbox'].astype(int)
34
- # Check bounds to prevent crash
35
- h, w = image.shape[:2]
36
- x1, y1 = max(0, x1), max(0, y1)
37
- x2, y2 = min(w, x2), min(h, y2)
38
-
39
- if x2 - x1 < 5 or y2 - y1 < 5: continue # Skip tiny/invalid boxes
40
 
41
  crops.append(pil_image.crop((x1, y1, x2, y2)))
42
  valid_segments.append(seg)
43
 
44
- if not crops: return None
45
 
46
- # Batch inference
47
  inputs = self.processor(
48
- text=[target_text],
49
- images=crops,
50
- return_tensors="pt",
51
- padding=True
52
  ).to(self.device)
53
 
54
  with torch.no_grad():
55
  outputs = self.model(**inputs)
56
- # logits_per_image: [num_crops, 1]
57
- probs = outputs.logits_per_image.softmax(dim=0)
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- # Get the index of the highest match
60
- best_idx = probs.argmax().item()
61
- best_score = probs[best_idx].item()
62
- best_segment = valid_segments[best_idx]
 
 
63
 
64
- return best_segment
 
 
 
3
  from transformers import CLIPProcessor, CLIPModel
4
 
5
  class CLIPMatcher:
6
+ def __init__(self, model_name='openai/clip-vit-large-patch14'):
7
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
8
  self.model = CLIPModel.from_pretrained(model_name).to(self.device)
9
  self.processor = CLIPProcessor.from_pretrained(model_name)
 
10
 
11
+ def get_top_k_segments(self, image, segments, text_query, k=5):
12
+ """
13
+ Returns top K segments based on CLIP score + Area Weight.
14
+ """
15
+ if not segments: return []
16
 
17
+ # 1. Clean Text
18
+ ignore = ['remove', 'delete', 'erase', 'the', 'a', 'an']
19
+ words = [w for w in text_query.lower().split() if w not in ignore]
20
+ clean_text = " ".join(words) if words else text_query
21
 
 
 
 
 
 
22
  pil_image = Image.fromarray(image)
 
 
 
23
  crops = []
24
  valid_segments = []
25
+
26
+ # Prepare crops
27
+ h, w = image.shape[:2]
28
+ total_img_area = h * w
29
+
30
  for seg in segments:
31
  x1, y1, x2, y2 = seg['bbox'].astype(int)
32
+ # Pad slightly
33
+ pad = 10
34
+ x1, y1 = max(0, x1-pad), max(0, y1-pad)
35
+ x2, y2 = min(w, x2+pad), min(h, y2+pad)
 
 
36
 
37
  crops.append(pil_image.crop((x1, y1, x2, y2)))
38
  valid_segments.append(seg)
39
 
40
+ if not crops: return []
41
 
42
+ # 2. Inference
43
  inputs = self.processor(
44
+ text=[clean_text], images=crops, return_tensors="pt", padding=True
 
 
 
45
  ).to(self.device)
46
 
47
  with torch.no_grad():
48
  outputs = self.model(**inputs)
49
+ # Standardize scores
50
+ probs = outputs.logits_per_image.softmax(dim=0).cpu().numpy().flatten()
51
+
52
+ # 3. Re-Scoring with Area Weight
53
+ final_results = []
54
+ for i, score in enumerate(probs):
55
+ seg = valid_segments[i]
56
+ area_ratio = seg['area'] / total_img_area
57
+
58
+ # HEURISTIC: Boost score for larger objects.
59
+ # If searching for general terms (bus, car, cat), bigger is usually better.
60
+ # We add 20% of the area_ratio to the score.
61
+ weighted_score = score + (area_ratio * 0.2)
62
 
63
+ final_results.append({
64
+ 'mask': seg['mask'],
65
+ 'bbox': seg['bbox'],
66
+ 'original_score': float(score),
67
+ 'weighted_score': float(weighted_score)
68
+ })
69
 
70
+ # 4. Sort and take Top K
71
+ final_results.sort(key=lambda x: x['weighted_score'], reverse=True)
72
+ return final_results[:k]
src/painter.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
  import numpy as np
3
  from PIL import Image
4
- from diffusers import StableDiffusionInpaintPipeline
5
 
6
  class SDInpainter:
7
  def __init__(self, model_id="runwayml/stable-diffusion-inpainting"):
@@ -48,6 +48,75 @@ class SDInpainter:
48
 
49
  def _dilate_mask(self, mask, kernel_size=9):
50
  # Increased kernel size slightly for better blending
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  import cv2
52
  kernel = np.ones((kernel_size, kernel_size), np.uint8)
53
  return cv2.dilate(mask, kernel, iterations=1)
 
1
  import torch
2
  import numpy as np
3
  from PIL import Image
4
+ from diffusers import StableDiffusionInpaintPipeline, StableDiffusionXLInpaintPipeline
5
 
6
  class SDInpainter:
7
  def __init__(self, model_id="runwayml/stable-diffusion-inpainting"):
 
48
 
49
  def _dilate_mask(self, mask, kernel_size=9):
50
  # Increased kernel size slightly for better blending
51
+ import cv2
52
+ kernel = np.ones((kernel_size, kernel_size), np.uint8)
53
+ return cv2.dilate(mask, kernel, iterations=1)
54
+
55
+
56
+ class SDXLInpainter:
57
+ def __init__(self, model_id="diffusers/stable-diffusion-xl-1.0-inpainting-0.1"):
58
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
59
+ # Use float16
60
+ self.pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
61
+ model_id,
62
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
63
+ variant="fp16", # Add variant for faster loading if available
64
+ use_safetensors=True
65
+ ).to(self.device)
66
+
67
+ if self.device == "cuda":
68
+ self.pipe.enable_model_cpu_offload() # Saves VRAM effectively
69
+
70
+ def inpaint(self, image, mask, prompt=""): # Default prompt changed to empty
71
+ pil_image = Image.fromarray(image).convert('RGB')
72
+
73
+ # Increase kernel size to 15 or 20 to ensure no edge artifacts remain
74
+ mask = self._dilate_mask(mask, kernel_size=15)
75
+
76
+ # Blur the mask slightly to make the transition smoother
77
+ import cv2
78
+ mask = cv2.GaussianBlur(mask, (5, 5), 0)
79
+
80
+ pil_mask = Image.fromarray((mask * 255).astype(np.uint8)).convert('L')
81
+
82
+ w, h = pil_image.size
83
+ target_size = 1024
84
+ scale = target_size / max(w, h)
85
+ new_w = int(w * scale) - (int(w * scale) % 8)
86
+ new_h = int(h * scale) - (int(h * scale) % 8)
87
+
88
+ resized_image = pil_image.resize((new_w, new_h), Image.LANCZOS)
89
+ resized_mask = pil_mask.resize((new_w, new_h), Image.NEAREST)
90
+
91
+ if not prompt or prompt == "background":
92
+ final_prompt = "clean background, empty space, seamless texture, high quality"
93
+ # Lower guidance scale for background filling to rely more on image context
94
+ guidance_scale = 4.5
95
+ else:
96
+ final_prompt = prompt
97
+ guidance_scale = 7.5
98
+
99
+ neg_prompt = (
100
+ "object, subject, person, animal, cat, dog, "
101
+ "glass, transparent, crystal, bottle, cup, reflection, "
102
+ "complex, 3d render, artifacts, shadow, distortion, blur, watermark"
103
+ )
104
+
105
+ output = self.pipe(
106
+ prompt=final_prompt,
107
+ negative_prompt=neg_prompt,
108
+ image=resized_image,
109
+ mask_image=resized_mask,
110
+ num_inference_steps=40,
111
+ guidance_scale=guidance_scale, # Dynamic guidance
112
+ strength=0.99, # High strength to ensure removal
113
+ ).images[0]
114
+
115
+ result = output.resize((w, h), Image.LANCZOS)
116
+
117
+ return np.array(result)
118
+
119
+ def _dilate_mask(self, mask, kernel_size=15):
120
  import cv2
121
  kernel = np.ones((kernel_size, kernel_size), np.uint8)
122
  return cv2.dilate(mask, kernel, iterations=1)
src/pipeline.py CHANGED
@@ -1,45 +1,58 @@
1
  import numpy as np
2
- from .segmenter import YOLOSegmenter
 
3
  from .matcher import CLIPMatcher
4
- from .painter import SDInpainter
5
- from .utils import resize_image
6
 
7
  class ObjectRemovalPipeline:
8
  def __init__(self):
9
  print("Initializing models...")
10
- self.segmenter = YOLOSegmenter()
11
  self.matcher = CLIPMatcher()
12
- self.inpainter = SDInpainter()
13
- print("Models loaded successfully!")
14
 
15
- def process(self, image, text_query, inpaint_prompt="background"):
16
  """
17
- Main pipeline for object removal
18
- Args:
19
- image: numpy array (H, W, 3)
20
- text_query: str, e.g., "remove the bottle"
21
- inpaint_prompt: str, prompt for inpainting
22
  """
23
- # Resize for processing
24
- original_shape = image.shape[:2]
25
- image = resize_image(image, max_size=1024)
26
-
27
- # Step 1: Segment objects
28
  segments = self.segmenter.segment(image)
29
  if not segments:
30
- return image, None, "No objects detected"
31
-
32
- # Step 2: Match text query to segment
33
- matched_segment = self.matcher.match_segments(image, segments, text_query)
34
- if matched_segment is None:
35
- return image, None, "No matching object found"
36
 
37
- # Step 3: Inpaint to remove object
38
- result = self.inpainter.inpaint(image, matched_segment['mask'], inpaint_prompt)
 
 
 
 
 
 
 
39
 
40
- # Resize back if needed
41
- if result.shape[:2] != original_shape:
42
- import cv2
43
- result = cv2.resize(result, (original_shape[1], original_shape[0]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- return result, matched_segment['mask'], f"Removed: {matched_segment['class_name']}"
 
1
  import numpy as np
2
+ import cv2
3
+ from .segmenter import SAM2Segmenter
4
  from .matcher import CLIPMatcher
5
+ from .painter import SDXLInpainter
6
+ from .utils import visualize_mask
7
 
8
  class ObjectRemovalPipeline:
9
  def __init__(self):
10
  print("Initializing models...")
11
+ self.segmenter = SAM2Segmenter()
12
  self.matcher = CLIPMatcher()
13
+ self.inpainter = SDXLInpainter()
14
+ print("Pipeline ready.")
15
 
16
+ def process(self, image, text_query, inpaint_prompt=""):
17
  """
18
+ Main processing function for object removal.
 
 
 
 
19
  """
20
+ # 1. Segment
 
 
 
 
21
  segments = self.segmenter.segment(image)
22
  if not segments:
23
+ return image, None, "No segments found"
 
 
 
 
 
24
 
25
+ # 2. Match with Top-K Strategy
26
+ # We get top 5 candidates to handle "Part-Whole" ambiguity (e.g. tire vs car)
27
+ candidates = self.matcher.get_top_k_segments(image, segments, text_query, k=5)
28
+ if not candidates:
29
+ return image, None, "No match found"
30
+
31
+ # 3. Merge Masks (The "Cat Tail" Fix)
32
+ best_candidate = candidates[0]
33
+ final_mask = best_candidate['mask'].copy()
34
 
35
+ print(f"Top Match Score: {best_candidate['weighted_score']:.3f}")
36
+
37
+ # Merge other candidates if they are close in score or physically overlap
38
+ for i in range(1, len(candidates)):
39
+ cand = candidates[i]
40
+ score_ratio = cand['weighted_score'] / best_candidate['weighted_score']
41
+
42
+ # Check intersection
43
+ intersection = np.logical_and(final_mask, cand['mask']).sum()
44
+
45
+ # Rule: Merge if score is similar (>85%) OR if they overlap pixels
46
+ if score_ratio > 0.85 or intersection > 0:
47
+ print(f"Merging Rank {i+1} (Score ratio: {score_ratio:.2f}, Overlap: {intersection > 0})")
48
+ final_mask = np.logical_or(final_mask, cand['mask'])
49
+
50
+ # 4. Dilate Final Mask
51
+ # Expands mask slightly to cover edges/seams
52
+ kernel = np.ones((15, 15), np.uint8)
53
+ final_mask = cv2.dilate(final_mask.astype(np.uint8), kernel, iterations=1)
54
+
55
+ # 5. Inpaint
56
+ result = self.inpainter.inpaint(image, final_mask, prompt=inpaint_prompt)
57
 
58
+ return result, final_mask, "Success"
src/segmenter.py CHANGED
@@ -1,7 +1,10 @@
 
1
  import numpy as np
2
- from ultralytics import YOLO
3
  import cv2
4
 
 
 
 
5
  class YOLOSegmenter:
6
  def __init__(self, model_name='yolov8x-seg.pt'):
7
  self.model = YOLO(model_name)
@@ -24,4 +27,58 @@ class YOLOSegmenter:
24
  'class_name': self.model.names[class_id]
25
  })
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  return segments
 
1
+ import torch
2
  import numpy as np
 
3
  import cv2
4
 
5
+ from ultralytics import YOLO
6
+ from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
7
+
8
  class YOLOSegmenter:
9
  def __init__(self, model_name='yolov8x-seg.pt'):
10
  self.model = YOLO(model_name)
 
27
  'class_name': self.model.names[class_id]
28
  })
29
 
30
+ return segments
31
+
32
+ class SAM2Segmenter:
33
+ def __init__(self, model_cfg='sam2.1_hiera_l.yaml', checkpoint=''):
34
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
35
+ # Load the Automatic Generator
36
+ self.mask_generator = SAM2AutomaticMaskGenerator.from_pretrained(
37
+ "facebook/sam2.1-hiera-large",
38
+ points_per_side=32,
39
+ pred_iou_thresh=0.80,
40
+ stability_score_thresh=0.92,
41
+ crop_n_layers=1,
42
+ crop_n_points_downscale_factor=2,
43
+ device=self.device
44
+ )
45
+
46
+ def segment(self, image):
47
+ """
48
+ Generates masks and filters out background-like huge segments.
49
+ """
50
+ if hasattr(self.mask_generator, 'generate'):
51
+ masks = self.mask_generator.generate(image)
52
+ else:
53
+ masks = self.mask_generator.predict(image)
54
+
55
+ segments = []
56
+ img_h, img_w = image.shape[:2]
57
+ total_area = img_h * img_w
58
+
59
+ for m in masks:
60
+ # SAM returns [x, y, w, h]
61
+ x, y, w, h = m['bbox']
62
+
63
+ # Convert to [x1, y1, x2, y2]
64
+ x1, y1, x2, y2 = x, y, x + w, y + h
65
+
66
+ # Ignore masks that are too large (> 75% of image)
67
+ if m['area'] > total_area * 0.75:
68
+ continue
69
+
70
+ # Ignore masks that are too small (< 0.5% of image)
71
+ if m['area'] < total_area * 0.005:
72
+ continue
73
+
74
+ segments.append({
75
+ 'mask': m['segmentation'].astype(np.uint8),
76
+ 'bbox': np.array([x1, y1, x2, y2]),
77
+ 'score': m.get('predicted_iou', 1.0),
78
+ 'area': m['area']
79
+ })
80
+
81
+ # Sort by area (smallest to largest) to prefer specific objects over containers
82
+ segments.sort(key=lambda s: s['area'])
83
+
84
  return segments