File size: 2,325 Bytes
08e5d46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from langchain.llms.base import LLM
from typing import Dict, List, Any, Optional
import torch,sys,os
from transformers import AutoTokenizer


class LLM(LLM):
    max_token: int = 4000
    temperature: float = 0.1
    top_p: float = 0.95
    tokenizer: Any
    model: Any

    def __init__(self, model_name_or_path, bit4=True):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,use_fast=False)
        self.tokenizer.pad_token = self.tokenizer.eos_token

        if bit4==False:
            from transformers import AutoModelForCausalLM
            self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path,device_map='auto',torch_dtype=torch.bfloat16,load_in_8bit=True)
            self.model.eval()
        else:
            from transformers import BitsAndBytesConfig, AutoModelForCausalLM
            double_quant_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True)
            self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map='auto', quantization_config=double_quant_config)
            self.model.eval()

        if torch.__version__ >= "2" and sys.platform != "win32":
            self.model = torch.compile(self.model)

    @property
    def _llm_type(self) -> str:
        return "Mistral"

    def _call(self, prompt: str, stop: Optional[List[str]] = None, return_only_outputs = True) -> str:
        print('prompt:',prompt)
        input_ids = self.tokenizer(prompt, return_tensors="pt",add_special_tokens=False).input_ids.to('cuda')
        generate_input = {
            "input_ids":input_ids,
            "max_new_tokens":1024,
            "do_sample":True,
            "top_k":50,
            "top_p":self.top_p,
            "temperature":self.temperature,
            "repetition_penalty":1.2,
            "eos_token_id":self.tokenizer.eos_token_id,
            "bos_token_id":self.tokenizer.bos_token_id,
            "pad_token_id":self.tokenizer.pad_token_id
        }
        generate_ids = self.model.generate(**generate_input)
        generate_ids = [item[len(input_ids[0]):-1] for  item in generate_ids]
        result_message = self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        return result_message