ChatWorld / src /Models /models.py
JiangYH's picture
Upload folder using huggingface_hub
6f179e7 verified
raw
history blame contribute delete
No virus
2.1 kB
import os
from string import Template
from typing import Dict, List, Union
from transformers import AutoTokenizer, AutoModelForCausalLM
from zhipuai import ZhipuAI
class GLM:
def __init__(self, model_name="silk-road/Haruhi-Zero-GLM3-6B-0_4"):
self.tokenizer = AutoTokenizer.from_pretrained(
model_name, trust_remote_code=True
)
client = AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True, device_map="auto"
)
self.client = client.eval()
def message2query(self, messages) -> str:
# [{'role': 'user', 'content': '老师: 同学请自我介绍一下'}]
# <|system|>
# You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
# <|user|>
# Hello
# <|assistant|>
# Hello, I'm ChatGLM3. What can I assist you today?
template = Template("<|$role|>\n$content\n")
return "".join([template.substitute(message) for message in messages])
def get_response(
self,
message: Union[str, list[dict[str, str]]],
history: List[Dict[str, str]] = None,
):
if isinstance(message, str):
response, history = self.client.chat(self.tokenizer, message)
elif isinstance(message, list):
response, history = self.client.chat(
self.tokenizer, message[-1]["content"],history=message[:-1]
)
# print(self.message2query(message))
print(response)
return response
class GLM_api:
def __init__(self, model_name="glm-4"):
API_KEY = os.environ.get("ZHIPU_API_KEY")
self.client = ZhipuAI(api_key=API_KEY)
self.model = model_name
def chat(self, message):
try:
response = self.client.chat.completions.create(
model=self.model, messages=message
)
except Exception as e:
print(e)
return "模型连接失败"
return response.choices[0].message.content