lxzcpro commited on
Commit
c3f0641
·
1 Parent(s): 9de67ae

Initial deployment of TextEraser

Browse files
Files changed (3) hide show
  1. .gitignore +5 -1
  2. app.py +36 -9
  3. requirements.txt +4 -1
.gitignore CHANGED
@@ -207,4 +207,8 @@ marimo/_lsp/
207
  __marimo__/
208
 
209
  models/yolov8
210
- rubrics.txt
 
 
 
 
 
207
  __marimo__/
208
 
209
  models/yolov8
210
+ rubrics.txt
211
+ yolov8s-worldv2.pt
212
+ yolov8x-seg.pt
213
+ .gradio
214
+ notebook
app.py CHANGED
@@ -1,12 +1,27 @@
1
  import gradio as gr
2
  import numpy as np
 
 
3
  from src.pipeline import ObjectRemovalPipeline
4
  from src.utils import visualize_mask
5
 
6
-
 
 
 
 
 
 
 
 
 
 
 
 
7
  pipeline = ObjectRemovalPipeline()
8
 
9
  def ensure_uint8(image):
 
10
  if image is None: return None
11
  image = np.array(image)
12
  if image.dtype != np.uint8:
@@ -14,11 +29,13 @@ def ensure_uint8(image):
14
  image = np.clip(image, 0, 255).astype(np.uint8)
15
  return image
16
 
 
17
  def step1_detect(image, text_query):
 
18
  if image is None or not text_query:
19
  return [], [], "Please upload image and enter text."
20
 
21
-
22
  candidates, msg = pipeline.get_candidates(image, text_query)
23
 
24
  if not candidates:
@@ -26,36 +43,41 @@ def step1_detect(image, text_query):
26
 
27
  masks = [c['mask'] for c in candidates]
28
 
29
-
30
  gallery_imgs = []
31
  for i, mask in enumerate(masks):
32
  viz = visualize_mask(image, mask)
33
-
34
- label = f"Option {i+1} (Score: {candidates[i].get('weighted_score', 0):.2f})"
35
  gallery_imgs.append((ensure_uint8(viz), label))
36
 
37
  return masks, gallery_imgs, "Select the best match below."
38
 
39
  def on_select(evt: gr.SelectData):
 
40
  return evt.index
41
 
 
42
  def step2_remove(image, masks, selected_idx, prompt, shadow_exp):
 
43
  if not masks or selected_idx is None:
44
  return None, "Please select an object first."
45
 
46
  target_mask = masks[selected_idx]
47
 
48
-
49
  result = pipeline.inpaint_selected(image, target_mask, prompt, shadow_expansion=shadow_exp)
50
 
51
  return ensure_uint8(result), "Success!"
52
 
 
53
  css = """
54
  .gradio-container {min-height: 0px !important}
55
- button.gallery-item {object-fit: contain !important}
56
  """
57
 
58
  with gr.Blocks(title="TextEraser", css=css, theme=gr.themes.Soft()) as demo:
 
59
  mask_state = gr.State([])
60
  idx_state = gr.State(0)
61
 
@@ -68,7 +90,7 @@ with gr.Blocks(title="TextEraser", css=css, theme=gr.themes.Soft()) as demo:
68
  btn_detect = gr.Button("1. Detect Objects", variant="primary")
69
 
70
  with gr.Column(scale=1):
71
-
72
  gallery = gr.Gallery(
73
  label="Candidates (Select One)",
74
  columns=2,
@@ -103,4 +125,9 @@ with gr.Blocks(title="TextEraser", css=css, theme=gr.themes.Soft()) as demo:
103
  )
104
 
105
  if __name__ == "__main__":
106
- demo.queue().launch(share=True)
 
 
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
+ import argparse
4
+ import os
5
  from src.pipeline import ObjectRemovalPipeline
6
  from src.utils import visualize_mask
7
 
