Spaces:
Running
Running
implement second version
Browse files- .gitignore +2 -1
- app.py +74 -94
- src/__init__.py +3 -3
- src/matcher.py +51 -30
- src/painter.py +1 -1
- src/pipeline.py +90 -41
- src/segmenter.py +64 -72
.gitignore
CHANGED
|
@@ -206,4 +206,5 @@ marimo/_static/
|
|
| 206 |
marimo/_lsp/
|
| 207 |
__marimo__/
|
| 208 |
|
| 209 |
-
models/yolov8
|
|
|
|
|
|
| 206 |
marimo/_lsp/
|
| 207 |
__marimo__/
|
| 208 |
|
| 209 |
+
models/yolov8
|
| 210 |
+
rubrics.txt
|
app.py
CHANGED
|
@@ -1,129 +1,109 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import numpy as np
|
| 3 |
-
import torch
|
| 4 |
from src.pipeline import ObjectRemovalPipeline
|
| 5 |
from src.utils import visualize_mask
|
| 6 |
|
| 7 |
-
# Initialize pipeline
|
| 8 |
-
print("Loading pipeline...")
|
| 9 |
pipeline = ObjectRemovalPipeline()
|
| 10 |
|
| 11 |
def ensure_uint8(image):
|
| 12 |
-
|
| 13 |
-
Ensures the image is in valid uint8 format (0-255) for Gradio display.
|
| 14 |
-
"""
|
| 15 |
-
if image is None:
|
| 16 |
-
return None
|
| 17 |
-
|
| 18 |
image = np.array(image)
|
| 19 |
-
|
| 20 |
-
# 1. Handle NaN/Inf (Exploding gradients often cause this)
|
| 21 |
-
if not np.isfinite(image).all():
|
| 22 |
-
print("Warning: Image contains NaN or Inf. Replacing with black.")
|
| 23 |
-
image = np.nan_to_num(image, nan=0.0, posinf=255.0, neginf=0.0)
|
| 24 |
-
|
| 25 |
-
# 2. Normalize Float (0.0-1.0) to Int (0-255)
|
| 26 |
if image.dtype != np.uint8:
|
| 27 |
-
|
| 28 |
-
if image.max() <= 1.0:
|
| 29 |
-
image = (image * 255.0)
|
| 30 |
-
|
| 31 |
-
# Clip to safe range and cast
|
| 32 |
image = np.clip(image, 0, 255).astype(np.uint8)
|
| 33 |
-
|
| 34 |
return image
|
| 35 |
|
| 36 |
-
def
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
"""
|
| 40 |
-
if image is None:
|
| 41 |
-
return None, None, "Error: Please upload an image first."
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
)
|
| 58 |
|
| 59 |
-
|
| 60 |
-
progress(0.9, desc="Post-processing...")
|
| 61 |
-
mask_viz = None
|
| 62 |
-
if mask is not None:
|
| 63 |
-
mask_viz = visualize_mask(image, mask)
|
| 64 |
-
else:
|
| 65 |
-
# If no mask found, return original image as preview
|
| 66 |
-
mask_viz = image
|
| 67 |
-
|
| 68 |
-
mask_viz = ensure_uint8(mask_viz)
|
| 69 |
-
result = ensure_uint8(result)
|
| 70 |
|
| 71 |
-
|
|
|
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
-
#
|
| 79 |
css = """
|
| 80 |
-
footer {visibility: hidden}
|
| 81 |
.gradio-container {min-height: 0px !important}
|
|
|
|
|
|
|
| 82 |
"""
|
| 83 |
|
| 84 |
-
with gr.Blocks(title="
|
| 85 |
-
gr.
|
| 86 |
-
gr.
|
|
|
|
|
|
|
| 87 |
|
| 88 |
with gr.Row():
|
| 89 |
with gr.Column(scale=1):
|
| 90 |
input_image = gr.Image(label="Input Image", type="numpy", height=400)
|
| 91 |
-
text_query = gr.Textbox(
|
| 92 |
-
|
| 93 |
-
placeholder="e.g., 'bottle', 'cell', 'petri dish'",
|
| 94 |
-
info="What should be removed?"
|
| 95 |
-
)
|
| 96 |
-
inpaint_prompt = gr.Textbox(
|
| 97 |
-
label="Inpaint Prompt (Context)",
|
| 98 |
-
placeholder="background",
|
| 99 |
-
value="background",
|
| 100 |
-
info="What should fill the empty space?"
|
| 101 |
-
)
|
| 102 |
-
submit_btn = gr.Button("Run Pipeline", variant="primary")
|
| 103 |
|
| 104 |
with gr.Column(scale=1):
|
| 105 |
-
#
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
-
|
| 115 |
-
# Ensure these files actually exist in your folder, or comment this out
|
| 116 |
-
# gr.Examples(
|
| 117 |
-
# examples=[["examples/lab_bench.jpg", "remove the pipette", "table surface"]],
|
| 118 |
-
# inputs=[input_image, text_query, inpaint_prompt],
|
| 119 |
-
# )
|
| 120 |
|
| 121 |
-
|
| 122 |
-
fn=
|
| 123 |
-
inputs=[input_image,
|
| 124 |
-
outputs=[output_image,
|
| 125 |
)
|
| 126 |
|
| 127 |
if __name__ == "__main__":
|
| 128 |
-
# queue() is essential for handling GPU workloads and preventing timeouts
|
| 129 |
demo.queue().launch(share=True)
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import numpy as np
|
|
|
|
| 3 |
from src.pipeline import ObjectRemovalPipeline
|
| 4 |
from src.utils import visualize_mask
|
| 5 |
|
| 6 |
+
# Initialize pipeline once
|
|
|
|
| 7 |
pipeline = ObjectRemovalPipeline()
|
| 8 |
|
| 9 |
def ensure_uint8(image):
|
| 10 |
+
if image is None: return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
image = np.array(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
if image.dtype != np.uint8:
|
| 13 |
+
if image.max() <= 1.0: image = image * 255.0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
image = np.clip(image, 0, 255).astype(np.uint8)
|
|
|
|
| 15 |
return image
|
| 16 |
|
| 17 |
+
def step1_detect(image, text_query):
|
| 18 |
+
if image is None or not text_query:
|
| 19 |
+
return [], [], "Please upload image and enter text."
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
+
# Calls the new method in pipeline.py
|
| 22 |
+
candidates, msg = pipeline.get_candidates(image, text_query)
|
| 23 |
+
|
| 24 |
+
if not candidates:
|
| 25 |
+
return [], [], f"Error: {msg}"
|
| 26 |
+
|
| 27 |
+
masks = [c['mask'] for c in candidates]
|
| 28 |
+
|
| 29 |
+
# Generate visualization for gallery
|
| 30 |
+
gallery_imgs = []
|
| 31 |
+
for i, mask in enumerate(masks):
|
| 32 |
+
viz = visualize_mask(image, mask)
|
| 33 |
+
# Label with rank and score if available
|
| 34 |
+
label = f"Option {i+1} (Score: {candidates[i].get('weighted_score', 0):.2f})"
|
| 35 |
+
gallery_imgs.append((ensure_uint8(viz), label))
|
| 36 |
|
| 37 |
+
return masks, gallery_imgs, "Select the best match below."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
+
def on_select(evt: gr.SelectData):
|
| 40 |
+
return evt.index
|
| 41 |
|
| 42 |
+
def step2_remove(image, masks, selected_idx, prompt, shadow_exp):
|
| 43 |
+
if not masks or selected_idx is None:
|
| 44 |
+
return None, "Please select an object first."
|
| 45 |
+
|
| 46 |
+
target_mask = masks[selected_idx]
|
| 47 |
+
|
| 48 |
+
# Calls the pipeline method
|
| 49 |
+
result = pipeline.inpaint_selected(image, target_mask, prompt, shadow_expansion=shadow_exp)
|
| 50 |
+
|
| 51 |
+
return ensure_uint8(result), "Success!"
|
| 52 |
|
| 53 |
+
# CSS for cleaner UI
|
| 54 |
css = """
|
|
|
|
| 55 |
.gradio-container {min-height: 0px !important}
|
| 56 |
+
/* Ensure images in gallery don't get cropped strictly */
|
| 57 |
+
button.gallery-item {object-fit: contain !important}
|
| 58 |
"""
|
| 59 |
|
| 60 |
+
with gr.Blocks(title="TextEraser", css=css, theme=gr.themes.Soft()) as demo:
|
| 61 |
+
mask_state = gr.State([])
|
| 62 |
+
idx_state = gr.State(0)
|
| 63 |
+
|
| 64 |
+
gr.Markdown("## TextEraser: Interactive Object Removal")
|
| 65 |
|
| 66 |
with gr.Row():
|
| 67 |
with gr.Column(scale=1):
|
| 68 |
input_image = gr.Image(label="Input Image", type="numpy", height=400)
|
| 69 |
+
text_query = gr.Textbox(label="What to remove?", placeholder="e.g. 'bottle', 'shadow'")
|
| 70 |
+
btn_detect = gr.Button("1. Detect Objects", variant="primary")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
with gr.Column(scale=1):
|
| 73 |
+
# FIXED: object_fit="contain" prevents cropping
|
| 74 |
+
# allow_preview=True lets you click to zoom
|
| 75 |
+
gallery = gr.Gallery(
|
| 76 |
+
label="Candidates (Select One)",
|
| 77 |
+
columns=2,
|
| 78 |
+
height=400,
|
| 79 |
+
allow_preview=True,
|
| 80 |
+
object_fit="contain"
|
| 81 |
+
)
|
| 82 |
+
status = gr.Textbox(label="Status", interactive=False)
|
| 83 |
+
|
| 84 |
+
with gr.Row():
|
| 85 |
+
with gr.Column(scale=1):
|
| 86 |
+
shadow_slider = gr.Slider(0, 40, value=10, label="Shadow Fix (Expand Mask Downwards)")
|
| 87 |
+
inpaint_prompt = gr.Textbox(label="Background Description", value="background")
|
| 88 |
+
btn_remove = gr.Button("2. Remove Selected", variant="stop")
|
| 89 |
|
| 90 |
+
with gr.Column(scale=1):
|
| 91 |
+
output_image = gr.Image(label="Final Result", height=400)
|
| 92 |
+
|
| 93 |
+
# Event Wiring
|
| 94 |
+
btn_detect.click(
|
| 95 |
+
fn=step1_detect,
|
| 96 |
+
inputs=[input_image, text_query],
|
| 97 |
+
outputs=[mask_state, gallery, status]
|
| 98 |
+
)
|
| 99 |
|
| 100 |
+
gallery.select(fn=on_select, inputs=None, outputs=idx_state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
+
btn_remove.click(
|
| 103 |
+
fn=step2_remove,
|
| 104 |
+
inputs=[input_image, mask_state, idx_state, inpaint_prompt, shadow_slider],
|
| 105 |
+
outputs=[output_image, status]
|
| 106 |
)
|
| 107 |
|
| 108 |
if __name__ == "__main__":
|
|
|
|
| 109 |
demo.queue().launch(share=True)
|
src/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
from .pipeline import ObjectRemovalPipeline
|
| 2 |
-
from .segmenter import
|
| 3 |
from .matcher import CLIPMatcher
|
| 4 |
-
from .painter import
|
| 5 |
|
| 6 |
-
__all__ = ['ObjectRemovalPipeline', '
|
|
|
|
| 1 |
from .pipeline import ObjectRemovalPipeline
|
| 2 |
+
from .segmenter import SAM2Predictor
|
| 3 |
from .matcher import CLIPMatcher
|
| 4 |
+
from .painter import SDXLInpainter
|
| 5 |
|
| 6 |
+
__all__ = ['ObjectRemovalPipeline', 'CLIPMatcher', 'SDXLInpainter', 'SAM2Predictor']
|
src/matcher.py
CHANGED
|
@@ -1,17 +1,17 @@
|
|
| 1 |
import torch
|
|
|
|
|
|
|
| 2 |
from PIL import Image
|
| 3 |
from transformers import CLIPProcessor, CLIPModel
|
| 4 |
|
| 5 |
class CLIPMatcher:
|
| 6 |
def __init__(self, model_name='openai/clip-vit-large-patch14'):
|
| 7 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 8 |
-
|
|
|
|
| 9 |
self.processor = CLIPProcessor.from_pretrained(model_name)
|
| 10 |
|
| 11 |
def get_top_k_segments(self, image, segments, text_query, k=5):
|
| 12 |
-
"""
|
| 13 |
-
Returns top K segments based on CLIP score + Area Weight.
|
| 14 |
-
"""
|
| 15 |
if not segments: return []
|
| 16 |
|
| 17 |
# 1. Clean Text
|
|
@@ -19,54 +19,75 @@ class CLIPMatcher:
|
|
| 19 |
words = [w for w in text_query.lower().split() if w not in ignore]
|
| 20 |
clean_text = " ".join(words) if words else text_query
|
| 21 |
|
|
|
|
| 22 |
pil_image = Image.fromarray(image)
|
| 23 |
crops = []
|
| 24 |
valid_segments = []
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
total_img_area = h * w
|
| 29 |
|
| 30 |
for seg in segments:
|
| 31 |
-
|
| 32 |
-
# Pad slightly
|
| 33 |
-
pad = 10
|
| 34 |
-
x1, y1 = max(0, x1-pad), max(0, y1-pad)
|
| 35 |
-
x2, y2 = min(w, x2+pad), min(h, y2+pad)
|
| 36 |
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
valid_segments.append(seg)
|
| 39 |
|
| 40 |
if not crops: return []
|
| 41 |
|
| 42 |
-
#
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
-
#
|
| 53 |
final_results = []
|
| 54 |
for i, score in enumerate(probs):
|
| 55 |
seg = valid_segments[i]
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
-
#
|
| 59 |
-
|
| 60 |
-
# We add 20% of the area_ratio to the score.
|
| 61 |
-
weighted_score = score + (area_ratio * 0.2)
|
| 62 |
|
| 63 |
final_results.append({
|
| 64 |
-
'mask': seg
|
| 65 |
'bbox': seg['bbox'],
|
| 66 |
'original_score': float(score),
|
| 67 |
-
'weighted_score':
|
|
|
|
| 68 |
})
|
| 69 |
|
| 70 |
-
# 4. Sort and take Top K
|
| 71 |
final_results.sort(key=lambda x: x['weighted_score'], reverse=True)
|
| 72 |
return final_results[:k]
|
|
|
|
| 1 |
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import gc
|
| 4 |
from PIL import Image
|
| 5 |
from transformers import CLIPProcessor, CLIPModel
|
| 6 |
|
| 7 |
class CLIPMatcher:
|
| 8 |
def __init__(self, model_name='openai/clip-vit-large-patch14'):
|
| 9 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 10 |
+
# Load directly to CPU first
|
| 11 |
+
self.model = CLIPModel.from_pretrained(model_name).to("cpu")
|
| 12 |
self.processor = CLIPProcessor.from_pretrained(model_name)
|
| 13 |
|
| 14 |
def get_top_k_segments(self, image, segments, text_query, k=5):
|
|
|
|
|
|
|
|
|
|
| 15 |
if not segments: return []
|
| 16 |
|
| 17 |
# 1. Clean Text
|
|
|
|
| 19 |
words = [w for w in text_query.lower().split() if w not in ignore]
|
| 20 |
clean_text = " ".join(words) if words else text_query
|
| 21 |
|
| 22 |
+
# 2. Crop (CPU)
|
| 23 |
pil_image = Image.fromarray(image)
|
| 24 |
crops = []
|
| 25 |
valid_segments = []
|
| 26 |
|
| 27 |
+
h_img, w_img = image.shape[:2]
|
| 28 |
+
total_img_area = h_img * w_img
|
|
|
|
| 29 |
|
| 30 |
for seg in segments:
|
| 31 |
+
if 'bbox' not in seg: continue
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
+
# Safe numpy cast
|
| 34 |
+
bbox = np.array(seg['bbox']).astype(int)
|
| 35 |
+
x1, y1, x2, y2 = bbox
|
| 36 |
+
|
| 37 |
+
# Adaptive Context Padding (30%)
|
| 38 |
+
w_box, h_box = x2 - x1, y2 - y1
|
| 39 |
+
pad_x = int(w_box * 0.3)
|
| 40 |
+
pad_y = int(h_box * 0.3)
|
| 41 |
+
|
| 42 |
+
crop_x1 = max(0, x1 - pad_x)
|
| 43 |
+
crop_y1 = max(0, y1 - pad_y)
|
| 44 |
+
crop_x2 = min(w_img, x2 + pad_x)
|
| 45 |
+
crop_y2 = min(h_img, y2 + pad_y)
|
| 46 |
+
|
| 47 |
+
crops.append(pil_image.crop((crop_x1, crop_y1, crop_x2, crop_y2)))
|
| 48 |
valid_segments.append(seg)
|
| 49 |
|
| 50 |
if not crops: return []
|
| 51 |
|
| 52 |
+
# 3. Inference (Brief GPU usage)
|
| 53 |
+
try:
|
| 54 |
+
self.model.to(self.device)
|
| 55 |
+
inputs = self.processor(
|
| 56 |
+
text=[clean_text], images=crops, return_tensors="pt", padding=True
|
| 57 |
+
).to(self.device)
|
| 58 |
|
| 59 |
+
with torch.no_grad():
|
| 60 |
+
outputs = self.model(**inputs)
|
| 61 |
+
# FIX: Use raw logits for meaningful scores.
|
| 62 |
+
# (Softmax forces sum=1, concealing bad matches)
|
| 63 |
+
probs = outputs.logits_per_image.cpu().numpy().flatten()
|
| 64 |
+
except Exception as e:
|
| 65 |
+
print(f"CLIP Error: {e}")
|
| 66 |
+
return []
|
| 67 |
+
finally:
|
| 68 |
+
# Move back to CPU immediately
|
| 69 |
+
self.model.to("cpu")
|
| 70 |
|
| 71 |
+
# 4. Score & Sort
|
| 72 |
final_results = []
|
| 73 |
for i, score in enumerate(probs):
|
| 74 |
seg = valid_segments[i]
|
| 75 |
+
if 'area' in seg:
|
| 76 |
+
area_ratio = seg['area'] / total_img_area
|
| 77 |
+
else:
|
| 78 |
+
w, h = seg['bbox'][2]-seg['bbox'][0], seg['bbox'][3]-seg['bbox'][1]
|
| 79 |
+
area_ratio = (w*h) / total_img_area
|
| 80 |
|
| 81 |
+
# Logits are roughly 15-30 range. Add small boost for area.
|
| 82 |
+
weighted_score = float(score) + (area_ratio * 2.0)
|
|
|
|
|
|
|
| 83 |
|
| 84 |
final_results.append({
|
| 85 |
+
'mask': seg.get('mask', None),
|
| 86 |
'bbox': seg['bbox'],
|
| 87 |
'original_score': float(score),
|
| 88 |
+
'weighted_score': weighted_score,
|
| 89 |
+
'label': seg.get('label', 'object')
|
| 90 |
})
|
| 91 |
|
|
|
|
| 92 |
final_results.sort(key=lambda x: x['weighted_score'], reverse=True)
|
| 93 |
return final_results[:k]
|
src/painter.py
CHANGED
|
@@ -75,7 +75,7 @@ class SDXLInpainter:
|
|
| 75 |
|
| 76 |
# Blur the mask slightly to make the transition smoother
|
| 77 |
import cv2
|
| 78 |
-
mask = cv2.GaussianBlur(mask, (
|
| 79 |
|
| 80 |
pil_mask = Image.fromarray((mask * 255).astype(np.uint8)).convert('L')
|
| 81 |
|
|
|
|
| 75 |
|
| 76 |
# Blur the mask slightly to make the transition smoother
|
| 77 |
import cv2
|
| 78 |
+
mask = cv2.GaussianBlur(mask, (21, 21), 0)
|
| 79 |
|
| 80 |
pil_mask = Image.fromarray((mask * 255).astype(np.uint8)).convert('L')
|
| 81 |
|
src/pipeline.py
CHANGED
|
@@ -1,58 +1,107 @@
|
|
| 1 |
import numpy as np
|
| 2 |
import cv2
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
| 4 |
from .matcher import CLIPMatcher
|
| 5 |
from .painter import SDXLInpainter
|
| 6 |
-
from .utils import visualize_mask
|
| 7 |
|
| 8 |
class ObjectRemovalPipeline:
|
| 9 |
def __init__(self):
|
| 10 |
-
print("Initializing
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
self.inpainter = SDXLInpainter()
|
| 14 |
-
print("Pipeline ready.")
|
| 15 |
|
| 16 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
"""
|
| 18 |
-
|
|
|
|
| 19 |
"""
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
if not segments:
|
| 23 |
-
return image, None, "No segments found"
|
| 24 |
|
| 25 |
-
#
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
final_mask = best_candidate['mask'].copy()
|
| 34 |
-
|
| 35 |
-
print(f"Top Match Score: {best_candidate['weighted_score']:.3f}")
|
| 36 |
|
| 37 |
-
#
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
|
| 46 |
-
if score_ratio > 0.85 or intersection > 0:
|
| 47 |
-
print(f"Merging Rank {i+1} (Score ratio: {score_ratio:.2f}, Overlap: {intersection > 0})")
|
| 48 |
-
final_mask = np.logical_or(final_mask, cand['mask'])
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
import cv2
|
| 3 |
+
import torch
|
| 4 |
+
import gc
|
| 5 |
+
# Note: We import classes but DO NOT instantiate them globally
|
| 6 |
+
from .segmenter import YOLOWorldDetector, SAM2Predictor
|
| 7 |
from .matcher import CLIPMatcher
|
| 8 |
from .painter import SDXLInpainter
|
|
|
|
| 9 |
|
| 10 |
class ObjectRemovalPipeline:
|
| 11 |
def __init__(self):
|
| 12 |
+
print("Initializing Pipeline in LOW MEMORY mode...")
|
| 13 |
+
# No models loaded at startup!
|
| 14 |
+
pass
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
def _clear_ram(self):
|
| 17 |
+
"""Helper to force clear RAM & VRAM"""
|
| 18 |
+
gc.collect()
|
| 19 |
+
torch.cuda.empty_cache()
|
| 20 |
+
|
| 21 |
+
def get_candidates(self, image, text_query):
|
| 22 |
"""
|
| 23 |
+
Step 1: Detect & Segment & Rank
|
| 24 |
+
Strategy: Load one model at a time, use it, then delete it.
|
| 25 |
"""
|
| 26 |
+
candidates = []
|
| 27 |
+
box_candidates = []
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
# --- PHASE 1: YOLO (Detect) ---
|
| 30 |
+
print("Loading YOLO...")
|
| 31 |
+
detector = YOLOWorldDetector()
|
| 32 |
+
try:
|
| 33 |
+
box_candidates = detector.detect(image, text_query)
|
| 34 |
+
finally:
|
| 35 |
+
del detector # Delete model immediately
|
| 36 |
+
self._clear_ram()
|
| 37 |
|
| 38 |
+
if not box_candidates:
|
| 39 |
+
return [], "No objects detected."
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
+
# --- PHASE 2: SAM2 (Segment) ---
|
| 42 |
+
print("Loading SAM2...")
|
| 43 |
+
segmenter = SAM2Predictor()
|
| 44 |
+
segments_to_score = []
|
| 45 |
+
try:
|
| 46 |
+
segmenter.set_image(image)
|
| 47 |
+
# Process top 3 boxes -> up to 9 masks
|
| 48 |
+
for cand in box_candidates[:3]:
|
| 49 |
+
bbox = cand['bbox']
|
| 50 |
+
mask_variations = segmenter.predict_from_box(bbox)
|
| 51 |
+
for i, (mask, sam_score) in enumerate(mask_variations):
|
| 52 |
+
segments_to_score.append({
|
| 53 |
+
'mask': mask,
|
| 54 |
+
'bbox': bbox,
|
| 55 |
+
'area': mask.sum(),
|
| 56 |
+
'label': f"{cand['label']} (Var {i+1})"
|
| 57 |
+
})
|
| 58 |
+
finally:
|
| 59 |
+
# Critical cleanup for SAM2
|
| 60 |
+
if hasattr(segmenter, 'clear_memory'):
|
| 61 |
+
segmenter.clear_memory()
|
| 62 |
+
del segmenter
|
| 63 |
+
self._clear_ram()
|
| 64 |
+
|
| 65 |
+
# --- PHASE 3: CLIP (Rank) ---
|
| 66 |
+
print("Loading CLIP...")
|
| 67 |
+
matcher = CLIPMatcher()
|
| 68 |
+
ranked_candidates = []
|
| 69 |
+
try:
|
| 70 |
+
ranked_candidates = matcher.get_top_k_segments(
|
| 71 |
+
image,
|
| 72 |
+
segments_to_score,
|
| 73 |
+
text_query,
|
| 74 |
+
k=len(segments_to_score)
|
| 75 |
+
)
|
| 76 |
+
finally:
|
| 77 |
+
del matcher
|
| 78 |
+
self._clear_ram()
|
| 79 |
|
| 80 |
+
return ranked_candidates, f"Found {len(ranked_candidates)} options."
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
+
def inpaint_selected(self, image, selected_mask, inpaint_prompt="", shadow_expansion=0):
|
| 83 |
+
"""
|
| 84 |
+
Step 2: Inpaint
|
| 85 |
+
"""
|
| 86 |
+
# Shadow / Edge Logic (CPU ops)
|
| 87 |
+
if shadow_expansion > 0:
|
| 88 |
+
kernel_h = int(shadow_expansion * 1.5)
|
| 89 |
+
kernel_w = int(shadow_expansion * 0.5)
|
| 90 |
+
kernel = np.ones((kernel_h, kernel_w), np.uint8)
|
| 91 |
+
selected_mask = cv2.dilate(selected_mask.astype(np.uint8), kernel, iterations=1)
|
| 92 |
|
| 93 |
+
kernel = np.ones((10, 10), np.uint8)
|
| 94 |
+
final_mask = cv2.dilate(selected_mask.astype(np.uint8), kernel, iterations=1)
|
| 95 |
|
| 96 |
+
result = None
|
| 97 |
+
|
| 98 |
+
# --- PHASE 4: SDXL (Inpaint) ---
|
| 99 |
+
print("Loading SDXL...")
|
| 100 |
+
inpainter = SDXLInpainter()
|
| 101 |
+
try:
|
| 102 |
+
result = inpainter.inpaint(image, final_mask, prompt=inpaint_prompt)
|
| 103 |
+
finally:
|
| 104 |
+
del inpainter
|
| 105 |
+
self._clear_ram()
|
| 106 |
+
|
| 107 |
+
return result
|
src/segmenter.py
CHANGED
|
@@ -1,84 +1,76 @@
|
|
| 1 |
import torch
|
| 2 |
import numpy as np
|
| 3 |
-
import
|
| 4 |
-
|
| 5 |
from ultralytics import YOLO
|
| 6 |
-
from sam2.
|
| 7 |
|
| 8 |
-
class
|
| 9 |
-
def __init__(self, model_name='
|
|
|
|
| 10 |
self.model = YOLO(model_name)
|
|
|
|
| 11 |
|
| 12 |
-
def
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
segments = []
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
-
class
|
| 33 |
-
def __init__(self,
|
| 34 |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
pred_iou_thresh=0.80,
|
| 40 |
-
stability_score_thresh=0.92,
|
| 41 |
-
crop_n_layers=1,
|
| 42 |
-
crop_n_points_downscale_factor=2,
|
| 43 |
-
device=self.device
|
| 44 |
-
)
|
| 45 |
|
| 46 |
-
def
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
"""
|
| 50 |
-
if hasattr(self.mask_generator, 'generate'):
|
| 51 |
-
masks = self.mask_generator.generate(image)
|
| 52 |
-
else:
|
| 53 |
-
masks = self.mask_generator.predict(image)
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
# Ignore masks that are too large (> 75% of image)
|
| 67 |
-
if m['area'] > total_area * 0.75:
|
| 68 |
-
continue
|
| 69 |
-
|
| 70 |
-
# Ignore masks that are too small (< 0.5% of image)
|
| 71 |
-
if m['area'] < total_area * 0.005:
|
| 72 |
-
continue
|
| 73 |
-
|
| 74 |
-
segments.append({
|
| 75 |
-
'mask': m['segmentation'].astype(np.uint8),
|
| 76 |
-
'bbox': np.array([x1, y1, x2, y2]),
|
| 77 |
-
'score': m.get('predicted_iou', 1.0),
|
| 78 |
-
'area': m['area']
|
| 79 |
-
})
|
| 80 |
-
|
| 81 |
-
# Sort by area (smallest to largest) to prefer specific objects over containers
|
| 82 |
-
segments.sort(key=lambda s: s['area'])
|
| 83 |
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import numpy as np
|
| 3 |
+
import gc
|
|
|
|
| 4 |
from ultralytics import YOLO
|
| 5 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 6 |
|
| 7 |
+
class YOLOWorldDetector:
|
| 8 |
+
def __init__(self, model_name='yolov8s-worldv2.pt'):
|
| 9 |
+
# Initialize, but manage device carefully
|
| 10 |
self.model = YOLO(model_name)
|
| 11 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 12 |
|
| 13 |
+
def detect(self, image, text_query):
|
| 14 |
+
clean_text = text_query.replace("remove", "").replace("delete", "").strip()
|
| 15 |
+
if not clean_text: clean_text = "object"
|
|
|
|
| 16 |
|
| 17 |
+
boxes = []
|
| 18 |
+
try:
|
| 19 |
+
# FIX: Force CPU for text encoding to prevent RuntimeError
|
| 20 |
+
self.model.to('cpu')
|
| 21 |
+
self.model.set_classes([clean_text])
|
| 22 |
+
|
| 23 |
+
if self.device == 'cuda':
|
| 24 |
+
self.model.to('cuda')
|
| 25 |
+
|
| 26 |
+
results = self.model.predict(image, conf=0.05, iou=0.5, verbose=False)[0]
|
| 27 |
+
|
| 28 |
+
if results.boxes:
|
| 29 |
+
for box in results.boxes.data:
|
| 30 |
+
x1, y1, x2, y2 = box[:4].cpu().numpy()
|
| 31 |
+
conf = float(box[4])
|
| 32 |
+
boxes.append({
|
| 33 |
+
'bbox': [int(x1), int(y1), int(x2), int(y2)],
|
| 34 |
+
'score': conf,
|
| 35 |
+
'label': clean_text
|
| 36 |
+
})
|
| 37 |
+
except Exception as e:
|
| 38 |
+
print(f"YOLO Error: {e}")
|
| 39 |
+
finally:
|
| 40 |
+
# Always offload after use
|
| 41 |
+
self.model.to('cpu')
|
| 42 |
+
|
| 43 |
+
boxes.sort(key=lambda x: x['score'], reverse=True)
|
| 44 |
+
return boxes
|
| 45 |
|
| 46 |
+
class SAM2Predictor:
|
| 47 |
+
def __init__(self, checkpoint="facebook/sam2.1-hiera-large"):
|
| 48 |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 49 |
+
try:
|
| 50 |
+
self.predictor = SAM2ImagePredictor.from_pretrained(checkpoint)
|
| 51 |
+
except:
|
| 52 |
+
self.predictor = SAM2ImagePredictor.from_pretrained(checkpoint, device='cpu')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
+
def set_image(self, image):
|
| 55 |
+
self.predictor.model.to(self.device)
|
| 56 |
+
self.predictor.set_image(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
+
def predict_from_box(self, bbox):
|
| 59 |
+
box_input = np.array(bbox)[None, :]
|
| 60 |
+
# Multimask = True for variety
|
| 61 |
+
masks, scores, logits = self.predictor.predict(
|
| 62 |
+
point_coords=None,
|
| 63 |
+
point_labels=None,
|
| 64 |
+
box=box_input,
|
| 65 |
+
multimask_output=True
|
| 66 |
+
)
|
| 67 |
+
sorted_results = sorted(zip(masks, scores), key=lambda x: x[1], reverse=True)
|
| 68 |
+
return [(m.astype(np.uint8), s) for m, s in sorted_results]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
+
def clear_memory(self):
|
| 71 |
+
# Critical for preventing memory leaks
|
| 72 |
+
self.predictor.reset_predictor()
|
| 73 |
+
self.predictor.model.to('cpu')
|
| 74 |
+
del self.predictor
|
| 75 |
+
torch.cuda.empty_cache()
|
| 76 |
+
gc.collect()
|