MoMA_demo / dataset_lib /dataset_eval_MoMA.py
Kunpeng Song
bg
ef3a17c
raw
history blame contribute delete
No virus
1.55 kB
from PIL import Image
import numpy as np
import torch
from torchvision import transforms
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from rembg import remove
def create_binary_mask(image):
grayscale = image.convert("L")
mask = grayscale.point(lambda x: 255 if x > 1 else 0, '1')
return mask
def Dataset_evaluate_MoMA(image_pil, prompt,subject, moMA_main_modal):
LLaVa_processor = moMA_main_modal.image_processor_llava
llava_config = moMA_main_modal.model_llava.config
transform = transforms.Compose([
transforms.Resize((512, 512)),
])
mask_pil = create_binary_mask(remove(image_pil)) # Image.open(mask_path)
blip2_opt = prompt
if transform is not None:
image_pil = transform(image_pil)
mask_pil = transform(mask_pil)
mask_pil = np.array(mask_pil)
mask_pil = mask_pil[:,:,0] if len(mask_pil.shape)==3 else mask_pil
image = torch.from_numpy(np.array(image_pil)).permute(2,0,1)
mask = (torch.clamp((torch.from_numpy(mask_pil).unsqueeze(0)).float(),min=0.0,max=1.0)>0).float()
res = {'image': (image/127.5-1).unsqueeze(0),\
'mask': mask.unsqueeze(0), \
'text': [blip2_opt]}
image_wb = image * mask + torch.ones_like(image)* (1-mask)*255
image_pil = Image.fromarray(image_wb.permute(1,2,0).numpy().astype(np.uint8))
res['llava_processed'] = process_images([image_pil], LLaVa_processor, llava_config)
res['label'] = [subject]
return res