Spaces:
Build error
Build error
File size: 2,731 Bytes
0ab9a32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import torch
from PIL import Image, ImageDraw, ImageOps
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import json
import pdb
import cv2
import numpy as np
from typing import Union
from .base_captioner import BaseCaptioner
class BLIP2Captioner(BaseCaptioner):
def __init__(self, device, dialogue: bool = False, enable_filter: bool = False):
super().__init__(device, enable_filter)
self.device = device
self.dialogue = dialogue
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map = 'sequential', load_in_8bit=True)
@torch.no_grad()
def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
if type(image) == str: # input path
image = Image.open(image)
if not self.dialogue:
text_prompt = 'Context: ignore the white part in this image. Question: describe this image. Answer:'
inputs = self.processor(image, text = text_prompt, return_tensors="pt").to(self.device, self.torch_dtype)
out = self.model.generate(**inputs, max_new_tokens=50)
captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
if self.enable_filter and filter:
captions = self.filter_caption(image, captions)
print(f"\nProcessed ImageCaptioning by BLIP2Captioner, Output Text: {captions}")
return captions
else:
context = []
template = "Question: {} Answer: {}."
while(True):
input_texts = input()
if input_texts == 'end':
break
prompt = " ".join([template.format(context[i][0], context[i][1]) for i in range(len(context))]) + " Question: " + input_texts + " Answer:"
inputs = self.processor(image, text = prompt, return_tensors="pt").to(self.device, self.torch_dtype)
out = self.model.generate(**inputs, max_new_tokens=50)
captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
context.append((input_texts, captions))
return captions
if __name__ == '__main__':
dialogue = False
model = BLIP2Captioner(device='cuda:4', dialogue = dialogue, cache_dir = '/nvme-ssd/fjj/Caption-Anything/model_cache')
image_path = 'test_img/img2.jpg'
seg_mask = np.zeros((224,224))
seg_mask[50:200, 50:200] = 1
print(f'process image {image_path}')
print(model.inference_seg(image_path, seg_mask)) |