Spaces:
Sleeping
Sleeping
File size: 9,200 Bytes
dc14176 a4c3bcc dc14176 7f79d8b 985ad3e 80eee0f dc14176 7f79d8b dc14176 7f79d8b dc14176 80eee0f dc14176 7f79d8b 289913c 7f79d8b 289913c 7f79d8b dc14176 7f79d8b 80eee0f 289913c 7f79d8b 289913c 80eee0f 7f79d8b 289913c 7f79d8b 289913c 7f79d8b 289913c 7f79d8b dc14176 7f79d8b dc14176 7f79d8b dc14176 7f79d8b dc14176 7f79d8b dc14176 7f79d8b dc14176 7f79d8b dc14176 7f79d8b dc14176 80eee0f dc14176 7f79d8b dc14176 7f79d8b 47ed12b 80eee0f 47ed12b dc14176 47ed12b 7f79d8b dc14176 47ed12b dc14176 a4c3bcc dc14176 47ed12b dc14176 47ed12b dc14176 47ed12b dc14176 a4c3bcc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
import os
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import json
@st.cache_resource
def load_model_and_tokenizer(model_name='intfloat/multilingual-e5-small'):
"""
Cached function to load model and tokenizer
This ensures the model is loaded only once and reused
"""
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("Loading model...")
model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16)
return tokenizer, model
class VietnameseChatbot:
def __init__(self, model_name='intfloat/multilingual-e5-small'):
"""
Initialize the Vietnamese chatbot with pre-loaded model and conversation data
"""
# Load pre-trained model and tokenizer using cached function
self.tokenizer, self.model = load_model_and_tokenizer(model_name)
# Load comprehensive conversation dataset
self.conversation_data = self._load_conversation_data()
# Pre-compute embeddings for faster response generation
print("Pre-computing conversation embeddings...")
self.conversation_embeddings = self._compute_embeddings()
def _load_conversation_data(self):
"""
Load a comprehensive conversation dataset
"""
return [
# Greeting conversations
{"query": "Xin chào", "response": "Chào bạn! Tôi có thể giúp gì cho bạn?"},
{"query": "Hi", "response": "Xin chào! Tôi là trợ lý AI tiếng Việt."},
{"query": "Chào buổi sáng", "response": "Chào buổi sáng! Chúc bạn một ngày tốt lành."},
# Identity and purpose
{"query": "Bạn là ai?", "response": "Tôi là trợ lý AI được phát triển để hỗ trợ và trò chuyện bằng tiếng Việt."},
{"query": "Bạn từ đâu đến?", "response": "Tôi được phát triển bởi một nhóm kỹ sư AI, và tôn chỉ của tôi là hỗ trợ con người."},
# Small talk
{"query": "Bạn thích gì?", "response": "Tôi thích học hỏi và giúp đỡ mọi người. Mỗi cuộc trò chuyện là một cơ hội để tôi phát triển."},
{"query": "Bạn có thể làm gì?", "response": "Tôi có thể trò chuyện, trả lời câu hỏi, và hỗ trợ bạn trong nhiều tình huống khác nhau."},
# Weather and time
{"query": "Thời tiết hôm nay thế nào?", "response": "Xin lỗi, tôi không thể cung cấp thông tin thời tiết trực tiếp. Bạn có thể kiểm tra ứng dụng dự báo thời tiết."},
{"query": "Bây giờ là mấy giờ?", "response": "Tôi là trợ lý AI, nên không thể xem đồng hồ. Bạn có thể kiểm tra thiết bị của mình."},
# Assistance offers
{"query": "Tôi cần trợ giúp", "response": "Tôi sẵn sàng hỗ trợ bạn. Bạn cần giúp gì?"},
{"query": "Giúp tôi với cái gì đó", "response": "Vâng, tôi có thể hỗ trợ bạn. Hãy cho tôi biết chi tiết hơn."},
# Farewell
{"query": "Tạm biệt", "response": "Hẹn gặp lại! Chúc bạn một ngày tốt đẹp."},
{"query": "Bye", "response": "Tạm biệt! Rất vui được trò chuyện với bạn."},
]
@st.cache_data
def _compute_embeddings(queries):
"""
Pre-compute embeddings for conversation queries
Cached to avoid recomputing on every run
"""
def embed_single_text(text, tokenizer, model):
try:
# Tokenize and generate embeddings
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
with torch.no_grad():
model_output = model(**inputs)
# Mean pooling
token_embeddings = model_output[0]
input_mask_expanded = inputs['attention_mask'].unsqueeze(-1).expand(token_embeddings.size()).float()
embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return embeddings.numpy()[0]
except Exception as e:
print(f"Embedding error: {e}")
return None
# Import these arguments to make the function self-contained
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-small')
model = AutoModel.from_pretrained('intfloat/multilingual-e5-small', torch_dtype=torch.float16)
embeddings = []
for query in queries:
embedding = embed_single_text(query['query'], tokenizer, model)
if embedding is not None:
embeddings.append(embedding)
return np.array(embeddings)
def embed_text(self, text):
"""
Generate embeddings for input text
"""
try:
# Tokenize and generate embeddings
inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True)
with torch.no_grad():
model_output = self.model(**inputs)
# Mean pooling
embeddings = self.mean_pooling(model_output, inputs['attention_mask'])
return embeddings.numpy()
except Exception as e:
print(f"Embedding error: {e}")
return None
def mean_pooling(self, model_output, attention_mask):
"""
Perform mean pooling on model output
"""
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def get_response(self, user_query):
"""
Find the most similar response from conversation data
"""
try:
# Embed user query
query_embedding = self.embed_text(user_query)
if query_embedding is None:
return "Xin lỗi, đã có lỗi xảy ra khi phân tích câu hỏi của bạn."
# Calculate cosine similarities
similarities = cosine_similarity(query_embedding, self.conversation_embeddings)[0]
# Find most similar response
best_match_index = np.argmax(similarities)
# Return response if similarity is above threshold
if similarities[best_match_index] > 0.5:
return self.conversation_data[best_match_index]['response']
return "Xin lỗi, tôi chưa hiểu rõ câu hỏi của bạn. Bạn có thể diễn đạt lại được không?"
except Exception as e:
print(f"Response generation error: {e}")
return "Đã xảy ra lỗi. Xin vui lòng thử lại."
@st.cache_resource
def initialize_chatbot():
"""
Cached function to initialize the chatbot
This ensures the chatbot is created only once
"""
return VietnameseChatbot()
def main():
st.set_page_config(
page_title="Trợ Lý AI Tiếng Việt",
page_icon="🤖",
)
st.title("🤖 Trợ Lý AI Tiếng Việt")
st.caption("Trò chuyện với trợ lý AI được phát triển bằng mô hình đa ngôn ngữ")
# Initialize chatbot using cached initialization
chatbot = initialize_chatbot()
# Chat history in session state
if 'messages' not in st.session_state:
st.session_state.messages = []
# Sidebar for additional information
with st.sidebar:
st.header("Về Trợ Lý AI")
st.write("Đây là một trợ lý AI được phát triển để hỗ trợ trò chuyện bằng tiếng Việt.")
st.write("Mô hình sử dụng: intfloat/multilingual-e5-small")
# Display chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# User input
if prompt := st.chat_input("Hãy nói gì đó..."):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# Get chatbot response
response = chatbot.get_response(prompt)
# Display chatbot response
with st.chat_message("assistant"):
st.markdown(response)
# Add assistant message to chat history
st.session_state.messages.append({"role": "assistant", "content": response})
if __name__ == "__main__":
main() |