import os.path as osp from transformers import AutoTokenizer, AutoModelForCausalLM from transformers.generation import GenerationConfig from typing import Dict, List, Optional, Union class QwenWrapper: def __init__(self, model_path: str='Qwen/Qwen-7B-Chat-Int4',system_prompt: str = None, **model_kwargs): self.system_prompt = system_prompt self.model_path=model_path assert osp.exists(model_path) or len(model_path.split('/')) == 2 self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True,device_map="auto") try: model.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True,device_map="auto") except: pass model = model.eval() self.model = model self.context_length=model.config.seq_length self.answer_buffer = 192 for k, v in model_kwargs.items(): print(f'Following args are passed but not used to initialize the model, {k}: {v}. ') def length_ok(self, inputs): tot = 0 for s in inputs: tot += len(self.tokenizer.encode(s)) return tot + self.answer_buffer < self.context_length def chat(self,full_inputs: Union[str, List[str]],offset=0) -> str: inputs = full_inputs[offset:] if not self.length_ok(inputs): return self.chat(full_inputs, offset + 1) history = [] if len(inputs) % 2 == 1: for i in range(len(inputs)//2): history.append((inputs[2*i],inputs[2*i+1])) input_msgs=inputs[-1] else: history.append(('',inputs[0])) for i in range(len(inputs)//2-1): history.append((inputs[2*i+1],inputs[2*i+2])) input_msgs=inputs[-1] response,_ = self.model.chat(self.tokenizer, input_msgs,history=history,system=self.system_prompt) return response,offset