import torch from PIL import Image import requests from torchvision import transforms from torchvision.transforms.functional import InterpolationMode import gradio as gr device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def load_image(image): image_size = 384 transform = transforms.Compose([ transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) ]) image = Image.fromarray(image.astype('uint8'), 'RGB') image = transform(image).unsqueeze(0).to(device) return image def generate_caption(image): image = load_image(image) model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base') model.eval() model = model.to(device) with torch.no_grad(): num_captions = 3 captions = [] for i in range(num_captions): caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5) captions.append(caption[0]) captions_list = [caption.strip() for caption in captions] return "\n".join(captions_list) iface = gr.Interface( generate_caption, inputs=gr.inputs.Image(shape=(384, 384)), outputs=gr.outputs.Textbox(num_lines=3), title="Image Caption Generator", description="Generate captions for images using BLIP" ) iface.launch()