File size: 3,832 Bytes
4589b64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
import torch
import requests
from PIL import Image

from utils.config import Config
from src.model.init import download_model
from transformers import AutoImageProcessor, ViTModel, AutoTokenizer, T5EncoderModel

class CommentGenerator():
    def __init__(self) -> None:
        
        self.config = Config("./config/comment_generator.yaml").__get_config__()
        download_model(self.config['model']['url'], self.config['model']['dir'])
        
        #Get model
        self.tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base")
        self.model = torch.load(self.config["model"]["dir"], map_location=torch.device(self.config["model"]['device']))        
        
        
        #Image
        self.vit_image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
        self.vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        self.vit_model.to(self.config["model"]["device"])
        
        #Content
        self.vit5_model = T5EncoderModel.from_pretrained("VietAI/vit5-base")
        self.vit5_model.to(self.config["model"]["device"])
        
    def get_text_feature(self, content):
        inputs = self.tokenizer(content,
                            padding="max_length",
                            truncation=True,
                            max_length=self.config["model"]["input_maxlen"],
                            return_tensors="pt").to(self.config["model"]["device"])
        with torch.no_grad():
            outputs = self.vit5_model(**inputs)
        last_hidden_states = outputs.last_hidden_state
        return last_hidden_states.to(self.config["model"]["device"]), inputs.attention_mask.to(self.config["model"]["device"])
        

    def get_image_feature_from_url(self, image_url, is_local=False):
        if not image_url:
            print(f"WARNING not image url {image_url}")
            return torch.zeros((1, 197, 768)).to(self.config["model"]["device"]), torch.zeros((1, 197)).to(self.config["model"]["device"])
        if not is_local:
            try:
                images = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
            except:
                print(f"READ IMAGE ERR: {image_url}")
                return torch.zeros((1, 197, 768)).to(self.config["model"]["device"]), torch.zeros((1, 197)).to(self.config["model"]["device"])
        else:
            images = Image.open(image_url).convert("RGB")
        inputs = self.vit_image_processor(images, return_tensors="pt").to(self.config["model"]["device"])
        with torch.no_grad():
            outputs = self.vit_model(**inputs)
        last_hidden_states = outputs.last_hidden_state
        attention_mask = torch.ones((last_hidden_states.shape[0], last_hidden_states.shape[1]))

        return last_hidden_states.to(self.config["model"]["device"]), attention_mask.to(self.config["model"]["device"])

    def inference(self, content_feature, content_mask, image_feature, image_mask):
        
        inputs_embeds = torch.cat((image_feature[0], content_feature[0]), 0)
        inputs_embeds = torch.unsqueeze(inputs_embeds, 0)
        attention_mask = torch.cat((image_mask[0], content_mask[0]), 0)
        attention_mask = torch.unsqueeze(attention_mask, 0)
        with torch.no_grad():
            generated_ids = self.model.generate(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                num_beams=2,
                max_length=self.config["model"]["output_maxlen"],
                # num_return_sequences=2
                # skip_special_tokens=True,
                # clean_up_tokenization_spaces=True
            )
        comments = [self.tokenizer.decode(generated_id, skip_special_tokens=True) for generated_id in generated_ids]
        return comments