Spaces:
Running
Running
Implement SAM2 and better inpainting
Browse files- src/matcher.py +44 -36
- src/painter.py +70 -1
- src/pipeline.py +43 -30
- src/segmenter.py +58 -1
src/matcher.py
CHANGED
|
@@ -3,62 +3,70 @@ from PIL import Image
|
|
| 3 |
from transformers import CLIPProcessor, CLIPModel
|
| 4 |
|
| 5 |
class CLIPMatcher:
|
| 6 |
-
def __init__(self, model_name='openai/clip-vit-
|
| 7 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 8 |
self.model = CLIPModel.from_pretrained(model_name).to(self.device)
|
| 9 |
self.processor = CLIPProcessor.from_pretrained(model_name)
|
| 10 |
-
self.model.eval()
|
| 11 |
|
| 12 |
-
def
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
| 19 |
|
| 20 |
-
# If query becomes empty (e.g. user just typed "remove"), fallback to original
|
| 21 |
-
target_text = clean_query if clean_query else text_query
|
| 22 |
-
|
| 23 |
-
print(f"Debug: CLIP searching for object: '{target_text}'")
|
| 24 |
-
|
| 25 |
pil_image = Image.fromarray(image)
|
| 26 |
-
best_score = -float('inf')
|
| 27 |
-
best_segment = None
|
| 28 |
-
|
| 29 |
crops = []
|
| 30 |
valid_segments = []
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
for seg in segments:
|
| 33 |
x1, y1, x2, y2 = seg['bbox'].astype(int)
|
| 34 |
-
#
|
| 35 |
-
|
| 36 |
-
x1, y1 = max(0, x1), max(0, y1)
|
| 37 |
-
x2, y2 = min(w, x2), min(h, y2)
|
| 38 |
-
|
| 39 |
-
if x2 - x1 < 5 or y2 - y1 < 5: continue # Skip tiny/invalid boxes
|
| 40 |
|
| 41 |
crops.append(pil_image.crop((x1, y1, x2, y2)))
|
| 42 |
valid_segments.append(seg)
|
| 43 |
|
| 44 |
-
if not crops: return
|
| 45 |
|
| 46 |
-
#
|
| 47 |
inputs = self.processor(
|
| 48 |
-
text=[
|
| 49 |
-
images=crops,
|
| 50 |
-
return_tensors="pt",
|
| 51 |
-
padding=True
|
| 52 |
).to(self.device)
|
| 53 |
|
| 54 |
with torch.no_grad():
|
| 55 |
outputs = self.model(**inputs)
|
| 56 |
-
#
|
| 57 |
-
probs = outputs.logits_per_image.softmax(dim=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
| 63 |
|
| 64 |
-
|
|
|
|
|
|
|
|
|
| 3 |
from transformers import CLIPProcessor, CLIPModel
|
| 4 |
|
| 5 |
class CLIPMatcher:
|
| 6 |
+
def __init__(self, model_name='openai/clip-vit-large-patch14'):
|
| 7 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 8 |
self.model = CLIPModel.from_pretrained(model_name).to(self.device)
|
| 9 |
self.processor = CLIPProcessor.from_pretrained(model_name)
|
|
|
|
| 10 |
|
| 11 |
+
def get_top_k_segments(self, image, segments, text_query, k=5):
|
| 12 |
+
"""
|
| 13 |
+
Returns top K segments based on CLIP score + Area Weight.
|
| 14 |
+
"""
|
| 15 |
+
if not segments: return []
|
| 16 |
|
| 17 |
+
# 1. Clean Text
|
| 18 |
+
ignore = ['remove', 'delete', 'erase', 'the', 'a', 'an']
|
| 19 |
+
words = [w for w in text_query.lower().split() if w not in ignore]
|
| 20 |
+
clean_text = " ".join(words) if words else text_query
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
pil_image = Image.fromarray(image)
|
|
|
|
|
|
|
|
|
|
| 23 |
crops = []
|
| 24 |
valid_segments = []
|
| 25 |
+
|
| 26 |
+
# Prepare crops
|
| 27 |
+
h, w = image.shape[:2]
|
| 28 |
+
total_img_area = h * w
|
| 29 |
+
|
| 30 |
for seg in segments:
|
| 31 |
x1, y1, x2, y2 = seg['bbox'].astype(int)
|
| 32 |
+
# Pad slightly
|
| 33 |
+
pad = 10
|
| 34 |
+
x1, y1 = max(0, x1-pad), max(0, y1-pad)
|
| 35 |
+
x2, y2 = min(w, x2+pad), min(h, y2+pad)
|
|
|
|
|
|
|
| 36 |
|
| 37 |
crops.append(pil_image.crop((x1, y1, x2, y2)))
|
| 38 |
valid_segments.append(seg)
|
| 39 |
|
| 40 |
+
if not crops: return []
|
| 41 |
|
| 42 |
+
# 2. Inference
|
| 43 |
inputs = self.processor(
|
| 44 |
+
text=[clean_text], images=crops, return_tensors="pt", padding=True
|
|
|
|
|
|
|
|
|
|
| 45 |
).to(self.device)
|
| 46 |
|
| 47 |
with torch.no_grad():
|
| 48 |
outputs = self.model(**inputs)
|
| 49 |
+
# Standardize scores
|
| 50 |
+
probs = outputs.logits_per_image.softmax(dim=0).cpu().numpy().flatten()
|
| 51 |
+
|
| 52 |
+
# 3. Re-Scoring with Area Weight
|
| 53 |
+
final_results = []
|
| 54 |
+
for i, score in enumerate(probs):
|
| 55 |
+
seg = valid_segments[i]
|
| 56 |
+
area_ratio = seg['area'] / total_img_area
|
| 57 |
+
|
| 58 |
+
# HEURISTIC: Boost score for larger objects.
|
| 59 |
+
# If searching for general terms (bus, car, cat), bigger is usually better.
|
| 60 |
+
# We add 20% of the area_ratio to the score.
|
| 61 |
+
weighted_score = score + (area_ratio * 0.2)
|
| 62 |
|
| 63 |
+
final_results.append({
|
| 64 |
+
'mask': seg['mask'],
|
| 65 |
+
'bbox': seg['bbox'],
|
| 66 |
+
'original_score': float(score),
|
| 67 |
+
'weighted_score': float(weighted_score)
|
| 68 |
+
})
|
| 69 |
|
| 70 |
+
# 4. Sort and take Top K
|
| 71 |
+
final_results.sort(key=lambda x: x['weighted_score'], reverse=True)
|
| 72 |
+
return final_results[:k]
|
src/painter.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import torch
|
| 2 |
import numpy as np
|
| 3 |
from PIL import Image
|
| 4 |
-
from diffusers import StableDiffusionInpaintPipeline
|
| 5 |
|
| 6 |
class SDInpainter:
|
| 7 |
def __init__(self, model_id="runwayml/stable-diffusion-inpainting"):
|
|
@@ -48,6 +48,75 @@ class SDInpainter:
|
|
| 48 |
|
| 49 |
def _dilate_mask(self, mask, kernel_size=9):
|
| 50 |
# Increased kernel size slightly for better blending
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
import cv2
|
| 52 |
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
| 53 |
return cv2.dilate(mask, kernel, iterations=1)
|
|
|
|
| 1 |
import torch
|
| 2 |
import numpy as np
|
| 3 |
from PIL import Image
|
| 4 |
+
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionXLInpaintPipeline
|
| 5 |
|
| 6 |
class SDInpainter:
|
| 7 |
def __init__(self, model_id="runwayml/stable-diffusion-inpainting"):
|
|
|
|
| 48 |
|
| 49 |
def _dilate_mask(self, mask, kernel_size=9):
|
| 50 |
# Increased kernel size slightly for better blending
|
| 51 |
+
import cv2
|
| 52 |
+
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
| 53 |
+
return cv2.dilate(mask, kernel, iterations=1)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class SDXLInpainter:
|
| 57 |
+
def __init__(self, model_id="diffusers/stable-diffusion-xl-1.0-inpainting-0.1"):
|
| 58 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 59 |
+
# Use float16
|
| 60 |
+
self.pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
|
| 61 |
+
model_id,
|
| 62 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
| 63 |
+
variant="fp16", # Add variant for faster loading if available
|
| 64 |
+
use_safetensors=True
|
| 65 |
+
).to(self.device)
|
| 66 |
+
|
| 67 |
+
if self.device == "cuda":
|
| 68 |
+
self.pipe.enable_model_cpu_offload() # Saves VRAM effectively
|
| 69 |
+
|
| 70 |
+
def inpaint(self, image, mask, prompt=""): # Default prompt changed to empty
|
| 71 |
+
pil_image = Image.fromarray(image).convert('RGB')
|
| 72 |
+
|
| 73 |
+
# Increase kernel size to 15 or 20 to ensure no edge artifacts remain
|
| 74 |
+
mask = self._dilate_mask(mask, kernel_size=15)
|
| 75 |
+
|
| 76 |
+
# Blur the mask slightly to make the transition smoother
|
| 77 |
+
import cv2
|
| 78 |
+
mask = cv2.GaussianBlur(mask, (5, 5), 0)
|
| 79 |
+
|
| 80 |
+
pil_mask = Image.fromarray((mask * 255).astype(np.uint8)).convert('L')
|
| 81 |
+
|
| 82 |
+
w, h = pil_image.size
|
| 83 |
+
target_size = 1024
|
| 84 |
+
scale = target_size / max(w, h)
|
| 85 |
+
new_w = int(w * scale) - (int(w * scale) % 8)
|
| 86 |
+
new_h = int(h * scale) - (int(h * scale) % 8)
|
| 87 |
+
|
| 88 |
+
resized_image = pil_image.resize((new_w, new_h), Image.LANCZOS)
|
| 89 |
+
resized_mask = pil_mask.resize((new_w, new_h), Image.NEAREST)
|
| 90 |
+
|
| 91 |
+
if not prompt or prompt == "background":
|
| 92 |
+
final_prompt = "clean background, empty space, seamless texture, high quality"
|
| 93 |
+
# Lower guidance scale for background filling to rely more on image context
|
| 94 |
+
guidance_scale = 4.5
|
| 95 |
+
else:
|
| 96 |
+
final_prompt = prompt
|
| 97 |
+
guidance_scale = 7.5
|
| 98 |
+
|
| 99 |
+
neg_prompt = (
|
| 100 |
+
"object, subject, person, animal, cat, dog, "
|
| 101 |
+
"glass, transparent, crystal, bottle, cup, reflection, "
|
| 102 |
+
"complex, 3d render, artifacts, shadow, distortion, blur, watermark"
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
output = self.pipe(
|
| 106 |
+
prompt=final_prompt,
|
| 107 |
+
negative_prompt=neg_prompt,
|
| 108 |
+
image=resized_image,
|
| 109 |
+
mask_image=resized_mask,
|
| 110 |
+
num_inference_steps=40,
|
| 111 |
+
guidance_scale=guidance_scale, # Dynamic guidance
|
| 112 |
+
strength=0.99, # High strength to ensure removal
|
| 113 |
+
).images[0]
|
| 114 |
+
|
| 115 |
+
result = output.resize((w, h), Image.LANCZOS)
|
| 116 |
+
|
| 117 |
+
return np.array(result)
|
| 118 |
+
|
| 119 |
+
def _dilate_mask(self, mask, kernel_size=15):
|
| 120 |
import cv2
|
| 121 |
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
| 122 |
return cv2.dilate(mask, kernel, iterations=1)
|
src/pipeline.py
CHANGED
|
@@ -1,45 +1,58 @@
|
|
| 1 |
import numpy as np
|
| 2 |
-
|
|
|
|
| 3 |
from .matcher import CLIPMatcher
|
| 4 |
-
from .painter import
|
| 5 |
-
from .utils import
|
| 6 |
|
| 7 |
class ObjectRemovalPipeline:
|
| 8 |
def __init__(self):
|
| 9 |
print("Initializing models...")
|
| 10 |
-
self.segmenter =
|
| 11 |
self.matcher = CLIPMatcher()
|
| 12 |
-
self.inpainter =
|
| 13 |
-
print("
|
| 14 |
|
| 15 |
-
def process(self, image, text_query, inpaint_prompt="
|
| 16 |
"""
|
| 17 |
-
Main
|
| 18 |
-
Args:
|
| 19 |
-
image: numpy array (H, W, 3)
|
| 20 |
-
text_query: str, e.g., "remove the bottle"
|
| 21 |
-
inpaint_prompt: str, prompt for inpainting
|
| 22 |
"""
|
| 23 |
-
#
|
| 24 |
-
original_shape = image.shape[:2]
|
| 25 |
-
image = resize_image(image, max_size=1024)
|
| 26 |
-
|
| 27 |
-
# Step 1: Segment objects
|
| 28 |
segments = self.segmenter.segment(image)
|
| 29 |
if not segments:
|
| 30 |
-
return image, None, "No
|
| 31 |
-
|
| 32 |
-
# Step 2: Match text query to segment
|
| 33 |
-
matched_segment = self.matcher.match_segments(image, segments, text_query)
|
| 34 |
-
if matched_segment is None:
|
| 35 |
-
return image, None, "No matching object found"
|
| 36 |
|
| 37 |
-
#
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
return result,
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
from .segmenter import SAM2Segmenter
|
| 4 |
from .matcher import CLIPMatcher
|
| 5 |
+
from .painter import SDXLInpainter
|
| 6 |
+
from .utils import visualize_mask
|
| 7 |
|
| 8 |
class ObjectRemovalPipeline:
|
| 9 |
def __init__(self):
|
| 10 |
print("Initializing models...")
|
| 11 |
+
self.segmenter = SAM2Segmenter()
|
| 12 |
self.matcher = CLIPMatcher()
|
| 13 |
+
self.inpainter = SDXLInpainter()
|
| 14 |
+
print("Pipeline ready.")
|
| 15 |
|
| 16 |
+
def process(self, image, text_query, inpaint_prompt=""):
|
| 17 |
"""
|
| 18 |
+
Main processing function for object removal.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
"""
|
| 20 |
+
# 1. Segment
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
segments = self.segmenter.segment(image)
|
| 22 |
if not segments:
|
| 23 |
+
return image, None, "No segments found"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
+
# 2. Match with Top-K Strategy
|
| 26 |
+
# We get top 5 candidates to handle "Part-Whole" ambiguity (e.g. tire vs car)
|
| 27 |
+
candidates = self.matcher.get_top_k_segments(image, segments, text_query, k=5)
|
| 28 |
+
if not candidates:
|
| 29 |
+
return image, None, "No match found"
|
| 30 |
+
|
| 31 |
+
# 3. Merge Masks (The "Cat Tail" Fix)
|
| 32 |
+
best_candidate = candidates[0]
|
| 33 |
+
final_mask = best_candidate['mask'].copy()
|
| 34 |
|
| 35 |
+
print(f"Top Match Score: {best_candidate['weighted_score']:.3f}")
|
| 36 |
+
|
| 37 |
+
# Merge other candidates if they are close in score or physically overlap
|
| 38 |
+
for i in range(1, len(candidates)):
|
| 39 |
+
cand = candidates[i]
|
| 40 |
+
score_ratio = cand['weighted_score'] / best_candidate['weighted_score']
|
| 41 |
+
|
| 42 |
+
# Check intersection
|
| 43 |
+
intersection = np.logical_and(final_mask, cand['mask']).sum()
|
| 44 |
+
|
| 45 |
+
# Rule: Merge if score is similar (>85%) OR if they overlap pixels
|
| 46 |
+
if score_ratio > 0.85 or intersection > 0:
|
| 47 |
+
print(f"Merging Rank {i+1} (Score ratio: {score_ratio:.2f}, Overlap: {intersection > 0})")
|
| 48 |
+
final_mask = np.logical_or(final_mask, cand['mask'])
|
| 49 |
+
|
| 50 |
+
# 4. Dilate Final Mask
|
| 51 |
+
# Expands mask slightly to cover edges/seams
|
| 52 |
+
kernel = np.ones((15, 15), np.uint8)
|
| 53 |
+
final_mask = cv2.dilate(final_mask.astype(np.uint8), kernel, iterations=1)
|
| 54 |
+
|
| 55 |
+
# 5. Inpaint
|
| 56 |
+
result = self.inpainter.inpaint(image, final_mask, prompt=inpaint_prompt)
|
| 57 |
|
| 58 |
+
return result, final_mask, "Success"
|
src/segmenter.py
CHANGED
|
@@ -1,7 +1,10 @@
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
-
from ultralytics import YOLO
|
| 3 |
import cv2
|
| 4 |
|
|
|
|
|
|
|
|
|
|
| 5 |
class YOLOSegmenter:
|
| 6 |
def __init__(self, model_name='yolov8x-seg.pt'):
|
| 7 |
self.model = YOLO(model_name)
|
|
@@ -24,4 +27,58 @@ class YOLOSegmenter:
|
|
| 24 |
'class_name': self.model.names[class_id]
|
| 25 |
})
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
return segments
|
|
|
|
| 1 |
+
import torch
|
| 2 |
import numpy as np
|
|
|
|
| 3 |
import cv2
|
| 4 |
|
| 5 |
+
from ultralytics import YOLO
|
| 6 |
+
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
| 7 |
+
|
| 8 |
class YOLOSegmenter:
|
| 9 |
def __init__(self, model_name='yolov8x-seg.pt'):
|
| 10 |
self.model = YOLO(model_name)
|
|
|
|
| 27 |
'class_name': self.model.names[class_id]
|
| 28 |
})
|
| 29 |
|
| 30 |
+
return segments
|
| 31 |
+
|
| 32 |
+
class SAM2Segmenter:
|
| 33 |
+
def __init__(self, model_cfg='sam2.1_hiera_l.yaml', checkpoint=''):
|
| 34 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 35 |
+
# Load the Automatic Generator
|
| 36 |
+
self.mask_generator = SAM2AutomaticMaskGenerator.from_pretrained(
|
| 37 |
+
"facebook/sam2.1-hiera-large",
|
| 38 |
+
points_per_side=32,
|
| 39 |
+
pred_iou_thresh=0.80,
|
| 40 |
+
stability_score_thresh=0.92,
|
| 41 |
+
crop_n_layers=1,
|
| 42 |
+
crop_n_points_downscale_factor=2,
|
| 43 |
+
device=self.device
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def segment(self, image):
|
| 47 |
+
"""
|
| 48 |
+
Generates masks and filters out background-like huge segments.
|
| 49 |
+
"""
|
| 50 |
+
if hasattr(self.mask_generator, 'generate'):
|
| 51 |
+
masks = self.mask_generator.generate(image)
|
| 52 |
+
else:
|
| 53 |
+
masks = self.mask_generator.predict(image)
|
| 54 |
+
|
| 55 |
+
segments = []
|
| 56 |
+
img_h, img_w = image.shape[:2]
|
| 57 |
+
total_area = img_h * img_w
|
| 58 |
+
|
| 59 |
+
for m in masks:
|
| 60 |
+
# SAM returns [x, y, w, h]
|
| 61 |
+
x, y, w, h = m['bbox']
|
| 62 |
+
|
| 63 |
+
# Convert to [x1, y1, x2, y2]
|
| 64 |
+
x1, y1, x2, y2 = x, y, x + w, y + h
|
| 65 |
+
|
| 66 |
+
# Ignore masks that are too large (> 75% of image)
|
| 67 |
+
if m['area'] > total_area * 0.75:
|
| 68 |
+
continue
|
| 69 |
+
|
| 70 |
+
# Ignore masks that are too small (< 0.5% of image)
|
| 71 |
+
if m['area'] < total_area * 0.005:
|
| 72 |
+
continue
|
| 73 |
+
|
| 74 |
+
segments.append({
|
| 75 |
+
'mask': m['segmentation'].astype(np.uint8),
|
| 76 |
+
'bbox': np.array([x1, y1, x2, y2]),
|
| 77 |
+
'score': m.get('predicted_iou', 1.0),
|
| 78 |
+
'area': m['area']
|
| 79 |
+
})
|
| 80 |
+
|
| 81 |
+
# Sort by area (smallest to largest) to prefer specific objects over containers
|
| 82 |
+
segments.sort(key=lambda s: s['area'])
|
| 83 |
+
|
| 84 |
return segments
|