8
+ # --- ZeroGPU Compatibility Shim ---
9
+ # Allows code to run on local CPU/GPU without crashing on 'import spaces'
10
+ try:
11
+ import spaces
12
+ except ImportError:
13
+ class spaces:
14
+ @staticmethod
15
+ def GPU(duration=120):
16
+ def decorator(func):
17
+ return func
18
+ return decorator
19
+
20
+ # Initialize pipeline (Models use lazy-loading to save memory)
21
  pipeline = ObjectRemovalPipeline()
22
 
23
  def ensure_uint8(image):
24
+ """Normalize image to uint8 (0-255)"""
25
  if image is None: return None
26
  image = np.array(image)
27
  if image.dtype != np.uint8:
 
29
  image = np.clip(image, 0, 255).astype(np.uint8)
30
  return image
31
 
32
+ @spaces.GPU(duration=120)
33
  def step1_detect(image, text_query):
34
+ """Detect objects and return candidates for user selection"""
35
  if image is None or not text_query:
36
  return [], [], "Please upload image and enter text."
37
 
38
+ # 1. Detect & Rank candidates via Pipeline
39
  candidates, msg = pipeline.get_candidates(image, text_query)
40
 
41
  if not candidates:
 
43
 
44
  masks = [c['mask'] for c in candidates]
45
 
46
+ # 2. Visualize masks for Gallery
47
  gallery_imgs = []
48
  for i, mask in enumerate(masks):
49
  viz = visualize_mask(image, mask)
50
+ score = candidates[i].get('weighted_score', 0)
51
+ label = f"Option {i+1} (Score: {score:.2f})"
52
  gallery_imgs.append((ensure_uint8(viz), label))
53
 
54
  return masks, gallery_imgs, "Select the best match below."
55
 
56
  def on_select(evt: gr.SelectData):
57
+ """Capture user selection from Gallery"""
58
  return evt.index
59
 
60
+ @spaces.GPU(duration=120)
61
  def step2_remove(image, masks, selected_idx, prompt, shadow_exp):
62
+ """Inpaint the selected mask"""
63
  if not masks or selected_idx is None:
64
  return None, "Please select an object first."
65
 
66
  target_mask = masks[selected_idx]
67
 
68
+ # 3. Inpaint with Shadow Fix logic
69
  result = pipeline.inpaint_selected(image, target_mask, prompt, shadow_expansion=shadow_exp)
70
 
71
  return ensure_uint8(result), "Success!"
72
 
73
+ # CSS for better layout and full image visibility in Gallery
74
  css = """
75
  .gradio-container {min-height: 0px !important}
76
+ button.gallery-item {object-fit: contain !important}
77
  """
78
 
79
  with gr.Blocks(title="TextEraser", css=css, theme=gr.themes.Soft()) as demo:
80
+ # State to hold masks between steps
81
  mask_state = gr.State([])
82
  idx_state = gr.State(0)
83
 
 
90
  btn_detect = gr.Button("1. Detect Objects", variant="primary")
91
 
92
  with gr.Column(scale=1):
93
+ # Interactive Gallery (Adaptable size)
94
  gallery = gr.Gallery(
95
  label="Candidates (Select One)",
96
  columns=2,
 
125
  )
126
 
127
  if __name__ == "__main__":
128
+ parser = argparse.ArgumentParser()
129
+ parser.add_argument("--share", action="store_true", help="Create a public link (Colab)")
130
+ args = parser.parse_args()
131
+
132
+ # queue() is required for ZeroGPU
133
+ demo.queue().launch(share=args.share)
requirements.txt CHANGED
@@ -16,4 +16,7 @@ gradio
16
  PyYAML
17
  filelock
18
  Pillow
19
- sniffio
 
 
 
 
16
  PyYAML
17
  filelock
18
  Pillow
19
+ sniffio
20
+ spaces
21
+ clip
22
+ git+https://github.com/facebookresearch/sam2.git