import torch import os #os.environ['TRANSFORMERS_CACHE'] = "./.cache" from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, TextStreamer from vigogne.preprocess import generate_inference_chat_prompt class CaesarFrenchLLM: def __init__(self) -> None: self.history = [] base_model_name_or_path = "bofenghuang/vigogne-2-7b-chat" self.tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path, padding_side="right", use_fast=False,) self.tokenizer.pad_token = self.tokenizer.eos_token self.model = AutoModelForCausalLM.from_pretrained( base_model_name_or_path, torch_dtype=torch.float32, device_map="auto", # load_in_8bit=True, # trust_remote_code=True, # low_cpu_mem_usage=True, ) # lora_model_name_or_path = "" # model = PeftModel.from_pretrained(model, lora_model_name_or_path) self.model.eval() if torch.__version__ >= "2": self.model = torch.compile(self.model) self.streamer = TextStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) def infer(self,user_query,temperature=0.1,top_p=1.0,top_k=0,max_new_tokens=512,**kwargs,): prompt = generate_inference_chat_prompt(user_query, tokenizer=self.tokenizer) input_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"].to(self.model.device) input_length = input_ids.shape[1] generated_outputs = self.model.generate( input_ids=input_ids, generation_config=GenerationConfig( temperature=temperature, do_sample=temperature > 0.0, top_p=top_p, top_k=top_k, max_new_tokens=max_new_tokens, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id, **kwargs, ), streamer=self.streamer, return_dict_in_generate=True, ) generated_tokens = generated_outputs.sequences[0, input_length:] generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) return generated_text def chat(self,user_input,**kwargs): print(f">> <|user|>: {user_input}") print(">> <|assistant|>: ", end="") model_response = self.infer([*self.history, [user_input, ""]], **kwargs) self.history.append([user_input, model_response]) return self.history[-1][1] # print(f">> <|assistant|>: {history[-1][1]}")