from transformers import AutoTokenizer, AutoModelForCausalLM import re import time import torch class SweetCommander(): def __init__(self, path="BlueDice/Katakuri-350m") -> None: self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = AutoModelForCausalLM.from_pretrained( path, low_cpu_mem_usage = True, trust_remote_code = False, torch_dtype = torch.float32, ) self.default_template = open("character_card.txt", "r").read() self.star_line = "***********************************************************" def __call__(self, char_name, user_name, user_input): t1 = time.time() prompt = self.default_template.format( char_name = char_name, user_name = user_name, user_input = user_input ) print(self.star_line) print(prompt) input_ids = self.tokenizer(prompt + f"\n{char_name}:", return_tensors = "pt") encoded_output = self.model.generate( input_ids["input_ids"], max_new_tokens = 50, temperature = 0.5, top_p = 0.9, top_k = 0, repetition_penalty = 1.1, pad_token_id = 50256, num_return_sequences = 1 ) decoded_output = self.tokenizer.decode(encoded_output[0], skip_special_tokens = True).replace(prompt, "") decoded_output = decoded_output.split(f"{char_name}:", 1)[1].split(f"{user_name}:",1)[0].strip() # parsed_result = re.sub('\*.*?\*', '', decoded_output).strip() # if len(parsed_result) != 0: decoded_output = parsed_result # decoded_output = " ".join(decoded_output.replace("*","").split()) # try: # parsed_result = decoded_output[:[m.start() for m in re.finditer(r'[.!?]', decoded_output)][-1]+1] # if len(parsed_result) != 0: decoded_output = parsed_result # except Exception: pass print(self.star_line) print("Response:",decoded_output) print("Eval time:",time.time()-t1) print(self.star_line) return decoded_output