Spaces:
Running
Running
File size: 3,384 Bytes
144afae 03bafc0 144afae a8246e3 144afae 03bafc0 144afae a8246e3 144afae a8246e3 144afae 03bafc0 144afae a8246e3 03bafc0 a8246e3 144afae 03bafc0 144afae 03bafc0 144afae a8246e3 144afae 03bafc0 144afae 03bafc0 a8246e3 03bafc0 a8246e3 03bafc0 a8246e3 03bafc0 144afae a8246e3 03bafc0 a8246e3 03bafc0 a8246e3 144afae a8246e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import torch
import numpy as np
import gc
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
class CLIPMatcher:
def __init__(self, model_name='openai/clip-vit-large-patch14'):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Load directly to CPU first
self.model = CLIPModel.from_pretrained(model_name).to("cpu")
self.processor = CLIPProcessor.from_pretrained(model_name)
def get_top_k_segments(self, image, segments, text_query, k=5):
if not segments: return []
# 1. Clean Text
ignore = ['remove', 'delete', 'erase', 'the', 'a', 'an']
words = [w for w in text_query.lower().split() if w not in ignore]
clean_text = " ".join(words) if words else text_query
# 2. Crop (CPU)
pil_image = Image.fromarray(image)
crops = []
valid_segments = []
h_img, w_img = image.shape[:2]
total_img_area = h_img * w_img
for seg in segments:
if 'bbox' not in seg: continue
# Safe numpy cast
bbox = np.array(seg['bbox']).astype(int)
x1, y1, x2, y2 = bbox
# Adaptive Context Padding (30%)
w_box, h_box = x2 - x1, y2 - y1
pad_x = int(w_box * 0.3)
pad_y = int(h_box * 0.3)
crop_x1 = max(0, x1 - pad_x)
crop_y1 = max(0, y1 - pad_y)
crop_x2 = min(w_img, x2 + pad_x)
crop_y2 = min(h_img, y2 + pad_y)
crops.append(pil_image.crop((crop_x1, crop_y1, crop_x2, crop_y2)))
valid_segments.append(seg)
if not crops: return []
# 3. Inference (Brief GPU usage)
try:
self.model.to(self.device)
inputs = self.processor(
text=[clean_text], images=crops, return_tensors="pt", padding=True
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
# FIX: Use raw logits for meaningful scores.
# (Softmax forces sum=1, concealing bad matches)
probs = outputs.logits_per_image.cpu().numpy().flatten()
except Exception as e:
print(f"CLIP Error: {e}")
return []
finally:
# Move back to CPU immediately
self.model.to("cpu")
# 4. Score & Sort
final_results = []
for i, score in enumerate(probs):
seg = valid_segments[i]
if 'area' in seg:
area_ratio = seg['area'] / total_img_area
else:
w, h = seg['bbox'][2]-seg['bbox'][0], seg['bbox'][3]-seg['bbox'][1]
area_ratio = (w*h) / total_img_area
# Logits are roughly 15-30 range. Add small boost for area.
weighted_score = float(score) + (area_ratio * 2.0)
final_results.append({
'mask': seg.get('mask', None),
'bbox': seg['bbox'],
'original_score': float(score),
'weighted_score': weighted_score,
'label': seg.get('label', 'object')
})
final_results.sort(key=lambda x: x['weighted_score'], reverse=True)
return final_results[:k] |