Spaces:
Build error
Build error
import os | |
import torch | |
import requests | |
from PIL import Image | |
from utils.config import Config | |
from src.model.init import download_model | |
from transformers import AutoImageProcessor, ViTModel, AutoTokenizer, T5EncoderModel | |
class CommentGenerator(): | |
def __init__(self) -> None: | |
self.config = Config("./config/comment_generator.yaml").__get_config__() | |
download_model(self.config['model']['url'], self.config['model']['dir']) | |
#Get model | |
self.tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base") | |
self.model = torch.load(self.config["model"]["dir"], map_location=torch.device(self.config["model"]['device'])) | |
#Image | |
self.vit_image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") | |
self.vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k") | |
self.vit_model.to(self.config["model"]["device"]) | |
#Content | |
self.vit5_model = T5EncoderModel.from_pretrained("VietAI/vit5-base") | |
self.vit5_model.to(self.config["model"]["device"]) | |
def get_text_feature(self, content): | |
inputs = self.tokenizer(content, | |
padding="max_length", | |
truncation=True, | |
max_length=self.config["model"]["input_maxlen"], | |
return_tensors="pt").to(self.config["model"]["device"]) | |
with torch.no_grad(): | |
outputs = self.vit5_model(**inputs) | |
last_hidden_states = outputs.last_hidden_state | |
return last_hidden_states.to(self.config["model"]["device"]), inputs.attention_mask.to(self.config["model"]["device"]) | |
def get_image_feature_from_url(self, image_url, is_local=False): | |
if not image_url: | |
print(f"WARNING not image url {image_url}") | |
return torch.zeros((1, 197, 768)).to(self.config["model"]["device"]), torch.zeros((1, 197)).to(self.config["model"]["device"]) | |
if not is_local: | |
try: | |
images = Image.open(requests.get(image_url, stream=True).raw).convert("RGB") | |
except: | |
print(f"READ IMAGE ERR: {image_url}") | |
return torch.zeros((1, 197, 768)).to(self.config["model"]["device"]), torch.zeros((1, 197)).to(self.config["model"]["device"]) | |
else: | |
images = Image.open(image_url).convert("RGB") | |
inputs = self.vit_image_processor(images, return_tensors="pt").to(self.config["model"]["device"]) | |
with torch.no_grad(): | |
outputs = self.vit_model(**inputs) | |
last_hidden_states = outputs.last_hidden_state | |
attention_mask = torch.ones((last_hidden_states.shape[0], last_hidden_states.shape[1])) | |
return last_hidden_states.to(self.config["model"]["device"]), attention_mask.to(self.config["model"]["device"]) | |
def inference(self, content_feature, content_mask, image_feature, image_mask): | |
inputs_embeds = torch.cat((image_feature[0], content_feature[0]), 0) | |
inputs_embeds = torch.unsqueeze(inputs_embeds, 0) | |
attention_mask = torch.cat((image_mask[0], content_mask[0]), 0) | |
attention_mask = torch.unsqueeze(attention_mask, 0) | |
with torch.no_grad(): | |
generated_ids = self.model.generate( | |
inputs_embeds=inputs_embeds, | |
attention_mask=attention_mask, | |
num_beams=2, | |
max_length=self.config["model"]["output_maxlen"], | |
# num_return_sequences=2 | |
# skip_special_tokens=True, | |
# clean_up_tokenization_spaces=True | |
) | |
comments = [self.tokenizer.decode(generated_id, skip_special_tokens=True) for generated_id in generated_ids] | |
return comments |