File size: 11,128 Bytes
e92e757
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
from typing import Optional, Any
from sqlalchemy import create_engine, text
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
from pydantic import Field
import pymysql

from llama_index.core.storage.chat_store import BaseChatStore
from llama_index.core.llms import ChatMessage
from llama_index.core.memory import ChatMemoryBuffer



class MySQLChatStore(BaseChatStore):
    """
    Implementação de um ChatStore que armazena mensagens em uma tabela MySQL,
    unindo a pergunta do usuário e a resposta do assistente na mesma linha.
    """
    table_name: Optional[str] = Field(default="chatstore", description="Nome da tabela MySQL.")

    _session: Optional[sessionmaker] = None
    _async_session: Optional[sessionmaker] = None

    def __init__(self, session: sessionmaker, async_session: sessionmaker, table_name: str):
        super().__init__(table_name=table_name.lower())
        self._session = session
        self._async_session = async_session
        self._initialize()

    @classmethod
    def from_params(cls, host: str, port: str, database: str, user: str, password: str, table_name: str = "chatstore") -> "MySQLChatStore":
        """
        Cria o sessionmaker síncrono e assíncrono, retornando a instância da classe.
        """
        conn_str = f"mysql+pymysql://{user}:{password}@{host}:{port}/{database}"
        async_conn_str = f"mysql+aiomysql://{user}:{password}@{host}:{port}/{database}"
        session, async_session = cls._connect(conn_str, async_conn_str)
        return cls(session=session, async_session=async_session, table_name=table_name)

    @classmethod
    def _connect(cls, connection_string: str, async_connection_string: str) -> tuple[sessionmaker, sessionmaker]:
        """
        Cria e retorna um sessionmaker síncrono e um sessionmaker assíncrono.
        """
        engine = create_engine(connection_string, echo=False)
        session = sessionmaker(bind=engine)

        async_engine = create_async_engine(async_connection_string)
        async_session = sessionmaker(bind=async_engine, class_=AsyncSession)

        return session, async_session

    def _initialize(self):
        """
        Garante que a tabela exista, com colunas para armazenar user_input e response.
        """
        with self._session() as session:
            session.execute(text(f"""
                CREATE TABLE IF NOT EXISTS {self.table_name} (
                    id INT AUTO_INCREMENT PRIMARY KEY,
                    chat_store_key VARCHAR(255) NOT NULL,
                    user_input TEXT,
                    response TEXT,
                    timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                )
            """))
            session.commit()

    def get_keys(self) -> list[str]:
        """
        Retorna todas as chaves armazenadas.
        """
        with self._session() as session:
            result = session.execute(text(f"""
                SELECT DISTINCT chat_store_key FROM {self.table_name}
            """))
            return [row[0] for row in result.fetchall()]

    def get_messages(self, key: str) -> list[ChatMessage]:
        """
        Retorna a conversa inteira (perguntas e respostas), na ordem de inserção (id).
        Cada linha pode conter o user_input, o response ou ambos (caso já respondido).
        """
        with self._session() as session:
            rows = session.execute(text(f"""
                SELECT user_input, response
                FROM {self.table_name}
                WHERE chat_store_key = :key
                ORDER BY id
            """), {"key": key}).fetchall()

            messages = []
            for user_in, resp in rows:
                if user_in is not None:
                    messages.append(ChatMessage(role='user', content=user_in))
                if resp is not None:
                    messages.append(ChatMessage(role='assistant', content=resp))
            return messages

    def set_messages(self, key: str, messages: list[ChatMessage]) -> None:
        """
        Sobrescreve o histórico de mensagens de uma chave (apaga tudo e insere novamente).
        Se quiser somente acrescentar, use add_message.
        
        Aqui, cada pergunta do usuário gera uma nova linha.
        Assim que encontrar uma mensagem de assistente, atualiza essa mesma linha.
        Se houver assistentes sem usuários, insere normalmente.
        """
        with self._session() as session:
            # Limpa histórico anterior
            session.execute(text(f"""
                DELETE FROM {self.table_name} WHERE chat_store_key = :key
            """), {"key": key})

            # Reinsere na ordem
            current_id = None
            for msg in messages:
                if msg.role == 'user':
                    # Cria nova linha com user_input
                    result = session.execute(text(f"""
                        INSERT INTO {self.table_name} (chat_store_key, user_input)
                        VALUES (:key, :ui)
                    """), {"key": key, "ui": msg.content})
                    # Pega o id do insert
                    current_id = result.lastrowid

                else:
                    # Tenta atualizar a última linha se existir
                    if current_id is not None:
                        session.execute(text(f"""
                            UPDATE {self.table_name}
                            SET response = :resp
                            WHERE id = :id
                        """), {"resp": msg.content, "id": current_id})
                        # Depois de atualizar a linha, zera o current_id
                        current_id = None
                    else:
                        # Se não houver pergunta pendente, insere como nova linha
                        session.execute(text(f"""
                            INSERT INTO {self.table_name} (chat_store_key, response)
                            VALUES (:key, :resp)
                        """), {"key": key, "resp": msg.content})

            session.commit()

    def add_message(self, key: str, message: ChatMessage) -> None:
        """
        Acrescenta uma nova mensagem no fluxo. Se for do usuário, insere nova linha;
        se for do assistente, tenta preencher a linha pendente que não tenha resposta.
        """

        with self._session() as session:
            if message.role == 'user':
                # Sempre cria uma nova linha para mensagens de usuário
                insert_stmt = text(f"""
                    INSERT INTO {self.table_name} (chat_store_key, user_input)
                    VALUES (:key, :ui)
                """)
                session.execute(insert_stmt, {
                    "key": key,
                    "ui": message.content
                })
            else:
                # Tenta encontrar a última linha sem resposta
                
                row = session.execute(text(f"""
                    SELECT id
                    FROM {self.table_name}
                    WHERE chat_store_key = :key
                      AND user_input IS NOT NULL
                      AND response IS NULL
                    ORDER BY id DESC
                    LIMIT 1
               """), {"key": key}).fetchone()

                if row:
                    # Atualiza com a resposta
                    msg_id = row[0]
                    
                    update_stmt = text(f"""
                        UPDATE {self.table_name}
                        SET response = :resp
                        WHERE id = :id
                    """)
                    session.execute(update_stmt, {
                        "resp": message.content,
                        "id": msg_id
                    })
                else:
                    # Se não achar linha pendente, insere como nova
                   
                    insert_stmt = text(f"""
                        INSERT INTO {self.table_name} (chat_store_key, response)
                        VALUES (:key, :resp)
                    """)
                    session.execute(insert_stmt, {
                        "key": key,
                        "resp": message.content
                    })

            session.commit()
            


    def delete_messages(self, key: str) -> None:
        """
        Remove todas as linhas associadas a 'key'.
        """
        with self._session() as session:
            session.execute(text(f"""
                DELETE FROM {self.table_name} WHERE chat_store_key = :key
            """), {"key": key})
            session.commit()

    def delete_last_message(self, key: str) -> Optional[ChatMessage]:
        """
        Apaga a última mensagem da conversa (considerando a ordem de inserção).
        Se a última linha tiver pergunta e resposta, remove primeiro a resposta;
        caso não exista resposta, remove a linha inteira.
        """
        with self._session() as session:
            # Localiza a última linha
            row = session.execute(text(f"""
                SELECT id, user_input, response
                FROM {self.table_name}
                WHERE chat_store_key = :key
                ORDER BY id DESC
                LIMIT 1
            """), {"key": key}).fetchone()

            if not row:
                return None

            row_id, user_in, resp = row

            # Se a linha tiver somente pergunta, apagamos a linha inteira.
            # Se tiver também a resposta, apagamos só a parte do assistente.
            if user_in and resp:
                # Remove a resposta
                session.execute(text(f"""
                    UPDATE {self.table_name}
                    SET response = NULL
                    WHERE id = :id
                """), {"id": row_id})
                session.commit()
                return ChatMessage(role='assistant', content=resp)
            else:
                # Deleta a linha inteira
                session.execute(text(f"""
                    DELETE FROM {self.table_name}
                    WHERE id = :id
                """), {"id": row_id})
                session.commit()

                if user_in:
                    return ChatMessage(role='user', content=user_in)
                elif resp:
                    return ChatMessage(role='assistant', content=resp)
                else:
                    return None

    def delete_message(self, key: str, idx: int) -> Optional[ChatMessage]:
        """
        Deleta a mensagem com base na ordem total do histórico. O índice 'idx' é
        calculado após reconstruir a lista de ChatMessages (user e assistant).
        """
        messages = self.get_messages(key)
        if idx < 0 or idx >= len(messages):
            return None

        removed = messages[idx]

        # Agora precisamos traduzir 'idx' para saber qual registro no banco será modificado.
        # É mais simples recriar todos os dados com set_messages sem a mensagem em 'idx':
        messages.pop(idx)
        self.set_messages(key, messages)

        return removed