lxzcpro commited on
Commit
03bafc0
·
1 Parent(s): cd002fc

implement second version

Browse files
Files changed (7) hide show
  1. .gitignore +2 -1
  2. app.py +74 -94
  3. src/__init__.py +3 -3
  4. src/matcher.py +51 -30
  5. src/painter.py +1 -1
  6. src/pipeline.py +90 -41
  7. src/segmenter.py +64 -72
.gitignore CHANGED
@@ -206,4 +206,5 @@ marimo/_static/
206
  marimo/_lsp/
207
  __marimo__/
208
 
209
- models/yolov8
 
 
206
  marimo/_lsp/
207
  __marimo__/
208
 
209
+ models/yolov8
210
+ rubrics.txt
app.py CHANGED
@@ -1,129 +1,109 @@
1
  import gradio as gr
2
  import numpy as np
3
- import torch
4
  from src.pipeline import ObjectRemovalPipeline
5
  from src.utils import visualize_mask
6
 
7
- # Initialize pipeline globally to load models only once
8
- print("Loading pipeline...")
9
  pipeline = ObjectRemovalPipeline()
10
 
11
  def ensure_uint8(image):
12
- """
13
- Ensures the image is in valid uint8 format (0-255) for Gradio display.
14
- """
15
- if image is None:
16
- return None
17
-
18
  image = np.array(image)
19
-
20
- # 1. Handle NaN/Inf (Exploding gradients often cause this)
21
- if not np.isfinite(image).all():
22
- print("Warning: Image contains NaN or Inf. Replacing with black.")
23
- image = np.nan_to_num(image, nan=0.0, posinf=255.0, neginf=0.0)
24
-
25
- # 2. Normalize Float (0.0-1.0) to Int (0-255)
26
  if image.dtype != np.uint8:
27
- # If image is in 0-1 range (common in torch/diffusers)
28
- if image.max() <= 1.0:
29
- image = (image * 255.0)
30
-
31
- # Clip to safe range and cast
32
  image = np.clip(image, 0, 255).astype(np.uint8)
33
-
34
  return image
35
 
36
- def remove_object(image, text_query, inpaint_prompt, progress=gr.Progress()):
37
- """
38
- Gradio wrapper with progress tracking and error handling.
39
- """
40
- if image is None:
41
- return None, None, "Error: Please upload an image first."
42
 
43
- if not text_query:
44
- return image, None, "Error: Please specify what to remove."
45
-
46
- try:
47
- # 1. Segmentation Phase
48
- progress(0.2, desc="Segmenting & Matching Object...")
49
-
50
- # Note: We call the pipeline.
51
- # Ideally, you would break the pipeline.process method apart to update progress
52
- # between segmentation and inpainting, but this works for now.
53
- result, mask, message = pipeline.process(
54
- image,
55
- text_query,
56
- inpaint_prompt if inpaint_prompt else "background"
57
- )
58
 
59
- # 2. Visualization Phase
60
- progress(0.9, desc="Post-processing...")
61
- mask_viz = None
62
- if mask is not None:
63
- mask_viz = visualize_mask(image, mask)
64
- else:
65
- # If no mask found, return original image as preview
66
- mask_viz = image
67
-
68
- mask_viz = ensure_uint8(mask_viz)
69
- result = ensure_uint8(result)
70
 
71
- return result, mask_viz, message
 
72
 
73
- except torch.cuda.OutOfMemoryError:
74
- return None, None, "Error: GPU Out of Memory. Try a smaller image."
75
- except Exception as e:
76
- return None, None, f"Error: {str(e)}"
 
 
 
 
 
 
77
 
78
- # Define Custom CSS for a cleaner look (Optional)
79
  css = """
80
- footer {visibility: hidden}
81
  .gradio-container {min-height: 0px !important}
 
 
82
  """
83
 
84
- with gr.Blocks(title="Object Removal", css=css, theme=gr.themes.Soft()) as demo:
85
- gr.Markdown("## Text-Guided Object Removal Pipeline")
86
- gr.Markdown("Identify objects via CLIP and remove them using Stable Diffusion.")
 
 
87
 
