IOTraining / app.py.bak.curr
JustKiddo's picture
Rename app.py to app.py.bak.curr
0123b44 verified
import os
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import json
@st.cache_resource
def load_model_and_tokenizer(model_name='intfloat/multilingual-e5-small'):
"""
Cached function to load model and tokenizer
This ensures the model is loaded only once and reused
"""
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("Loading model...")
model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16)
return tokenizer, model
class VietnameseChatbot:
def __init__(self, model_name='intfloat/multilingual-e5-small'):
"""
Initialize the Vietnamese chatbot with pre-loaded model and conversation data
"""
# Load pre-trained model and tokenizer using cached function
self.tokenizer, self.model = load_model_and_tokenizer(model_name)
# Load comprehensive conversation dataset
self.conversation_data = self._load_conversation_data()
# Pre-compute embeddings for faster response generation
print("Pre-computing conversation embeddings...")
self.conversation_embeddings = self._compute_embeddings()
def _load_conversation_data(self):
"""
Load a comprehensive conversation dataset
"""
return [
# Greeting conversations
{"query": "Xin chào", "response": "Chào bạn! Tôi có thể giúp gì cho bạn?"},
{"query": "Hi", "response": "Xin chào! Tôi là trợ lý AI tiếng Việt."},
{"query": "Chào buổi sáng", "response": "Chào buổi sáng! Chúc bạn một ngày tốt lành."},
# Identity and purpose
{"query": "Bạn là ai?", "response": "Tôi là trợ lý AI được phát triển để hỗ trợ và trò chuyện bằng tiếng Việt."},
{"query": "Bạn từ đâu đến?", "response": "Tôi được phát triển bởi một nhóm kỹ sư AI, và tôn chỉ của tôi là hỗ trợ con người."},
# Small talk
{"query": "Bạn thích gì?", "response": "Tôi thích học hỏi và giúp đỡ mọi người. Mỗi cuộc trò chuyện là một cơ hội để tôi phát triển."},
{"query": "Bạn có thể làm gì?", "response": "Tôi có thể trò chuyện, trả lời câu hỏi, và hỗ trợ bạn trong nhiều tình huống khác nhau."},
# Weather and time
{"query": "Thời tiết hôm nay thế nào?", "response": "Xin lỗi, tôi không thể cung cấp thông tin thời tiết trực tiếp. Bạn có thể kiểm tra ứng dụng dự báo thời tiết."},
{"query": "Bây giờ là mấy giờ?", "response": "Tôi là trợ lý AI, nên không thể xem đồng hồ. Bạn có thể kiểm tra thiết bị của mình."},
# Assistance offers
{"query": "Tôi cần trợ giúp", "response": "Tôi sẵn sàng hỗ trợ bạn. Bạn cần giúp gì?"},
{"query": "Giúp tôi với cái gì đó", "response": "Vâng, tôi có thể hỗ trợ bạn. Hãy cho tôi biết chi tiết hơn."},
# Farewell
{"query": "Tạm biệt", "response": "Hẹn gặp lại! Chúc bạn một ngày tốt đẹp."},
{"query": "Bye", "response": "Tạm biệt! Rất vui được trò chuyện với bạn."},
]
@st.cache_data
def _compute_embeddings(_self): # Add underscore to self parameter
"""
Pre-compute embeddings for conversation queries
Cached to avoid recomputing on every run
"""
def embed_single_text(text, tokenizer, model):
try:
# Tokenize and generate embeddings
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
with torch.no_grad():
model_output = model(**inputs)
# Mean pooling
token_embeddings = model_output[0]
input_mask_expanded = inputs['attention_mask'].unsqueeze(-1).expand(token_embeddings.size()).float()
embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return embeddings.numpy()[0]
except Exception as e:
print(f"Embedding error: {e}")
return None
embeddings = []
for conversation in _self.conversation_data: # Use _self instead of self
embedding = embed_single_text(conversation['query'], _self.tokenizer, _self.model) # Use _self instead of self
if embedding is not None:
embeddings.append(embedding)
return np.array(embeddings)
def embed_text(self, text):
"""
Generate embeddings for input text
"""
try:
# Tokenize and generate embeddings
inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True)
with torch.no_grad():
model_output = self.model(**inputs)
# Mean pooling
embeddings = self.mean_pooling(model_output, inputs['attention_mask'])
return embeddings.numpy()
except Exception as e:
print(f"Embedding error: {e}")
return None
def mean_pooling(self, model_output, attention_mask):
"""
Perform mean pooling on model output
"""
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def get_response(self, user_query):
"""
Find the most similar response from conversation data
"""
try:
# Embed user query
query_embedding = self.embed_text(user_query)
if query_embedding is None:
return "Xin lỗi, đã có lỗi xảy ra khi phân tích câu hỏi của bạn."
# Calculate cosine similarities
similarities = cosine_similarity(query_embedding, self.conversation_embeddings)[0]
# Find most similar response
best_match_index = np.argmax(similarities)
# Return response if similarity is above threshold
if similarities[best_match_index] > 0.5:
return self.conversation_data[best_match_index]['response']
return "Xin lỗi, tôi chưa hiểu rõ câu hỏi của bạn. Bạn có thể diễn đạt lại được không?"
except Exception as e:
print(f"Response generation error: {e}")
return "Đã xảy ra lỗi. Xin vui lòng thử lại."
@st.cache_resource
def initialize_chatbot():
"""
Cached function to initialize the chatbot
This ensures the chatbot is created only once
"""
return VietnameseChatbot()
def main():
st.set_page_config(
page_title="Trợ Lý AI Tiếng Việt",
page_icon="🤖",
)
st.title("🤖 Trợ Lý AI Tiếng Việt")
st.caption("Trò chuyện với trợ lý AI được phát triển bằng mô hình đa ngôn ngữ")
# Initialize chatbot using cached initialization
chatbot = initialize_chatbot()
# Chat history in session state
if 'messages' not in st.session_state:
st.session_state.messages = []
# Sidebar for additional information
with st.sidebar:
st.header("Về Trợ Lý AI")
st.write("Đây là một trợ lý AI được phát triển để hỗ trợ trò chuyện bằng tiếng Việt.")
st.write("Mô hình sử dụng: intfloat/multilingual-e5-small")
# Display chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# User input
if prompt := st.chat_input("Hãy nói gì đó..."):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# Get chatbot response
response = chatbot.get_response(prompt)
# Display chatbot response
with st.chat_message("assistant"):
st.markdown(response)
# Add assistant message to chat history
st.session_state.messages.append({"role": "assistant", "content": response})
if __name__ == "__main__":
main()