File size: 2,474 Bytes
1f31241
1
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"mount_file_id":"19syVpDlkaAbNpOGldD0K2XTQ9I9qnf6t","authorship_tag":"ABX9TyM1i0p1DD9twFxMn+AGLqUI"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","source":["# install requirements\n","import sys\n","if 'google.colab' in sys.modules:\n","    print('Running in Colab.')\n","    !pip install transformers timm fairscale\n","    !git clone https://github.com/salesforce/BLIP\n","    %cd BLIP\n","\n","from PIL import Image\n","import requests\n","import torch\n","from torchvision import transforms\n","from torchvision.transforms.functional import InterpolationMode\n","\n","device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n","\n","def load_demo_image(image_size,device):\n","    img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg' \n","    raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')   \n","\n","    w,h = raw_image.size\n","    display(raw_image.resize((w//5,h//5)))\n","    \n","    transform = transforms.Compose([\n","        transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),\n","        transforms.ToTensor(),\n","        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n","        ]) \n","    image = transform(raw_image).unsqueeze(0).to(device)   \n","    return image\n","\n","from models.blip import blip_decoder\n","\n","image_size = 384\n","image = load_demo_image(image_size=image_size, device=device)\n","\n","model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'\n","    \n","model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base')\n","model.eval()\n","model = model.to(device)\n","\n","with torch.no_grad():\n","\n","    # beam search\n","  #captions = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5, num_return_sequences=3) \n","    # nucleus sampling\n","  num_captions = 3\n","  captions = []\n","  for i in range(num_captions):\n","      caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)\n","      captions.append(caption[0])\n","  for i, caption in enumerate(captions):\n","      print(f'caption {i+1}: {caption}') "],"metadata":{"id":"rLSiGT3z7mJs"},"execution_count":null,"outputs":[]}]}