video_bot_999 / chatbot.py
youngtsai's picture
update
8080a34
raw
history blame contribute delete
No virus
6.57 kB
import gradio as gr
import json
import requests
class Chatbot:
def __init__(self, config):
self.video_id = config.get('video_id')
self.content_subject = config.get('content_subject')
self.content_grade = config.get('content_grade')
self.jutor_chat_key = config.get('jutor_chat_key')
self.transcript_text = self.get_transcript_text(config.get('transcript'))
self.key_moments_text = self.get_key_moments_text(config.get('key_moments'))
self.ai_model_name = config.get('ai_model_name')
self.ai_client = config.get('ai_client')
self.instructions = config.get('instructions')
def get_transcript_text(self, transcript_data):
if isinstance(transcript_data, str):
transcript_json = json.loads(transcript_data)
else:
transcript_json = transcript_data
for entry in transcript_json:
entry.pop('end_time', None)
transcript_text = json.dumps(transcript_json, ensure_ascii=False)
return transcript_text
def get_key_moments_text(self, key_moments_data):
if isinstance(key_moments_data, str):
key_moments_json = json.loads(key_moments_data)
else:
key_moments_json = key_moments_data
# key_moments_json remove images
for moment in key_moments_json:
moment.pop('images', None)
moment.pop('end', None)
moment.pop('transcript', None)
key_moments_text = json.dumps(key_moments_json, ensure_ascii=False)
return key_moments_text
def chat(self, user_message, chat_history):
try:
messages = self.prepare_messages(chat_history, user_message)
system_prompt = self.instructions
service_type = self.ai_model_name
response_text = self.chat_with_service(service_type, system_prompt, messages)
except Exception as e:
print(f"Error: {e}")
response_text = "學習精靈有點累,請稍後再試!"
return response_text
def prepare_messages(self, chat_history, user_message):
messages = []
if chat_history is not None:
if len(chat_history) > 10:
chat_history = chat_history[-10:]
for user_msg, assistant_msg in chat_history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
if user_message:
user_message += "/n (請一定要用繁體中文回答 zh-TW,並用台灣人的禮貌口語表達,回答時不要特別說明這是台灣人的語氣,不要提到「台灣腔」,不用提到「逐字稿」這個詞,用「內容」代替),回答時如果有用到數學式,請用數學符號代替純文字(Latex 用 $ 字號 render)"
messages.append({"role": "user", "content": user_message})
return messages
def chat_with_service(self, service_type, system_prompt, messages):
if service_type == 'openai':
return self.chat_with_jutor(system_prompt, messages)
elif service_type == 'groq_llama3':
return self.chat_with_groq(service_type, system_prompt, messages)
elif service_type == 'groq_mixtral':
return self.chat_with_groq(service_type, system_prompt, messages)
elif service_type == 'claude3':
return self.chat_with_claude3(system_prompt, messages)
else:
raise gr.Error("不支持的服务类型")
def chat_with_jutor(self, system_prompt, messages):
messages.insert(0, {"role": "system", "content": system_prompt})
api_endpoint = "https://ci-live-feat-video-ai-dot-junyiacademy.appspot.com/api/v2/jutor/hf-chat"
headers = {
"Content-Type": "application/json",
"x-api-key": self.jutor_chat_key,
}
model = "gpt-4o"
print("======model======")
print(model)
# model = "gpt-3.5-turbo-0125"
data = {
"data": {
"messages": messages,
"max_tokens": 512,
"temperature": 0.9,
"model": model,
"stream": False,
}
}
response = requests.post(api_endpoint, headers=headers, data=json.dumps(data))
response_data = response.json()
response_completion = response_data['data']['choices'][0]['message']['content'].strip()
return response_completion
def chat_with_groq(self, model_name, system_prompt, messages):
# system_prompt insert to messages 的最前面 {"role": "system", "content": system_prompt}
messages.insert(0, {"role": "system", "content": system_prompt})
model_name_dict = {
"groq_llama3": "llama3-70b-8192",
"groq_mixtral": "mixtral-8x7b-32768"
}
model = model_name_dict.get(model_name)
print("======model======")
print(model)
request_payload = {
"model": model,
"messages": messages,
"max_tokens": 500 # 設定一個較大的值,可根據需要調整
}
groq_client = self.ai_client
response = groq_client.chat.completions.create(**request_payload)
response_completion = response.choices[0].message.content.strip()
return response_completion
def chat_with_claude3(self, system_prompt, messages):
if not system_prompt.strip():
raise ValueError("System prompt cannot be empty")
model_id = "anthropic.claude-3-sonnet-20240229-v1:0"
# model_id = "anthropic.claude-3-haiku-20240307-v1:0"
print("======model_id======")
print(model_id)
kwargs = {
"modelId": model_id,
"contentType": "application/json",
"accept": "application/json",
"body": json.dumps({
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 500,
"system": system_prompt,
"messages": messages
})
}
# 建立 message API,讀取回應
bedrock_client = self.ai_client
response = bedrock_client.invoke_model(**kwargs)
response_body = json.loads(response.get('body').read())
response_completion = response_body.get('content')[0].get('text').strip()
return response_completion