ttengwang
fix bugs of example images and api keys
5c74464
raw
history blame
3.02 kB
from transformers import GitProcessor, AutoProcessor
from .modeling_git import GitForCausalLM
from PIL import Image
import torch
from .base_captioner import BaseCaptioner
import numpy as np
from typing import Union
import torchvision.transforms.functional as F
class GITCaptioner(BaseCaptioner):
def __init__(self, device, enable_filter=False):
super().__init__(device, enable_filter)
self.device = device
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.processor = AutoProcessor.from_pretrained("microsoft/git-large")
self.model = GitForCausalLM.from_pretrained("microsoft/git-large", torch_dtype=self.torch_dtype).to(self.device)
@torch.no_grad()
def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
if type(image) == str: # input path
image = Image.open(image)
pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device, self.torch_dtype)
generated_ids = self.model.generate(pixel_values=pixel_values, max_new_tokens=50)
generated_caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
if self.enable_filter and filter:
captions = self.filter_caption(image, captions)
print(f"\nProcessed ImageCaptioning by GITCaptioner, Output Text: {generated_caption}")
return generated_caption
@torch.no_grad()
def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg", filter=False, disable_regular_box = False):
crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode, disable_regular_box=disable_regular_box)
if type(image) == str: # input path
image = Image.open(image)
inputs = self.processor(images=image, return_tensors="pt")
pixel_values = inputs.pixel_values.to(self.device, self.torch_dtype)
_, _, H, W = pixel_values.shape
seg_mask = Image.fromarray(seg_mask.astype(float))
seg_mask = seg_mask.resize((H, W))
seg_mask = F.pil_to_tensor(seg_mask) > 0.5
seg_mask = seg_mask.float()
pixel_masks = seg_mask.unsqueeze(0).to(self.device)
out = self.model.generate(pixel_values=pixel_values, pixel_masks=pixel_masks, max_new_tokens=50)
captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
if self.enable_filter and filter:
captions = self.filter_caption(image, captions)
print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
return captions, crop_save_path
if __name__ == '__main__':
model = GITCaptioner(device='cuda:2', enable_filter=False)
image_path = 'test_img/img2.jpg'
seg_mask = np.zeros((224,224))
seg_mask[50:200, 50:200] = 1
print(f'process image {image_path}')
print(model.inference_with_reduced_tokens(image_path, seg_mask))