File size: 2,474 Bytes
1f31241 |
1 |
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"mount_file_id":"19syVpDlkaAbNpOGldD0K2XTQ9I9qnf6t","authorship_tag":"ABX9TyM1i0p1DD9twFxMn+AGLqUI"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","source":["# install requirements\n","import sys\n","if 'google.colab' in sys.modules:\n"," print('Running in Colab.')\n"," !pip install transformers timm fairscale\n"," !git clone https://github.com/salesforce/BLIP\n"," %cd BLIP\n","\n","from PIL import Image\n","import requests\n","import torch\n","from torchvision import transforms\n","from torchvision.transforms.functional import InterpolationMode\n","\n","device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n","\n","def load_demo_image(image_size,device):\n"," img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg' \n"," raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB') \n","\n"," w,h = raw_image.size\n"," display(raw_image.resize((w//5,h//5)))\n"," \n"," transform = transforms.Compose([\n"," transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),\n"," transforms.ToTensor(),\n"," transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n"," ]) \n"," image = transform(raw_image).unsqueeze(0).to(device) \n"," return image\n","\n","from models.blip import blip_decoder\n","\n","image_size = 384\n","image = load_demo_image(image_size=image_size, device=device)\n","\n","model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'\n"," \n","model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base')\n","model.eval()\n","model = model.to(device)\n","\n","with torch.no_grad():\n","\n"," # beam search\n"," #captions = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5, num_return_sequences=3) \n"," # nucleus sampling\n"," num_captions = 3\n"," captions = []\n"," for i in range(num_captions):\n"," caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)\n"," captions.append(caption[0])\n"," for i, caption in enumerate(captions):\n"," print(f'caption {i+1}: {caption}') "],"metadata":{"id":"rLSiGT3z7mJs"},"execution_count":null,"outputs":[]}]} |