Spaces:
Running
Running
| import torch | |
| import numpy as np | |
| import gc | |
| from ultralytics import YOLO | |
| from sam2.sam2_image_predictor import SAM2ImagePredictor | |
| class YOLOWorldDetector: | |
| def __init__(self, model_name='yolov8s-worldv2.pt'): | |
| self.model = YOLO(model_name) | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| def detect(self, image, text_query): | |
| clean_text = text_query.replace("remove", "").replace("delete", "").strip() | |
| if not clean_text: clean_text = "object" | |
| boxes = [] | |
| try: | |
| self.model.to('cpu') | |
| self.model.set_classes([clean_text]) | |
| if self.device == 'cuda': | |
| self.model.to('cuda') | |
| results = self.model.predict(image, conf=0.05, iou=0.5, verbose=False)[0] | |
| if results.boxes: | |
| for box in results.boxes.data: | |
| x1, y1, x2, y2 = box[:4].cpu().numpy() | |
| conf = float(box[4]) | |
| boxes.append({ | |
| 'bbox': [int(x1), int(y1), int(x2), int(y2)], | |
| 'score': conf, | |
| 'label': clean_text | |
| }) | |
| except Exception as e: | |
| print(f"YOLO Error: {e}") | |
| finally: | |
| self.model.to('cpu') | |
| boxes.sort(key=lambda x: x['score'], reverse=True) | |
| return boxes | |
| class SAM2Predictor: | |
| def __init__(self, checkpoint="facebook/sam2.1-hiera-large"): | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| try: | |
| self.predictor = SAM2ImagePredictor.from_pretrained(checkpoint) | |
| except: | |
| self.predictor = SAM2ImagePredictor.from_pretrained(checkpoint, device='cpu') | |
| def set_image(self, image): | |
| self.predictor.model.to(self.device) | |
| self.predictor.set_image(image) | |
| def predict_from_box(self, bbox): | |
| box_input = np.array(bbox)[None, :] | |
| masks, scores, logits = self.predictor.predict( | |
| point_coords=None, | |
| point_labels=None, | |
| box=box_input, | |
| multimask_output=True | |
| ) | |
| sorted_results = sorted(zip(masks, scores), key=lambda x: x[1], reverse=True) | |
| return [(m.astype(np.uint8), s) for m, s in sorted_results] | |
| def clear_memory(self): | |
| self.predictor.reset_predictor() | |
| self.predictor.model.to('cpu') | |
| del self.predictor | |
| torch.cuda.empty_cache() | |
| gc.collect() |