TextEraser / src /segmenter.py
lxzcpro's picture
code clean up
9de67ae
import torch
import numpy as np
import gc
from ultralytics import YOLO
from sam2.sam2_image_predictor import SAM2ImagePredictor
class YOLOWorldDetector:
def __init__(self, model_name='yolov8s-worldv2.pt'):
self.model = YOLO(model_name)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
def detect(self, image, text_query):
clean_text = text_query.replace("remove", "").replace("delete", "").strip()
if not clean_text: clean_text = "object"
boxes = []
try:
self.model.to('cpu')
self.model.set_classes([clean_text])
if self.device == 'cuda':
self.model.to('cuda')
results = self.model.predict(image, conf=0.05, iou=0.5, verbose=False)[0]
if results.boxes:
for box in results.boxes.data:
x1, y1, x2, y2 = box[:4].cpu().numpy()
conf = float(box[4])
boxes.append({
'bbox': [int(x1), int(y1), int(x2), int(y2)],
'score': conf,
'label': clean_text
})
except Exception as e:
print(f"YOLO Error: {e}")
finally:
self.model.to('cpu')
boxes.sort(key=lambda x: x['score'], reverse=True)
return boxes
class SAM2Predictor:
def __init__(self, checkpoint="facebook/sam2.1-hiera-large"):
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
try:
self.predictor = SAM2ImagePredictor.from_pretrained(checkpoint)
except:
self.predictor = SAM2ImagePredictor.from_pretrained(checkpoint, device='cpu')
def set_image(self, image):
self.predictor.model.to(self.device)
self.predictor.set_image(image)
def predict_from_box(self, bbox):
box_input = np.array(bbox)[None, :]
masks, scores, logits = self.predictor.predict(
point_coords=None,
point_labels=None,
box=box_input,
multimask_output=True
)
sorted_results = sorted(zip(masks, scores), key=lambda x: x[1], reverse=True)
return [(m.astype(np.uint8), s) for m, s in sorted_results]
def clear_memory(self):
self.predictor.reset_predictor()
self.predictor.model.to('cpu')
del self.predictor
torch.cuda.empty_cache()
gc.collect()