Spaces:
Sleeping
Sleeping
| from PIL import Image | |
| from models.model import ImageCaptioningModel | |
| from torchvision import transforms | |
| import torch | |
| import torch | |
| from transformers import ViTModel, ViTFeatureExtractor, GPT2LMHeadModel, GPT2Tokenizer | |
| from PIL import Image | |
| from config.config import Config | |
| class ImageCaptioningInference: | |
| def __init__(self, model): | |
| self.model = model | |
| self.device = Config.DEVICE | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor() | |
| ]) | |
| def infer_image(self, image): | |
| # Load and preprocess the image | |
| # image = Image.open(image_path) | |
| image = self.transform(image).unsqueeze(0).to(self.device) | |
| # Extract image features | |
| image_features = self.model.extract_image_features(image) | |
| # Generate caption | |
| caption = self.generate_caption(image_features) | |
| return caption | |
| def generate_caption(self, image_features, num_beams=3, max_length=50): | |
| # Prepare the image features for input | |
| image_features = image_features.unsqueeze(1) # [batch_size, 1, hidden_size] | |
| # Generate caption using beam search | |
| output = self.model.gpt2_model.generate( | |
| inputs_embeds=image_features, | |
| max_length=max_length, | |
| num_beams=num_beams, | |
| early_stopping=True, | |
| pad_token_id=self.model.tokenizer.eos_token_id, | |
| bos_token_id=self.model.tokenizer.bos_token_id, | |
| eos_token_id=self.model.tokenizer.eos_token_id | |
| ) | |
| # Decode the generated caption | |
| caption = self.model.tokenizer.decode(output[0], skip_special_tokens=True) | |
| return caption | |
| if __name__ == "__main__": | |
| # Path to the saved model directory | |
| model_dir = 'model' | |
| # Initialize inference class | |
| model = ImageCaptioningModel() | |
| model.load(model_dir) | |
| inference_model = ImageCaptioningInference(model) | |
| # Path to the input image | |
| image_path = 'test_img.jpg' | |
| image = Image.open(image_path) | |
| # Perform inference and print the generated caption | |
| caption = inference_model.infer_image(image) | |
| print("Generated Caption:", caption) | |