88
  with gr.Row():
89
  with gr.Column(scale=1):
90
  input_image = gr.Image(label="Input Image", type="numpy", height=400)
91
- text_query = gr.Textbox(
92
- label="Target Object",
93
- placeholder="e.g., 'bottle', 'cell', 'petri dish'",
94
- info="What should be removed?"
95
- )
96
- inpaint_prompt = gr.Textbox(
97
- label="Inpaint Prompt (Context)",
98
- placeholder="background",
99
- value="background",
100
- info="What should fill the empty space?"
101
- )
102
- submit_btn = gr.Button("Run Pipeline", variant="primary")
103
 
104
  with gr.Column(scale=1):
105
- # Result tabs to switch between final result and debug mask
106
- with gr.Tabs():
107
- with gr.TabItem("Final Result"):
108
- output_image = gr.Image(label="Inpainted Result", height=400)
109
- with gr.TabItem("Segmentation Debug"):
110
- mask_preview = gr.Image(label="Detected Mask Overlay", height=400)
 
 
 
 
 
 
 
 
 
 
111
 
112
- status_text = gr.Textbox(label="Pipeline Logs", interactive=False)
 
 
 
 
 
 
 
 
113
 
114
- # Examples allow users to test without uploading
115
- # Ensure these files actually exist in your folder, or comment this out
116
- # gr.Examples(
117
- # examples=[["examples/lab_bench.jpg", "remove the pipette", "table surface"]],
118
- # inputs=[input_image, text_query, inpaint_prompt],
119
- # )
120
 
121
- submit_btn.click(
122
- fn=remove_object,
123
- inputs=[input_image, text_query, inpaint_prompt],
124
- outputs=[output_image, mask_preview, status_text]
125
  )
126
 
127
  if __name__ == "__main__":
128
- # queue() is essential for handling GPU workloads and preventing timeouts
129
  demo.queue().launch(share=True)
 
1
  import gradio as gr
2
  import numpy as np
 
3
  from src.pipeline import ObjectRemovalPipeline
4
  from src.utils import visualize_mask
5
 
6
+ # Initialize pipeline once
 
7
  pipeline = ObjectRemovalPipeline()
8
 
9
  def ensure_uint8(image):
10
+ if image is None: return None
 
 
 
 
 
11
  image = np.array(image)
 
 
 
 
 
 
 
12
  if image.dtype != np.uint8:
13
+ if image.max() <= 1.0: image = image * 255.0
 
 
 
 
14
  image = np.clip(image, 0, 255).astype(np.uint8)
 
15
  return image
16
 
17
+ def step1_detect(image, text_query):
18
+ if image is None or not text_query:
19
+ return [], [], "Please upload image and enter text."
 
 
 
20
 
21
+ # Calls the new method in pipeline.py
22
+ candidates, msg = pipeline.get_candidates(image, text_query)
23
+
24
+ if not candidates:
25
+ return [], [], f"Error: {msg}"
26
+
27
+ masks = [c['mask'] for c in candidates]
28
+
29
+ # Generate visualization for gallery
30
+ gallery_imgs = []
31
+ for i, mask in enumerate(masks):
32
+ viz = visualize_mask(image, mask)
33
+ # Label with rank and score if available
34
+ label = f"Option {i+1} (Score: {candidates[i].get('weighted_score', 0):.2f})"
35
+ gallery_imgs.append((ensure_uint8(viz), label))
36
 
37
+ return masks, gallery_imgs, "Select the best match below."
 
 
 
 
 
 
 
 
 
 
38
 
39
+ def on_select(evt: gr.SelectData):
40
+ return evt.index
41
 
