Spaces:
Sleeping
Sleeping
File size: 2,062 Bytes
0b01481 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import os.path as osp
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation import GenerationConfig
from typing import Dict, List, Optional, Union
class QwenWrapper:
def __init__(self, model_path: str='Qwen/Qwen-7B-Chat-Int4',system_prompt: str = None, **model_kwargs):
self.system_prompt = system_prompt
self.model_path=model_path
assert osp.exists(model_path) or len(model_path.split('/')) == 2
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True,device_map="auto")
try:
model.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True,device_map="auto")
except:
pass
model = model.eval()
self.model = model
self.context_length=model.config.seq_length
self.answer_buffer = 192
for k, v in model_kwargs.items():
print(f'Following args are passed but not used to initialize the model, {k}: {v}. ')
def length_ok(self, inputs):
tot = 0
for s in inputs:
tot += len(self.tokenizer.encode(s))
return tot + self.answer_buffer < self.context_length
def chat(self,full_inputs: Union[str, List[str]],offset=0) -> str:
inputs = full_inputs[offset:]
if not self.length_ok(inputs):
return self.chat(full_inputs, offset + 1)
history = []
if len(inputs) % 2 == 1:
for i in range(len(inputs)//2):
history.append((inputs[2*i],inputs[2*i+1]))
input_msgs=inputs[-1]
else:
history.append(('',inputs[0]))
for i in range(len(inputs)//2-1):
history.append((inputs[2*i+1],inputs[2*i+2]))
input_msgs=inputs[-1]
response,_ = self.model.chat(self.tokenizer, input_msgs,history=history,system=self.system_prompt)
return response,offset |