File size: 1,543 Bytes
01a2ce5
44d180e
 
 
2b167f5
09f135e
 
9779cd8
 
09f135e
2b167f5
a1fddf9
01a2ce5
143b62d
 
3f40f6e
c9c9f16
e2bb507
143b62d
01a2ce5
 
a1fddf9
 
01a2ce5
 
44d180e
6b6861a
01a2ce5
44d180e
 
d5f9bcf
44d180e
 
9779cd8
01a2ce5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3984f1
44d180e
b9e79ed
fcae4c8
da4acea
44d180e
 
 
 
 
d546b0e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from transformers import AutoTokenizer
import transformers
import torch

from huggingface_hub import login
import os 

import logging

login(token = os.getenv('HF_TOKEN'))

class Model(torch.nn.Module):
    number_of_models = 0
    __model_list__ = [
        "lmsys/vicuna-7b-v1.5",
        "google-t5/t5-large",
        "mistralai/Mistral-7B-Instruct-v0.1",
        "meta-llama/Meta-Llama-3.1-8B-Instruct"
    ]

    def __init__(self, model_name="lmsys/vicuna-7b-v1.5") -> None:
        super(Model, self).__init__()
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.name = model_name
        self.pipeline = transformers.pipeline(
            "summarization",
            model=model_name,
            tokenizer=self.tokenizer,
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )

        logging.info(f'Load model {self.name}')
        self.update()

    @classmethod
    def update(cls):
        cls.number_of_models += 1

    def return_mode_name(self):
        return self.name
    
    def return_tokenizer(self):
        return self.tokenizer
    
    def return_model(self):
        return self.pipeline

    def gen(self, content, temp=0.1, max_length=500):
        sequences = self.pipeline(
            content,
            max_new_tokens=max_length,
            do_sample=True,
            temperature=temp,
            num_return_sequences=1,
            eos_token_id=self.tokenizer.eos_token_id,
        )
        
        return sequences[-1]['summary_text']