42
+ def step2_remove(image, masks, selected_idx, prompt, shadow_exp):
43
+ if not masks or selected_idx is None:
44
+ return None, "Please select an object first."
45
+
46
+ target_mask = masks[selected_idx]
47
+
48
+ # Calls the pipeline method
49
+ result = pipeline.inpaint_selected(image, target_mask, prompt, shadow_expansion=shadow_exp)
50
+
51
+ return ensure_uint8(result), "Success!"
52
 
53
+ # CSS for cleaner UI
54
  css = """
 
55
  .gradio-container {min-height: 0px !important}
56
+ /* Ensure images in gallery don't get cropped strictly */
57
+ button.gallery-item {object-fit: contain !important}
58
  """
59
 
60
+ with gr.Blocks(title="TextEraser", css=css, theme=gr.themes.Soft()) as demo:
61
+ mask_state = gr.State([])
62
+ idx_state = gr.State(0)
63
+
64
+ gr.Markdown("## TextEraser: Interactive Object Removal")
65
 
66
  with gr.Row():
67
  with gr.Column(scale=1):
68
  input_image = gr.Image(label="Input Image", type="numpy", height=400)
69
+ text_query = gr.Textbox(label="What to remove?", placeholder="e.g. 'bottle', 'shadow'")
70
+ btn_detect = gr.Button("1. Detect Objects", variant="primary")
 
 
 
 
 
 
 
 
 
 
71
 
72
  with gr.Column(scale=1):
73
+ # FIXED: object_fit="contain" prevents cropping
74
+ # allow_preview=True lets you click to zoom
75
+ gallery = gr.Gallery(
76
+ label="Candidates (Select One)",
77
+ columns=2,
78
+ height=400,
79
+ allow_preview=True,
80
+ object_fit="contain"
81
+ )
82
+ status = gr.Textbox(label="Status", interactive=False)
83
+
84
+ with gr.Row():
85
+ with gr.Column(scale=1):
86
+ shadow_slider = gr.Slider(0, 40, value=10, label="Shadow Fix (Expand Mask Downwards)")
87
+ inpaint_prompt = gr.Textbox(label="Background Description", value="background")
88
+ btn_remove = gr.Button("2. Remove Selected", variant="stop")
89
 
90
+ with gr.Column(scale=1):
91
+ output_image = gr.Image(label="Final Result", height=400)
92
+
93
+ # Event Wiring
94
+ btn_detect.click(
95
+ fn=step1_detect,
96
+ inputs=[input_image, text_query],
97
+ outputs=[mask_state, gallery, status]
98
+ )
99
 
100
+ gallery.select(fn=on_select, inputs=None, outputs=idx_state)
 
 
 
 
 
101
 
102
+ btn_remove.click(
103
+ fn=step2_remove,
104
+ inputs=[input_image, mask_state, idx_state, inpaint_prompt, shadow_slider],
105
+ outputs=[output_image, status]
106
  )
107
 
108
  if __name__ == "__main__":
 
109
  demo.queue().launch(share=True)
src/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
  from .pipeline import ObjectRemovalPipeline
2
- from .segmenter import YOLOSegmenter
3
  from .matcher import CLIPMatcher
4
- from .painter import SDInpainter
5
 
6
- __all__ = ['ObjectRemovalPipeline', 'YOLOSegmenter', 'CLIPMatcher', 'SDInpainter']
 
1
  from .pipeline import ObjectRemovalPipeline
2
+ from .segmenter import SAM2Predictor
3
  from .matcher import CLIPMatcher
4
+ from .painter import SDXLInpainter
5
 
6
+ __all__ = ['ObjectRemovalPipeline', 'CLIPMatcher', 'SDXLInpainter', 'SAM2Predictor']
src/matcher.py CHANGED
@@ -1,17 +1,17 @@
1
  import torch
 
 
2
  from PIL import Image
3
  from transformers import CLIPProcessor, CLIPModel
4
 
5
  class CLIPMatcher:
6
  def __init__(self, model_name='openai/clip-vit-large-patch14'):
7
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
8
- self.model = CLIPModel.from_pretrained(model_name).to(self.device)
 
9
  self.processor = CLIPProcessor.from_pretrained(model_name)
