Spaces:
Sleeping
Sleeping
File size: 7,484 Bytes
dc14176 a4c3bcc dc14176 7f79d8b 985ad3e dc14176 7f79d8b dc14176 7f79d8b dc14176 7f79d8b dc14176 7f79d8b dc14176 7f79d8b dc14176 7f79d8b dc14176 7f79d8b dc14176 7f79d8b dc14176 7f79d8b dc14176 7f79d8b dc14176 7f79d8b dc14176 7f79d8b dc14176 7f79d8b dc14176 7f79d8b 47ed12b 7f79d8b dc14176 47ed12b dc14176 47ed12b 7f79d8b dc14176 47ed12b dc14176 a4c3bcc dc14176 47ed12b dc14176 47ed12b dc14176 47ed12b dc14176 a4c3bcc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import os
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import json
class VietnameseChatbot:
def __init__(self, model_name='intfloat/multilingual-e5-small'):
"""
Initialize the Vietnamese chatbot with pre-loaded model and conversation data
"""
# Load pre-trained model and tokenizer
print("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
print("Loading model...")
self.model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16)
# Load comprehensive conversation dataset
self.conversation_data = self._load_conversation_data()
# Pre-compute embeddings for faster response generation
print("Pre-computing conversation embeddings...")
self.conversation_embeddings = self._precompute_embeddings()
def _load_conversation_data(self):
"""
Load a comprehensive conversation dataset
"""
return [
# Greeting conversations
{"query": "Xin chào", "response": "Chào bạn! Tôi có thể giúp gì cho bạn?"},
{"query": "Hi", "response": "Xin chào! Tôi là trợ lý AI tiếng Việt."},
{"query": "Chào buổi sáng", "response": "Chào buổi sáng! Chúc bạn một ngày tốt lành."},
# Identity and purpose
{"query": "Bạn là ai?", "response": "Tôi là trợ lý AI được phát triển để hỗ trợ và trò chuyện bằng tiếng Việt."},
{"query": "Bạn từ đâu đến?", "response": "Tôi được phát triển bởi một nhóm kỹ sư AI, và tôn chỉ của tôi là hỗ trợ con người."},
# Small talk
{"query": "Bạn thích gì?", "response": "Tôi thích học hỏi và giúp đỡ mọi người. Mỗi cuộc trò chuyện là một cơ hội để tôi phát triển."},
{"query": "Bạn có thể làm gì?", "response": "Tôi có thể trò chuyện, trả lời câu hỏi, và hỗ trợ bạn trong nhiều tình huống khác nhau."},
# Weather and time
{"query": "Thời tiết hôm nay thế nào?", "response": "Xin lỗi, tôi không thể cung cấp thông tin thời tiết trực tiếp. Bạn có thể kiểm tra ứng dụng dự báo thời tiết."},
{"query": "Bây giờ là mấy giờ?", "response": "Tôi là trợ lý AI, nên không thể xem đồng hồ. Bạn có thể kiểm tra thiết bị của mình."},
# Assistance offers
{"query": "Tôi cần trợ giúp", "response": "Tôi sẵn sàng hỗ trợ bạn. Bạn cần giúp gì?"},
{"query": "Giúp tôi với cái gì đó", "response": "Vâng, tôi có thể hỗ trợ bạn. Hãy cho tôi biết chi tiết hơn."},
# Farewell
{"query": "Tạm biệt", "response": "Hẹn gặp lại! Chúc bạn một ngày tốt đẹp."},
{"query": "Bye", "response": "Tạm biệt! Rất vui được trò chuyện với bạn."},
]
def _precompute_embeddings(self):
"""
Pre-compute embeddings for all conversation queries
"""
embeddings = []
for item in self.conversation_data:
embedding = self.embed_text(item['query'])
if embedding is not None:
embeddings.append(embedding[0])
return np.array(embeddings)
def embed_text(self, text):
"""
Generate embeddings for input text
"""
try:
# Tokenize and generate embeddings
inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True)
with torch.no_grad():
model_output = self.model(**inputs)
# Mean pooling
embeddings = self.mean_pooling(model_output, inputs['attention_mask'])
return embeddings.numpy()
except Exception as e:
print(f"Embedding error: {e}")
return None
def mean_pooling(self, model_output, attention_mask):
"""
Perform mean pooling on model output
"""
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def get_response(self, user_query):
"""
Find the most similar response from conversation data
"""
try:
# Embed user query
query_embedding = self.embed_text(user_query)
if query_embedding is None:
return "Xin lỗi, đã có lỗi xảy ra khi phân tích câu hỏi của bạn."
# Calculate cosine similarities
similarities = cosine_similarity(query_embedding, self.conversation_embeddings)[0]
# Find most similar response
best_match_index = np.argmax(similarities)
# Return response if similarity is above threshold
if similarities[best_match_index] > 0.5:
return self.conversation_data[best_match_index]['response']
return "Xin lỗi, tôi chưa hiểu rõ câu hỏi của bạn. Bạn có thể diễn đạt lại được không?"
except Exception as e:
print(f"Response generation error: {e}")
return "Đã xảy ra lỗi. Xin vui lòng thử lại."
def main():
st.set_page_config(
page_title="Trợ Lý AI Tiếng Việt",
page_icon="🤖",
)
st.title("🤖 Trợ Lý AI Tiếng Việt")
st.caption("Trò chuyện với trợ lý AI được phát triển bằng mô hình đa ngôn ngữ")
# Initialize chatbot (this will pre-load models and embeddings)
chatbot = VietnameseChatbot()
# Chat history in session state
if 'messages' not in st.session_state:
st.session_state.messages = []
# Sidebar for additional information
with st.sidebar:
st.header("Về Trợ Lý AI")
st.write("Đây là một trợ lý AI được phát triển để hỗ trợ trò chuyện bằng tiếng Việt.")
st.write("Mô hình sử dụng: intfloat/multilingual-e5-small")
# Display chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# User input
if prompt := st.chat_input("Hãy nói gì đó..."):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# Get chatbot response
response = chatbot.get_response(prompt)
# Display chatbot response
with st.chat_message("assistant"):
st.markdown(response)
# Add assistant message to chat history
st.session_state.messages.append({"role": "assistant", "content": response})
if __name__ == "__main__":
main() |