truong-xuan-linh's picture
init
4589b64
raw
history blame contribute delete
No virus
3.83 kB
import os
import torch
import requests
from PIL import Image
from utils.config import Config
from src.model.init import download_model
from transformers import AutoImageProcessor, ViTModel, AutoTokenizer, T5EncoderModel
class CommentGenerator():
def __init__(self) -> None:
self.config = Config("./config/comment_generator.yaml").__get_config__()
download_model(self.config['model']['url'], self.config['model']['dir'])
#Get model
self.tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base")
self.model = torch.load(self.config["model"]["dir"], map_location=torch.device(self.config["model"]['device']))
#Image
self.vit_image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
self.vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
self.vit_model.to(self.config["model"]["device"])
#Content
self.vit5_model = T5EncoderModel.from_pretrained("VietAI/vit5-base")
self.vit5_model.to(self.config["model"]["device"])
def get_text_feature(self, content):
inputs = self.tokenizer(content,
padding="max_length",
truncation=True,
max_length=self.config["model"]["input_maxlen"],
return_tensors="pt").to(self.config["model"]["device"])
with torch.no_grad():
outputs = self.vit5_model(**inputs)
last_hidden_states = outputs.last_hidden_state
return last_hidden_states.to(self.config["model"]["device"]), inputs.attention_mask.to(self.config["model"]["device"])
def get_image_feature_from_url(self, image_url, is_local=False):
if not image_url:
print(f"WARNING not image url {image_url}")
return torch.zeros((1, 197, 768)).to(self.config["model"]["device"]), torch.zeros((1, 197)).to(self.config["model"]["device"])
if not is_local:
try:
images = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
except:
print(f"READ IMAGE ERR: {image_url}")
return torch.zeros((1, 197, 768)).to(self.config["model"]["device"]), torch.zeros((1, 197)).to(self.config["model"]["device"])
else:
images = Image.open(image_url).convert("RGB")
inputs = self.vit_image_processor(images, return_tensors="pt").to(self.config["model"]["device"])
with torch.no_grad():
outputs = self.vit_model(**inputs)
last_hidden_states = outputs.last_hidden_state
attention_mask = torch.ones((last_hidden_states.shape[0], last_hidden_states.shape[1]))
return last_hidden_states.to(self.config["model"]["device"]), attention_mask.to(self.config["model"]["device"])
def inference(self, content_feature, content_mask, image_feature, image_mask):
inputs_embeds = torch.cat((image_feature[0], content_feature[0]), 0)
inputs_embeds = torch.unsqueeze(inputs_embeds, 0)
attention_mask = torch.cat((image_mask[0], content_mask[0]), 0)
attention_mask = torch.unsqueeze(attention_mask, 0)
with torch.no_grad():
generated_ids = self.model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
num_beams=2,
max_length=self.config["model"]["output_maxlen"],
# num_return_sequences=2
# skip_special_tokens=True,
# clean_up_tokenization_spaces=True
)
comments = [self.tokenizer.decode(generated_id, skip_special_tokens=True) for generated_id in generated_ids]
return comments