hahahafofo's picture
init
79ec61a
raw
history blame
1.87 kB
from textgen import ChatGlmModel, LlamaModel
from loguru import logger
class LLM(object):
def __init__(
self,
gen_model_type: str = "chatglm",
gen_model_name_or_path: str = "THUDM/chatglm-6b-int4",
lora_model_name_or_path: str = None,
):
self.model_type = gen_model_type
if gen_model_type == "chatglm":
self.gen_model = ChatGlmModel(
gen_model_type,
gen_model_name_or_path,
lora_name=lora_model_name_or_path,
)
elif gen_model_type == "llama":
self.gen_model = LlamaModel(
gen_model_type,
gen_model_name_or_path,
lora_name=lora_model_name_or_path,
)
else:
raise ValueError('gen_model_type must be chatglm or llama.')
self.history = None
def generate_answer(self, query_str, context_str, history=None, max_length=1024, prompt_template=None):
"""Generate answer from query and context."""
if self.model_type == "t5":
response = self.gen_model(query_str, max_length=max_length, do_sample=True)[0]['generated_text']
return response, history
prompt = prompt_template.format(context_str=context_str, query_str=query_str)
response, out_history = self.gen_model.chat(prompt, history, max_length=max_length)
return response, out_history
def chat(self, query_str, history=None, max_length=1024):
if self.model_type == "t5":
response = self.gen_model(query_str, max_length=max_length, do_sample=True)[0]['generated_text']
logger.debug(response)
return response, history
response, out_history = self.gen_model.chat(query_str, history, max_length=max_length)
return response, out_history