Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| import requests | |
| from io import BytesIO | |
| from torchvision.models import resnet18, ResNet18_Weights | |
| def predict(img_path = None) -> str: | |
| # Initialize the model and transform | |
| resnet_model = resnet18(weights=ResNet18_Weights.DEFAULT) | |
| resnet_transform = ResNet18_Weights.DEFAULT.transforms() | |
| # Load the image | |
| if img_path is None: | |
| image = Image.open("examples/steak.jpeg").convert("RGB") | |
| if isinstance(img_path, np.ndarray): | |
| img = Image.fromarray(img_path.astype("uint8"), "RGB") | |
| # img = effnet_b2_transform(img).unsqueeze(0) | |
| # Convert to tensor | |
| # img = torch.from_numpy(np.array(image)).permute(2, 0, 1) | |
| img = resnet_transform(img) | |
| # Inference | |
| resnet_model.eval() | |
| with torch.inference_mode(): | |
| logits = resnet_model(img.unsqueeze(0)) | |
| pred_class = torch.softmax(logits, dim=1).argmax(dim=1).item() | |
| predicted_label = ResNet18_Weights.DEFAULT.meta["categories"][pred_class] | |
| print(f"Predicted class: {predicted_label}") | |
| return predicted_label | |
| import numpy as np | |
| import gradio as gr | |
| demo = gr.Interface(predict, | |
| gr.Image(), | |
| "label", | |
| title="ResNet-18_1K π", | |
| description="Upload an image to see classification probabilities based on ResNet-18 with 1K classes",) | |
| if __name__ == "__main__": | |
| demo.launch() | |