File size: 7,484 Bytes
dc14176
a4c3bcc
 
dc14176
 
 
7f79d8b
985ad3e
dc14176
7f79d8b
dc14176
7f79d8b
dc14176
7f79d8b
 
 
 
 
 
dc14176
7f79d8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc14176
7f79d8b
 
 
 
 
 
 
 
 
 
 
 
dc14176
 
 
 
 
 
7f79d8b
dc14176
 
7f79d8b
dc14176
 
 
 
 
 
 
7f79d8b
dc14176
 
 
 
 
 
 
7f79d8b
dc14176
 
 
 
 
 
 
 
 
7f79d8b
dc14176
 
7f79d8b
dc14176
 
 
 
 
 
 
 
7f79d8b
dc14176
 
 
 
 
7f79d8b
 
 
 
dc14176
 
7f79d8b
47ed12b
7f79d8b
dc14176
47ed12b
dc14176
 
 
47ed12b
7f79d8b
 
 
 
 
 
dc14176
 
 
 
47ed12b
dc14176
 
 
 
a4c3bcc
dc14176
 
 
47ed12b
dc14176
 
47ed12b
dc14176
 
 
47ed12b
dc14176
 
 
a4c3bcc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import json

class VietnameseChatbot:
    def __init__(self, model_name='intfloat/multilingual-e5-small'):
        """
        Initialize the Vietnamese chatbot with pre-loaded model and conversation data
        """
        # Load pre-trained model and tokenizer
        print("Loading tokenizer...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        print("Loading model...")
        self.model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16)
        
        # Load comprehensive conversation dataset
        self.conversation_data = self._load_conversation_data()
        
        # Pre-compute embeddings for faster response generation
        print("Pre-computing conversation embeddings...")
        self.conversation_embeddings = self._precompute_embeddings()

    def _load_conversation_data(self):
        """
        Load a comprehensive conversation dataset
        """
        return [
            # Greeting conversations
            {"query": "Xin chào", "response": "Chào bạn! Tôi có thể giúp gì cho bạn?"},
            {"query": "Hi", "response": "Xin chào! Tôi là trợ lý AI tiếng Việt."},
            {"query": "Chào buổi sáng", "response": "Chào buổi sáng! Chúc bạn một ngày tốt lành."},
            
            # Identity and purpose
            {"query": "Bạn là ai?", "response": "Tôi là trợ lý AI được phát triển để hỗ trợ và trò chuyện bằng tiếng Việt."},
            {"query": "Bạn từ đâu đến?", "response": "Tôi được phát triển bởi một nhóm kỹ sư AI, và tôn chỉ của tôi là hỗ trợ con người."},
            
            # Small talk
            {"query": "Bạn thích gì?", "response": "Tôi thích học hỏi và giúp đỡ mọi người. Mỗi cuộc trò chuyện là một cơ hội để tôi phát triển."},
            {"query": "Bạn có thể làm gì?", "response": "Tôi có thể trò chuyện, trả lời câu hỏi, và hỗ trợ bạn trong nhiều tình huống khác nhau."},
            
            # Weather and time
            {"query": "Thời tiết hôm nay thế nào?", "response": "Xin lỗi, tôi không thể cung cấp thông tin thời tiết trực tiếp. Bạn có thể kiểm tra ứng dụng dự báo thời tiết."},
            {"query": "Bây giờ là mấy giờ?", "response": "Tôi là trợ lý AI, nên không thể xem đồng hồ. Bạn có thể kiểm tra thiết bị của mình."},
            
            # Assistance offers
            {"query": "Tôi cần trợ giúp", "response": "Tôi sẵn sàng hỗ trợ bạn. Bạn cần giúp gì?"},
            {"query": "Giúp tôi với cái gì đó", "response": "Vâng, tôi có thể hỗ trợ bạn. Hãy cho tôi biết chi tiết hơn."},
            
            # Farewell
            {"query": "Tạm biệt", "response": "Hẹn gặp lại! Chúc bạn một ngày tốt đẹp."},
            {"query": "Bye", "response": "Tạm biệt! Rất vui được trò chuyện với bạn."},
        ]

    def _precompute_embeddings(self):
        """
        Pre-compute embeddings for all conversation queries
        """
        embeddings = []
        for item in self.conversation_data:
            embedding = self.embed_text(item['query'])
            if embedding is not None:
                embeddings.append(embedding[0])
        return np.array(embeddings)

    def embed_text(self, text):
        """
        Generate embeddings for input text
        """
        try:
            # Tokenize and generate embeddings
            inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True)
            
            with torch.no_grad():
                model_output = self.model(**inputs)
            
            # Mean pooling
            embeddings = self.mean_pooling(model_output, inputs['attention_mask'])
            return embeddings.numpy()
        except Exception as e:
            print(f"Embedding error: {e}")
            return None

    def mean_pooling(self, model_output, attention_mask):
        """
        Perform mean pooling on model output
        """
        token_embeddings = model_output[0]
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    def get_response(self, user_query):
        """
        Find the most similar response from conversation data
        """
        try:
            # Embed user query
            query_embedding = self.embed_text(user_query)
            
            if query_embedding is None:
                return "Xin lỗi, đã có lỗi xảy ra khi phân tích câu hỏi của bạn."
            
            # Calculate cosine similarities
            similarities = cosine_similarity(query_embedding, self.conversation_embeddings)[0]
            
            # Find most similar response
            best_match_index = np.argmax(similarities)
            
            # Return response if similarity is above threshold
            if similarities[best_match_index] > 0.5:
                return self.conversation_data[best_match_index]['response']
            
            return "Xin lỗi, tôi chưa hiểu rõ câu hỏi của bạn. Bạn có thể diễn đạt lại được không?"
        except Exception as e:
            print(f"Response generation error: {e}")
            return "Đã xảy ra lỗi. Xin vui lòng thử lại."

def main():
    st.set_page_config(
        page_title="Trợ Lý AI Tiếng Việt",
        page_icon="🤖",
    )

    st.title("🤖 Trợ Lý AI Tiếng Việt")
    st.caption("Trò chuyện với trợ lý AI được phát triển bằng mô hình đa ngôn ngữ")
    
    # Initialize chatbot (this will pre-load models and embeddings)
    chatbot = VietnameseChatbot()
    
    # Chat history in session state
    if 'messages' not in st.session_state:
        st.session_state.messages = []
    
    # Sidebar for additional information
    with st.sidebar:
        st.header("Về Trợ Lý AI")
        st.write("Đây là một trợ lý AI được phát triển để hỗ trợ trò chuyện bằng tiếng Việt.")
        st.write("Mô hình sử dụng: intfloat/multilingual-e5-small")
    
    # Display chat messages
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])
    
    # User input
    if prompt := st.chat_input("Hãy nói gì đó..."):
        # Add user message to chat history
        st.session_state.messages.append({"role": "user", "content": prompt})
        
        # Display user message
        with st.chat_message("user"):
            st.markdown(prompt)
        
        # Get chatbot response
        response = chatbot.get_response(prompt)
        
        # Display chatbot response
        with st.chat_message("assistant"):
            st.markdown(response)
        
        # Add assistant message to chat history
        st.session_state.messages.append({"role": "assistant", "content": response})

if __name__ == "__main__":
    main()