File size: 2,524 Bytes
a8246e3
144afae
03bafc0
a8246e3
03bafc0
a8246e3
03bafc0
 
9de67ae
144afae
03bafc0
144afae
03bafc0
 
 
144afae
03bafc0
 
9de67ae
03bafc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9de67ae
03bafc0
 
 
 
a8246e3
03bafc0
 
a8246e3
03bafc0
 
 
 
a8246e3
03bafc0
 
 
a8246e3
03bafc0
 
9de67ae
03bafc0
 
 
 
 
 
 
 
a8246e3
03bafc0
9de67ae
03bafc0
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import numpy as np
import gc
from ultralytics import YOLO
from sam2.sam2_image_predictor import SAM2ImagePredictor

class YOLOWorldDetector:
    def __init__(self, model_name='yolov8s-worldv2.pt'):

        self.model = YOLO(model_name)
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    def detect(self, image, text_query):
        clean_text = text_query.replace("remove", "").replace("delete", "").strip()
        if not clean_text: clean_text = "object"
        
        boxes = []
        try:

            self.model.to('cpu')
            self.model.set_classes([clean_text])
            
            if self.device == 'cuda':
                self.model.to('cuda')
                
            results = self.model.predict(image, conf=0.05, iou=0.5, verbose=False)[0]
            
            if results.boxes:
                for box in results.boxes.data:
                    x1, y1, x2, y2 = box[:4].cpu().numpy()
                    conf = float(box[4])
                    boxes.append({
                        'bbox': [int(x1), int(y1), int(x2), int(y2)],
                        'score': conf,
                        'label': clean_text
                    })
        except Exception as e:
            print(f"YOLO Error: {e}")
        finally:

            self.model.to('cpu')
            
        boxes.sort(key=lambda x: x['score'], reverse=True)
        return boxes

class SAM2Predictor:
    def __init__(self, checkpoint="facebook/sam2.1-hiera-large"):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        try:
            self.predictor = SAM2ImagePredictor.from_pretrained(checkpoint)
        except:
            self.predictor = SAM2ImagePredictor.from_pretrained(checkpoint, device='cpu')

    def set_image(self, image):
        self.predictor.model.to(self.device)
        self.predictor.set_image(image)

    def predict_from_box(self, bbox):
        box_input = np.array(bbox)[None, :]

        masks, scores, logits = self.predictor.predict(
            point_coords=None,
            point_labels=None,
            box=box_input,
            multimask_output=True 
        )
        sorted_results = sorted(zip(masks, scores), key=lambda x: x[1], reverse=True)
        return [(m.astype(np.uint8), s) for m, s in sorted_results]
        
    def clear_memory(self):

        self.predictor.reset_predictor()
        self.predictor.model.to('cpu')
        del self.predictor
        torch.cuda.empty_cache()
        gc.collect()