File size: 2,062 Bytes
0b01481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import os.path as osp
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation import GenerationConfig
from typing import Dict, List, Optional, Union


class QwenWrapper:
    def __init__(self, model_path: str='Qwen/Qwen-7B-Chat-Int4',system_prompt: str = None, **model_kwargs):
        self.system_prompt = system_prompt
        self.model_path=model_path
        assert osp.exists(model_path) or len(model_path.split('/')) == 2 
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True,device_map="auto")
        try:
            model.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True,device_map="auto")
        except:
            pass
        
        model = model.eval()
        self.model = model
        self.context_length=model.config.seq_length
        self.answer_buffer = 192
        for k, v in model_kwargs.items():
            print(f'Following args are passed but not used to initialize the model, {k}: {v}. ')
        
    def length_ok(self, inputs):
        tot = 0
        for s in inputs:
            tot += len(self.tokenizer.encode(s))
        return tot + self.answer_buffer < self.context_length

    def chat(self,full_inputs: Union[str, List[str]],offset=0) -> str:
        inputs = full_inputs[offset:]
        if not self.length_ok(inputs):
            return self.chat(full_inputs, offset + 1)

        history = []
        if len(inputs) % 2 == 1:
            for i in range(len(inputs)//2):
                history.append((inputs[2*i],inputs[2*i+1]))
            input_msgs=inputs[-1]
        else:
            history.append(('',inputs[0]))
            for i in range(len(inputs)//2-1):
                history.append((inputs[2*i+1],inputs[2*i+2]))
            input_msgs=inputs[-1]

            
        response,_ = self.model.chat(self.tokenizer, input_msgs,history=history,system=self.system_prompt)
        return response,offset