File size: 3,555 Bytes
144afae
a8246e3
03bafc0
 
 
 
144afae
a8246e3
144afae
 
 
03bafc0
 
 
144afae
03bafc0
 
 
 
 
 
144afae
03bafc0
 
144afae
03bafc0
 
144afae
03bafc0
 
 
 
 
 
 
 
a8246e3
03bafc0
 
a8246e3
03bafc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8246e3
03bafc0
a8246e3
03bafc0
 
 
 
 
 
 
 
 
 
a8246e3
03bafc0
 
144afae
03bafc0
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import numpy as np
import cv2
import torch
import gc
# Note: We import classes but DO NOT instantiate them globally
from .segmenter import YOLOWorldDetector, SAM2Predictor
from .matcher import CLIPMatcher
from .painter import SDXLInpainter

class ObjectRemovalPipeline:
    def __init__(self):
        print("Initializing Pipeline in LOW MEMORY mode...")
        # No models loaded at startup!
        pass
    
    def _clear_ram(self):
        """Helper to force clear RAM & VRAM"""
        gc.collect()
        torch.cuda.empty_cache()

    def get_candidates(self, image, text_query):
        """
        Step 1: Detect & Segment & Rank
        Strategy: Load one model at a time, use it, then delete it.
        """
        candidates = []
        box_candidates = []
        
        # --- PHASE 1: YOLO (Detect) ---
        print("Loading YOLO...")
        detector = YOLOWorldDetector()
        try:
            box_candidates = detector.detect(image, text_query)
        finally:
            del detector # Delete model immediately
            self._clear_ram()
            
        if not box_candidates:
            return [], "No objects detected."

        # --- PHASE 2: SAM2 (Segment) ---
        print("Loading SAM2...")
        segmenter = SAM2Predictor()
        segments_to_score = []
        try:
            segmenter.set_image(image)
            # Process top 3 boxes -> up to 9 masks
            for cand in box_candidates[:3]: 
                bbox = cand['bbox']
                mask_variations = segmenter.predict_from_box(bbox)
                for i, (mask, sam_score) in enumerate(mask_variations):
                    segments_to_score.append({
                        'mask': mask,
                        'bbox': bbox,
                        'area': mask.sum(),
                        'label': f"{cand['label']} (Var {i+1})"
                    })
        finally:
            # Critical cleanup for SAM2
            if hasattr(segmenter, 'clear_memory'):
                segmenter.clear_memory()
            del segmenter
            self._clear_ram()

        # --- PHASE 3: CLIP (Rank) ---
        print("Loading CLIP...")
        matcher = CLIPMatcher()
        ranked_candidates = []
        try:
            ranked_candidates = matcher.get_top_k_segments(
                image, 
                segments_to_score, 
                text_query, 
                k=len(segments_to_score)
            )
        finally:
            del matcher
            self._clear_ram()
            
        return ranked_candidates, f"Found {len(ranked_candidates)} options."

    def inpaint_selected(self, image, selected_mask, inpaint_prompt="", shadow_expansion=0):
        """
        Step 2: Inpaint
        """
        # Shadow / Edge Logic (CPU ops)
        if shadow_expansion > 0:
            kernel_h = int(shadow_expansion * 1.5)
            kernel_w = int(shadow_expansion * 0.5)
            kernel = np.ones((kernel_h, kernel_w), np.uint8)
            selected_mask = cv2.dilate(selected_mask.astype(np.uint8), kernel, iterations=1)

        kernel = np.ones((10, 10), np.uint8)
        final_mask = cv2.dilate(selected_mask.astype(np.uint8), kernel, iterations=1)
        
        result = None
        
        # --- PHASE 4: SDXL (Inpaint) ---
        print("Loading SDXL...")
        inpainter = SDXLInpainter()
        try:
            result = inpainter.inpaint(image, final_mask, prompt=inpaint_prompt)
        finally:
            del inpainter
            self._clear_ram()
            
        return result