ChatWorld / src /Models /models.py
JiangYH's picture
Upload folder using huggingface_hub
6f179e7 verified
import os
from string import Template
from typing import Dict, List, Union
from transformers import AutoTokenizer, AutoModelForCausalLM
from zhipuai import ZhipuAI
class GLM:
def __init__(self, model_name="silk-road/Haruhi-Zero-GLM3-6B-0_4"):
self.tokenizer = AutoTokenizer.from_pretrained(
model_name, trust_remote_code=True
)
client = AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True, device_map="auto"
)
self.client = client.eval()
def message2query(self, messages) -> str:
# [{'role': 'user', 'content': '老师: 同学请自我介绍一下'}]
# <|system|>
# You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
# <|user|>
# Hello
# <|assistant|>
# Hello, I'm ChatGLM3. What can I assist you today?
template = Template("<|$role|>\n$content\n")
return "".join([template.substitute(message) for message in messages])
def get_response(
self,
message: Union[str, list[dict[str, str]]],
history: List[Dict[str, str]] = None,
):
if isinstance(message, str):
response, history = self.client.chat(self.tokenizer, message)
elif isinstance(message, list):
response, history = self.client.chat(
self.tokenizer, message[-1]["content"],history=message[:-1]
)
# print(self.message2query(message))
print(response)
return response
class GLM_api:
def __init__(self, model_name="glm-4"):
API_KEY = os.environ.get("ZHIPU_API_KEY")
self.client = ZhipuAI(api_key=API_KEY)
self.model = model_name
def chat(self, message):
try:
response = self.client.chat.completions.create(
model=self.model, messages=message
)
except Exception as e:
print(e)
return "模型连接失败"
return response.choices[0].message.content