Spaces:
Running
Running
Commit
·
ce088ab
1
Parent(s):
c7ab35b
improve app design
Browse files- app.py +10 -5
- predict.py +13 -8
app.py
CHANGED
|
@@ -16,11 +16,16 @@ with demo:
|
|
| 16 |
|
| 17 |
with gr.Box():
|
| 18 |
|
|
|
|
| 19 |
with gr.Row():
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
input_image = gr.Image(type='filepath',label="Input Image", show_label=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
output_mask = gr.Image(label="Predicted Masks", show_label=True)
|
| 25 |
|
| 26 |
gr.Markdown("**Predict**")
|
|
@@ -32,10 +37,10 @@ with demo:
|
|
| 32 |
gr.Markdown("**Examples:**")
|
| 33 |
|
| 34 |
with gr.Column():
|
| 35 |
-
gr.Examples(example_list, [input_image, segmentation_task], output_mask, predict_masks)
|
| 36 |
|
| 37 |
|
| 38 |
-
submit_button.click(predict_masks, inputs=[input_image, segmentation_task], outputs=output_mask)
|
| 39 |
|
| 40 |
gr.Markdown('\n Demo created by: <a href=\"https://www.linkedin.com/in/shivalika-singh/\">Shivalika Singh</a>')
|
| 41 |
|
|
|
|
| 16 |
|
| 17 |
with gr.Box():
|
| 18 |
|
| 19 |
+
|
| 20 |
with gr.Row():
|
| 21 |
+
with gr.Column():
|
| 22 |
+
gr.Markdown("**Inputs**")
|
| 23 |
+
segmentation_task = gr.Dropdown(["semantic", "instance", "panoptic"], value="panoptic", label="Segmentation Task", show_label=True)
|
| 24 |
input_image = gr.Image(type='filepath',label="Input Image", show_label=True)
|
| 25 |
+
|
| 26 |
+
with gr.Column():
|
| 27 |
+
gr.Markdown("**Outputs**")
|
| 28 |
+
output_heading = gr.Textbox(label="Output Type", show_label=True)
|
| 29 |
output_mask = gr.Image(label="Predicted Masks", show_label=True)
|
| 30 |
|
| 31 |
gr.Markdown("**Predict**")
|
|
|
|
| 37 |
gr.Markdown("**Examples:**")
|
| 38 |
|
| 39 |
with gr.Column():
|
| 40 |
+
gr.Examples(example_list, [input_image, segmentation_task], [output_mask,output_heading], predict_masks)
|
| 41 |
|
| 42 |
|
| 43 |
+
submit_button.click(predict_masks, inputs=[input_image, segmentation_task], outputs=[output_mask,output_heading])
|
| 44 |
|
| 45 |
gr.Markdown('\n Demo created by: <a href=\"https://www.linkedin.com/in/shivalika-singh/\">Shivalika Singh</a>')
|
| 46 |
|
predict.py
CHANGED
|
@@ -4,7 +4,8 @@ import numpy as np
|
|
| 4 |
from PIL import Image
|
| 5 |
from collections import defaultdict
|
| 6 |
import os
|
| 7 |
-
#
|
|
|
|
| 8 |
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
|
| 9 |
|
| 10 |
from detectron2.data import MetadataCatalog
|
|
@@ -21,12 +22,12 @@ def load_model_and_processor(model_ckpt: str):
|
|
| 21 |
|
| 22 |
def load_default_ckpt(segmentation_task: str):
|
| 23 |
if segmentation_task == "semantic":
|
| 24 |
-
|
| 25 |
elif segmentation_task == "instance":
|
| 26 |
-
|
| 27 |
else:
|
| 28 |
-
|
| 29 |
-
return
|
| 30 |
|
| 31 |
def draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image):
|
| 32 |
metadata = MetadataCatalog.get("coco_2017_val_panoptic")
|
|
@@ -73,8 +74,8 @@ def visualize_instance_seg_mask(mask, input_image):
|
|
| 73 |
def predict_masks(input_img_path: str, segmentation_task: str):
|
| 74 |
|
| 75 |
#load model and image processor
|
| 76 |
-
|
| 77 |
-
model, image_processor = load_model_and_processor(
|
| 78 |
|
| 79 |
## pass input image through image processor
|
| 80 |
image = Image.open(input_img_path)
|
|
@@ -90,16 +91,20 @@ def predict_masks(input_img_path: str, segmentation_task: str):
|
|
| 90 |
predicted_segmentation_map = result.cpu().numpy()
|
| 91 |
palette = ade_palette()
|
| 92 |
output_result = draw_semantic_segmentation(predicted_segmentation_map, image, palette)
|
|
|
|
| 93 |
|
| 94 |
elif segmentation_task == "instance":
|
| 95 |
result = image_processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
|
| 96 |
predicted_instance_map = result["segmentation"].cpu().detach().numpy()
|
| 97 |
output_result = visualize_instance_seg_mask(predicted_instance_map, image)
|
|
|
|
| 98 |
|
| 99 |
else:
|
| 100 |
result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
|
| 101 |
predicted_segmentation_map = result["segmentation"]
|
| 102 |
seg_info = result['segments_info']
|
| 103 |
output_result = draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image)
|
|
|
|
| 104 |
|
| 105 |
-
|
|
|
|
|
|
| 4 |
from PIL import Image
|
| 5 |
from collections import defaultdict
|
| 6 |
import os
|
| 7 |
+
# Mentioning detectron2 as a dependency directly in requirements.txt tries to install detectron2 before torch and results in an error even if torch is listed as a dependency before detectron2.
|
| 8 |
+
# Hence, installing detectron2 this way when using Gradio HF spaces.
|
| 9 |
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
|
| 10 |
|
| 11 |
from detectron2.data import MetadataCatalog
|
|
|
|
| 22 |
|
| 23 |
def load_default_ckpt(segmentation_task: str):
|
| 24 |
if segmentation_task == "semantic":
|
| 25 |
+
default_ckpt = "facebook/mask2former-swin-tiny-ade-semantic"
|
| 26 |
elif segmentation_task == "instance":
|
| 27 |
+
default_ckpt = "facebook/mask2former-swin-small-coco-instance"
|
| 28 |
else:
|
| 29 |
+
default_ckpt = "facebook/mask2former-swin-tiny-coco-panoptic"
|
| 30 |
+
return default_ckpt
|
| 31 |
|
| 32 |
def draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image):
|
| 33 |
metadata = MetadataCatalog.get("coco_2017_val_panoptic")
|
|
|
|
| 74 |
def predict_masks(input_img_path: str, segmentation_task: str):
|
| 75 |
|
| 76 |
#load model and image processor
|
| 77 |
+
default_ckpt = load_default_ckpt(segmentation_task)
|
| 78 |
+
model, image_processor = load_model_and_processor(default_ckpt)
|
| 79 |
|
| 80 |
## pass input image through image processor
|
| 81 |
image = Image.open(input_img_path)
|
|
|
|
| 91 |
predicted_segmentation_map = result.cpu().numpy()
|
| 92 |
palette = ade_palette()
|
| 93 |
output_result = draw_semantic_segmentation(predicted_segmentation_map, image, palette)
|
| 94 |
+
output_heading = "Semantic Segmentation Output"
|
| 95 |
|
| 96 |
elif segmentation_task == "instance":
|
| 97 |
result = image_processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
|
| 98 |
predicted_instance_map = result["segmentation"].cpu().detach().numpy()
|
| 99 |
output_result = visualize_instance_seg_mask(predicted_instance_map, image)
|
| 100 |
+
output_heading = "Instance Segmentation Output"
|
| 101 |
|
| 102 |
else:
|
| 103 |
result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
|
| 104 |
predicted_segmentation_map = result["segmentation"]
|
| 105 |
seg_info = result['segments_info']
|
| 106 |
output_result = draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image)
|
| 107 |
+
output_heading = "Panoptic Segmentation Output"
|
| 108 |
|
| 109 |
+
|
| 110 |
+
return output_result, output_heading
|