10
 
11
  def get_top_k_segments(self, image, segments, text_query, k=5):
12
- """
13
- Returns top K segments based on CLIP score + Area Weight.
14
- """
15
  if not segments: return []
16
 
17
  # 1. Clean Text
@@ -19,54 +19,75 @@ class CLIPMatcher:
19
  words = [w for w in text_query.lower().split() if w not in ignore]
20
  clean_text = " ".join(words) if words else text_query
21
 
 
22
  pil_image = Image.fromarray(image)
23
  crops = []
24
  valid_segments = []
25
 
26
- # Prepare crops
27
- h, w = image.shape[:2]
28
- total_img_area = h * w
29
 
30
  for seg in segments:
31
- x1, y1, x2, y2 = seg['bbox'].astype(int)
32
- # Pad slightly
33
- pad = 10
34
- x1, y1 = max(0, x1-pad), max(0, y1-pad)
35
- x2, y2 = min(w, x2+pad), min(h, y2+pad)
36
 
37
- crops.append(pil_image.crop((x1, y1, x2, y2)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  valid_segments.append(seg)
39
 
40
  if not crops: return []
41
 
42
- # 2. Inference
43
- inputs = self.processor(
44
- text=[clean_text], images=crops, return_tensors="pt", padding=True
45
- ).to(self.device)
 
 
46
 
47
- with torch.no_grad():
48
- outputs = self.model(**inputs)
49
- # Standardize scores
50
- probs = outputs.logits_per_image.softmax(dim=0).cpu().numpy().flatten()
 
 
 
 
 
 
 
51
 
52
- # 3. Re-Scoring with Area Weight
53
  final_results = []
54
  for i, score in enumerate(probs):
55
  seg = valid_segments[i]
56
- area_ratio = seg['area'] / total_img_area
 
 
 
 
57
 
58
- # HEURISTIC: Boost score for larger objects.
59
- # If searching for general terms (bus, car, cat), bigger is usually better.
60
- # We add 20% of the area_ratio to the score.
61
- weighted_score = score + (area_ratio * 0.2)
62
 
63
  final_results.append({
64
- 'mask': seg['mask'],
65
  'bbox': seg['bbox'],
66
  'original_score': float(score),
67
- 'weighted_score': float(weighted_score)
 
68
  })
69
 
70
- # 4. Sort and take Top K
71
  final_results.sort(key=lambda x: x['weighted_score'], reverse=True)
72
  return final_results[:k]
 
1
  import torch
2
+ import numpy as np
3
+ import gc
4
  from PIL import Image
5
  from transformers import CLIPProcessor, CLIPModel
6
 
7
  class CLIPMatcher:
8
  def __init__(self, model_name='openai/clip-vit-large-patch14'):
9
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ # Load directly to CPU first
11
+ self.model = CLIPModel.from_pretrained(model_name).to("cpu")
12
  self.processor = CLIPProcessor.from_pretrained(model_name)
13
 
14
  def get_top_k_segments(self, image, segments, text_query, k=5):
 
 
 
15
  if not segments: return []
16
 
17
  # 1. Clean Text
 
19
  words = [w for w in text_query.lower().split() if w not in ignore]
20
  clean_text = " ".join(words) if words else text_query
21
 
22
+ # 2. Crop (CPU)
23
  pil_image = Image.fromarray(image)
24
  crops = []
25
  valid_segments = []
26
 
27
+ h_img, w_img = image.shape[:2]
28
+ total_img_area = h_img * w_img
 
29
 
30
  for seg in segments:
31
+ if 'bbox' not in seg: continue
 
 
 
 
32
 
33
+ # Safe numpy cast
34
+ bbox = np.array(seg['bbox']).astype(int)
35
+ x1, y1, x2, y2 = bbox
36
+
37
+ # Adaptive Context Padding (30%)
38
+ w_box, h_box = x2 - x1, y2 - y1
39
+ pad_x = int(w_box * 0.3)
40
+ pad_y = int(h_box * 0.3)
41
+
42
+ crop_x1 = max(0, x1 - pad_x)
43
+ crop_y1 = max(0, y1 - pad_y)
44
+ crop_x2 = min(w_img, x2 + pad_x)
45
+ crop_y2 = min(h_img, y2 + pad_y)
46
+
47
+ crops.append(pil_image.crop((crop_x1, crop_y1, crop_x2, crop_y2)))
48
  valid_segments.append(seg)
49
 
50
  if not crops: return []
51
 
52
+ # 3. Inference (Brief GPU usage)
53
+ try:
54
+ self.model.to(self.device)
55
+ inputs = self.processor(
56
+ text=[clean_text], images=crops, return_tensors="pt", padding=True
57
+ ).to(self.device)
58
 
59
+ with torch.no_grad():
60
+ outputs = self.model(**inputs)
61
+ # FIX: Use raw logits for meaningful scores.
62
+ # (Softmax forces sum=1, concealing bad matches)
63
+ probs = outputs.logits_per_image.cpu().numpy().flatten()
64
+ except Exception as e:
65
+ print(f"CLIP Error: {e}")
66
+ return []
67
+ finally:
68
+ # Move back to CPU immediately
69
+ self.model.to("cpu")
70
 
71
+ # 4. Score & Sort
72
  final_results = []
73
  for i, score in enumerate(probs):
74
  seg = valid_segments[i]
75
+ if 'area' in seg:
76
+ area_ratio = seg['area'] / total_img_area
77
+ else:
78
+ w, h = seg['bbox'][2]-seg['bbox'][0], seg['bbox'][3]-seg['bbox'][1]
79
+ area_ratio = (w*h) / total_img_area
80
 
81
+ # Logits are roughly 15-30 range. Add small boost for area.
82
+ weighted_score = float(score) + (area_ratio * 2.0)
 
 
83
 
84
  final_results.append({
85
+ 'mask': seg.get('mask', None),
86
  'bbox': seg['bbox'],
87
  'original_score': float(score),
88
+ 'weighted_score': weighted_score,
89
+ 'label': seg.get('label', 'object')
90
  })
91
 
 
92
  final_results.sort(key=lambda x: x['weighted_score'], reverse=True)
93
  return final_results[:k]
src/painter.py CHANGED
@@ -75,7 +75,7 @@ class SDXLInpainter:
75
 
76
  # Blur the mask slightly to make the transition smoother
77
  import cv2
78
- mask = cv2.GaussianBlur(mask, (5, 5), 0)
79
 
80
  pil_mask = Image.fromarray((mask * 255).astype(np.uint8)).convert('L')
81
 
 
75
 
76
  # Blur the mask slightly to make the transition smoother
77
  import cv2
78
+ mask = cv2.GaussianBlur(mask, (21, 21), 0)
79
 
80
  pil_mask = Image.fromarray((mask * 255).astype(np.uint8)).convert('L')
81
 
src/pipeline.py CHANGED
@@ -1,58 +1,107 @@
1
  import numpy as np
2
  import cv2
3
- from .segmenter import SAM2Segmenter
 
 
 
4
  from .matcher import CLIPMatcher
5
  from .painter import SDXLInpainter
6
- from .utils import visualize_mask
7
 
8
  class ObjectRemovalPipeline:
9
  def __init__(self):
10
- print("Initializing models...")
11
- self.segmenter = SAM2Segmenter()
12
- self.matcher = CLIPMatcher()
13
- self.inpainter = SDXLInpainter()
14
- print("Pipeline ready.")
15
 
16
- def process(self, image, text_query, inpaint_prompt=""):
 
 
 
 
 
17
  """
18
- Main processing function for object removal.
 
19
  """
20
- # 1. Segment
21
- segments = self.segmenter.segment(image)
22
- if not segments:
23
- return image, None, "No segments found"
24
 
25
- # 2. Match with Top-K Strategy
26
- # We get top 5 candidates to handle "Part-Whole" ambiguity (e.g. tire vs car)
27
- candidates = self.matcher.get_top_k_segments(image, segments, text_query, k=5)
28
- if not candidates:
29
- return image, None, "No match found"
 
 
 
30
 
31
- # 3. Merge Masks (The "Cat Tail" Fix)
32
- best_candidate = candidates[0]
33
- final_mask = best_candidate['mask'].copy()
34
-
35
- print(f"Top Match Score: {best_candidate['weighted_score']:.3f}")
36
 
37
- # Merge other candidates if they are close in score or physically overlap
38
- for i in range(1, len(candidates)):
39
- cand = candidates[i]
40
- score_ratio = cand['weighted_score'] / best_candidate['weighted_score']
41
-
42
- # Check intersection
43
- intersection = np.logical_and(final_mask, cand['mask']).sum()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- # Rule: Merge if score is similar (>85%) OR if they overlap pixels
46
- if score_ratio > 0.85 or intersection > 0:
47
- print(f"Merging Rank {i+1} (Score ratio: {score_ratio:.2f}, Overlap: {intersection > 0})")
48
- final_mask = np.logical_or(final_mask, cand['mask'])
49
 
50
- # 4. Dilate Final Mask
51
- # Expands mask slightly to cover edges/seams
52
- kernel = np.ones((15, 15), np.uint8)
53
- final_mask = cv2.dilate(final_mask.astype(np.uint8), kernel, iterations=1)
 
 
 
 
 
 
54
 
55
- # 5. Inpaint
56
- result = self.inpainter.inpaint(image, final_mask, prompt=inpaint_prompt)
57
 
58
- return result, final_mask, "Success"
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  import cv2
3
+ import torch
4
+ import gc
5
+ # Note: We import classes but DO NOT instantiate them globally
6
+ from .segmenter import YOLOWorldDetector, SAM2Predictor
7
  from .matcher import CLIPMatcher
8
  from .painter import SDXLInpainter
 
9
 
10
  class ObjectRemovalPipeline:
11
  def __init__(self):
12
+ print("Initializing Pipeline in LOW MEMORY mode...")
13
+ # No models loaded at startup!
14
+ pass
 
 
15
 
16
+ def _clear_ram(self):
17
+ """Helper to force clear RAM & VRAM"""
18
+ gc.collect()
19
+ torch.cuda.empty_cache()
20
+
21
+ def get_candidates(self, image, text_query):
22
  """
23
+ Step 1: Detect & Segment & Rank
24
+ Strategy: Load one model at a time, use it, then delete it.
25
  """
26
+ candidates = []
27
+ box_candidates = []
 
 
28
 
29
+ # --- PHASE 1: YOLO (Detect) ---
30
+ print("Loading YOLO...")
31
+ detector = YOLOWorldDetector()
32
+ try:
33
+ box_candidates = detector.detect(image, text_query)
34
+ finally:
35
+ del detector # Delete model immediately
36
+ self._clear_ram()
37
 
38
+ if not box_candidates:
39
+ return [], "No objects detected."
 
 
 
40
 
41
+ # --- PHASE 2: SAM2 (Segment) ---
42
+ print("Loading SAM2...")
43
+ segmenter = SAM2Predictor()
44
+ segments_to_score = []
45
+ try:
46
+ segmenter.set_image(image)
47
+ # Process top 3 boxes -> up to 9 masks
48
+ for cand in box_candidates[:3]:
49
+ bbox = cand['bbox']
50
+ mask_variations = segmenter.predict_from_box(bbox)
51
+ for i, (mask, sam_score) in enumerate(mask_variations):
52
+ segments_to_score.append({
53
+ 'mask': mask,
54
+ 'bbox': bbox,
55
+ 'area': mask.sum(),
56
+ 'label': f"{cand['label']} (Var {i+1})"
57
+ })
58
+ finally:
59
+ # Critical cleanup for SAM2
60
+ if hasattr(segmenter, 'clear_memory'):
61
+ segmenter.clear_memory()
62
+ del segmenter
63
+ self._clear_ram()
64
+
65
+ # --- PHASE 3: CLIP (Rank) ---
66
+ print("Loading CLIP...")
67
+ matcher = CLIPMatcher()
68
+ ranked_candidates = []
69
+ try:
70
+ ranked_candidates = matcher.get_top_k_segments(
71
+ image,
72
+ segments_to_score,
73
+ text_query,
74
+ k=len(segments_to_score)
75
+ )
76
+ finally:
77
+ del matcher
78
+ self._clear_ram()
79
 
80
+ return ranked_candidates, f"Found {len(ranked_candidates)} options."
 
 
 
81
 
82
+ def inpaint_selected(self, image, selected_mask, inpaint_prompt="", shadow_expansion=0):
83
+ """
84
+ Step 2: Inpaint
85
+ """
86
+ # Shadow / Edge Logic (CPU ops)
87
+ if shadow_expansion > 0:
88
+ kernel_h = int(shadow_expansion * 1.5)
89
+ kernel_w = int(shadow_expansion * 0.5)
90
+ kernel = np.ones((kernel_h, kernel_w), np.uint8)
91
+ selected_mask = cv2.dilate(selected_mask.astype(np.uint8), kernel, iterations=1)
92
 
93
+ kernel = np.ones((10, 10), np.uint8)
94
+ final_mask = cv2.dilate(selected_mask.astype(np.uint8), kernel, iterations=1)
95
 
96
+ result = None
97
+
98
+ # --- PHASE 4: SDXL (Inpaint) ---
99
+ print("Loading SDXL...")
100
+ inpainter = SDXLInpainter()
101
+ try:
102
+ result = inpainter.inpaint(image, final_mask, prompt=inpaint_prompt)
103
+ finally:
104
+ del inpainter
105
+ self._clear_ram()
106
+
107
+ return result
src/segmenter.py CHANGED
@@ -1,84 +1,76 @@
1
  import torch
2
  import numpy as np
3
- import cv2
4
-
5
  from ultralytics import YOLO
6
- from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
7
 
8
- class YOLOSegmenter:
9
- def __init__(self, model_name='yolov8x-seg.pt'):
 
10
  self.model = YOLO(model_name)
 
11
 
12
- def segment(self, image):
13
- """Return list of (mask, bbox, class_id) tuples"""
14
- results = self.model(image)[0]
15
- segments = []
16
 
17
- if results.masks is not None:
18
- for i, mask in enumerate(results.masks.data):
19
- mask_np = mask.cpu().numpy()
20
- mask_resized = cv2.resize(mask_np, (image.shape[1], image.shape[0]))
21
- bbox = results.boxes.xyxy[i].cpu().numpy()
22
- class_id = int(results.boxes.cls[i])
23
- segments.append({
24
- 'mask': (mask_resized > 0.5).astype(np.uint8),
25
- 'bbox': bbox,
26
- 'class_id': class_id,
27
- 'class_name': self.model.names[class_id]
28
- })
29
-
30
- return segments
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- class SAM2Segmenter:
33
- def __init__(self, model_cfg='sam2.1_hiera_l.yaml', checkpoint=''):
34
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
35
- # Load the Automatic Generator
36
- self.mask_generator = SAM2AutomaticMaskGenerator.from_pretrained(
37
- "facebook/sam2.1-hiera-large",
38
- points_per_side=32,
39
- pred_iou_thresh=0.80,
40
- stability_score_thresh=0.92,
41
- crop_n_layers=1,
42
- crop_n_points_downscale_factor=2,
43
- device=self.device
44
- )
45
 
46
- def segment(self, image):
47
- """
48
- Generates masks and filters out background-like huge segments.
49
- """
50
- if hasattr(self.mask_generator, 'generate'):
51
- masks = self.mask_generator.generate(image)
52
- else:
53
- masks = self.mask_generator.predict(image)
54
 
55
- segments = []
56
- img_h, img_w = image.shape[:2]
57
- total_area = img_h * img_w
58
-
59
- for m in masks:
60
- # SAM returns [x, y, w, h]
61
- x, y, w, h = m['bbox']
62
-
63
- # Convert to [x1, y1, x2, y2]
64
- x1, y1, x2, y2 = x, y, x + w, y + h
65
-
66
- # Ignore masks that are too large (> 75% of image)
67
- if m['area'] > total_area * 0.75:
68
- continue
69
-
70
- # Ignore masks that are too small (< 0.5% of image)
71
- if m['area'] < total_area * 0.005:
72
- continue
73
-
74
- segments.append({
75
- 'mask': m['segmentation'].astype(np.uint8),
76
- 'bbox': np.array([x1, y1, x2, y2]),
77
- 'score': m.get('predicted_iou', 1.0),
78
- 'area': m['area']
79
- })
80
-
81
- # Sort by area (smallest to largest) to prefer specific objects over containers
82
- segments.sort(key=lambda s: s['area'])
83
 
84
- return segments
 
 
 
 
 
 
 
1
  import torch
2
  import numpy as np
3
+ import gc
 
4
  from ultralytics import YOLO
5
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
6
 
7
+ class YOLOWorldDetector:
8
+ def __init__(self, model_name='yolov8s-worldv2.pt'):
9
+ # Initialize, but manage device carefully
10
  self.model = YOLO(model_name)
11
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
 
13
+ def detect(self, image, text_query):
14
+ clean_text = text_query.replace("remove", "").replace("delete", "").strip()
15
+ if not clean_text: clean_text = "object"
 
16
 
17
+ boxes = []
18
+ try:
19
+ # FIX: Force CPU for text encoding to prevent RuntimeError
20
+ self.model.to('cpu')
21
+ self.model.set_classes([clean_text])
22
+
23
+ if self.device == 'cuda':
24
+ self.model.to('cuda')
25
+
26
+ results = self.model.predict(image, conf=0.05, iou=0.5, verbose=False)[0]
27
+
28
+ if results.boxes:
29
+ for box in results.boxes.data:
30
+ x1, y1, x2, y2 = box[:4].cpu().numpy()
31
+ conf = float(box[4])
32
+ boxes.append({
33
+ 'bbox': [int(x1), int(y1), int(x2), int(y2)],
34
+ 'score': conf,
35
+ 'label': clean_text
36
+ })
37
+ except Exception as e:
38
+ print(f"YOLO Error: {e}")
39
+ finally:
40
+ # Always offload after use
41
+ self.model.to('cpu')
42
+
43
+ boxes.sort(key=lambda x: x['score'], reverse=True)
44
+ return boxes
45
 
46
+ class SAM2Predictor:
47
+ def __init__(self, checkpoint="facebook/sam2.1-hiera-large"):
48
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
49
+ try:
50
+ self.predictor = SAM2ImagePredictor.from_pretrained(checkpoint)
51
+ except:
52
+ self.predictor = SAM2ImagePredictor.from_pretrained(checkpoint, device='cpu')
 
 
 
 
 
 
53
 
54
+ def set_image(self, image):
55
+ self.predictor.model.to(self.device)
56
+ self.predictor.set_image(image)
 
 
 
 
 
57
 
58
+ def predict_from_box(self, bbox):
59
+ box_input = np.array(bbox)[None, :]
60
+ # Multimask = True for variety
61
+ masks, scores, logits = self.predictor.predict(
62
+ point_coords=None,
63
+ point_labels=None,
64
+ box=box_input,
65
+ multimask_output=True
66
+ )
67
+ sorted_results = sorted(zip(masks, scores), key=lambda x: x[1], reverse=True)
68
+ return [(m.astype(np.uint8), s) for m, s in sorted_results]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ def clear_memory(self):
71
+ # Critical for preventing memory leaks
72
+ self.predictor.reset_predictor()
73
+ self.predictor.model.to('cpu')
74
+ del self.predictor
75
+ torch.cuda.empty_cache()
76
+ gc.collect()