Spaces:
Runtime error
Runtime error
from textgen import ChatGlmModel, LlamaModel | |
from loguru import logger | |
class LLM(object): | |
def __init__( | |
self, | |
gen_model_type: str = "chatglm", | |
gen_model_name_or_path: str = "THUDM/chatglm-6b-int4", | |
lora_model_name_or_path: str = None, | |
): | |
self.model_type = gen_model_type | |
if gen_model_type == "chatglm": | |
self.gen_model = ChatGlmModel( | |
gen_model_type, | |
gen_model_name_or_path, | |
lora_name=lora_model_name_or_path, | |
) | |
elif gen_model_type == "llama": | |
self.gen_model = LlamaModel( | |
gen_model_type, | |
gen_model_name_or_path, | |
lora_name=lora_model_name_or_path, | |
) | |
else: | |
raise ValueError('gen_model_type must be chatglm or llama.') | |
self.history = None | |
def generate_answer(self, query_str, context_str, history=None, max_length=1024, prompt_template=None): | |
"""Generate answer from query and context.""" | |
if self.model_type == "t5": | |
response = self.gen_model(query_str, max_length=max_length, do_sample=True)[0]['generated_text'] | |
return response, history | |
prompt = prompt_template.format(context_str=context_str, query_str=query_str) | |
response, out_history = self.gen_model.chat(prompt, history, max_length=max_length) | |
return response, out_history | |
def chat(self, query_str, history=None, max_length=1024): | |
if self.model_type == "t5": | |
response = self.gen_model(query_str, max_length=max_length, do_sample=True)[0]['generated_text'] | |
logger.debug(response) | |
return response, history | |
response, out_history = self.gen_model.chat(query_str, history, max_length=max_length) | |
return response, out_history | |