File size: 1,374 Bytes
51fc134
 
 
deac6fb
 
 
51fc134
68540c5
 
 
 
 
 
51fc134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deac6fb
57849f1
deac6fb
51fc134
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

class EndpointHandler:
    def __init__(self, path=""):
        base_model_id = "mistralai/Mistral-7B-v0.1"
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16
        )
        
        base_model = AutoModelForCausalLM.from_pretrained(
            base_model_id,  # Mistral, same as before
            quantization_config=bnb_config,  # Same quantization config as before
            device_map="auto",
            trust_remote_code=True,
            use_auth_token=False
        )
        
        self.eval_tokenizer = AutoTokenizer.from_pretrained(
            base_model_id,
            add_bos_token=True,
            trust_remote_code=True,
        )
        
        self.ft_model = PeftModel.from_pretrained(base_model, "FloVolo/mistral-flo-finetune-2-T4").to("cuda")

    def __call__(self, data):
        inputs = data.pop("inputs", data)

        model_input = self.eval_tokenizer(inputs, return_tensors="pt").to("cuda")

        with torch.no_grad():
            return self.eval_tokenizer.decode(self.ft_model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True)