|
import gradio as gr
|
|
import spaces
|
|
from sentence_transformers import SentenceTransformer
|
|
from sentence_transformers.util import cos_sim
|
|
from sentence_transformers.quantization import quantize_embeddings
|
|
import pymssql
|
|
import os
|
|
import pandas as pd
|
|
from openai import OpenAI
|
|
from pydantic import BaseModel, Field
|
|
import json
|
|
from sentence_transformers import CrossEncoder
|
|
from torch import nn
|
|
import time
|
|
|
|
|
|
SqlServer = os.environ['SQL_SERVER']
|
|
SqlDatabase = os.environ['SQL_DB']
|
|
SqlUser = os.environ['SQL_USER']
|
|
SqlPass = os.environ['SQL_PASS']
|
|
|
|
|
|
OpenaiApiKey = os.environ.get("OPENAI_API_KEY")
|
|
OpenaiBaseUrl = os.environ.get("OPENAI_BASE_URL","https://generativelanguage.googleapis.com/v1beta/openai")
|
|
|
|
|
|
def sql(query,db=SqlDatabase, login_timeout = 120,onConnectionError = None):
|
|
|
|
start_time = time.time()
|
|
|
|
while True:
|
|
try:
|
|
cnxn = pymssql.connect(SqlServer,SqlUser,SqlPass,db, login_timeout = 5)
|
|
break;
|
|
except Exception as e:
|
|
if onConnectionError:
|
|
onConnectionError(e)
|
|
|
|
if time.time() - start_time > login_timeout:
|
|
raise TimeoutError("SQL Connection Timeout");
|
|
|
|
time.sleep(1)
|
|
|
|
|
|
cursor = cnxn.cursor()
|
|
cursor.execute(query)
|
|
|
|
columns = [column[0] for column in cursor.description]
|
|
results = [dict(zip(columns, row)) for row in cursor.fetchall()]
|
|
|
|
return results;
|
|
|
|
|
|
|
|
@spaces.GPU
|
|
def embed(text):
|
|
|
|
query_embedding = Embedder.encode(text)
|
|
return query_embedding.tolist();
|
|
|
|
|
|
@spaces.GPU
|
|
def rerank(query,documents, **kwargs):
|
|
return Reranker.rank(query, documents, **kwargs)
|
|
|
|
ClientOpenai = OpenAI(
|
|
api_key=OpenaiApiKey
|
|
,base_url=OpenaiBaseUrl
|
|
)
|
|
|
|
def llm(messages, ResponseFormat = None, **kwargs):
|
|
|
|
fn = ClientOpenai.chat.completions.create
|
|
|
|
if ResponseFormat:
|
|
fn = ClientOpenai.beta.chat.completions.parse
|
|
|
|
params = {
|
|
'model':"gemini-2.0-flash"
|
|
,'n':1
|
|
,'messages':messages
|
|
,'response_format':ResponseFormat
|
|
}
|
|
|
|
params.update(kwargs);
|
|
|
|
response = fn(**params)
|
|
|
|
if params.get('stream'):
|
|
return response
|
|
|
|
return response.choices[0];
|
|
|
|
def ai(system,user, schema, **kwargs):
|
|
msg = [
|
|
{'role':"system",'content':system}
|
|
,{'role':"user",'content':user}
|
|
]
|
|
|
|
return llm(msg, schema, **kwargs);
|
|
|
|
|
|
def search(text, top = 10, onConnectionError = None):
|
|
|
|
EnglishText = text
|
|
|
|
embeddings = embed(text);
|
|
|
|
query = f"""
|
|
declare @search vector(1024) = '{embeddings}'
|
|
|
|
select top {top}
|
|
*
|
|
from (
|
|
select
|
|
RelPath
|
|
,Similaridade = 1-CosDistance
|
|
,ScriptContent = ChunkContent
|
|
,ContentLength = LEN(ChunkContent)
|
|
,CosDistance
|
|
from
|
|
(
|
|
select
|
|
*
|
|
,CosDistance = vector_distance('cosine',embeddings,@search)
|
|
from
|
|
Scripts
|
|
) C
|
|
) v
|
|
order by
|
|
CosDistance
|
|
"""
|
|
|
|
queryResults = sql(query, onConnectionError = onConnectionError);
|
|
|
|
|
|
|
|
return queryResults
|
|
|
|
|
|
print("Loading embedding model");
|
|
Embedder = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
|
|
|
|
print("Loading reranker");
|
|
Reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-large-v1", activation_fn=nn.Sigmoid())
|
|
|
|
class rfTranslatedText(BaseModel):
|
|
text: str = Field(description='Translated text')
|
|
|
|
class rfGenericText(BaseModel):
|
|
text: str = Field(description='The text result')
|
|
|
|
def ChatFunc(message, history):
|
|
|
|
|
|
|
|
IsNewSearch = True;
|
|
|
|
messages = []
|
|
CurrentTable = None;
|
|
|
|
def ChatBotOutput():
|
|
return [messages,CurrentTable]
|
|
|
|
class BotMessage():
|
|
def __init__(self, *args, **kwargs):
|
|
self.Message = gr.ChatMessage(*args, **kwargs)
|
|
self.LastContent = None
|
|
messages.append(self.Message);
|
|
|
|
def __call__(self, content, noNewLine = False):
|
|
if not content:
|
|
return;
|
|
|
|
self.Message.content += content;
|
|
self.LastContent = None;
|
|
|
|
if not noNewLine:
|
|
self.Message.content += "\n";
|
|
|
|
return ChatBotOutput();
|
|
|
|
def update(self,content):
|
|
|
|
if not self.LastContent:
|
|
self.LastContent = self.Message.content
|
|
|
|
self.Message.content = self.LastContent +" "+content+"\n";
|
|
|
|
return ChatBotOutput();
|
|
|
|
def done(self):
|
|
self.Message.metadata['status'] = "done";
|
|
return ChatBotOutput();
|
|
|
|
def Reply(msg):
|
|
m = BotMessage(msg);
|
|
return ChatBotOutput();
|
|
|
|
m = BotMessage("",metadata={"title":"Procurando scripts...","status":"pending"});
|
|
|
|
|
|
def OnConnError(err):
|
|
print("Sql connection error:", err)
|
|
|
|
|
|
try:
|
|
|
|
if IsNewSearch:
|
|
|
|
yield m("Melhorando o prompt")
|
|
|
|
|
|
LLMResult = ai("""
|
|
Translate the user's message to English.
|
|
The message is a question related to a SQL Server T-SQL script that the user is searching for.
|
|
You only need to translate the message to English.
|
|
""",message, rfTranslatedText)
|
|
Question = LLMResult.message.parsed.text;
|
|
|
|
yield m(f"Melhorado: {Question}")
|
|
|
|
yield m("Procurando scripts...")
|
|
try:
|
|
FoundScripts = search(message, onConnectionError = OnConnError)
|
|
except:
|
|
yield m("Houve alguma falha ao executar a consulta no banco. Tente novamente. Se persistir, veja orientações na aba Help!")
|
|
return;
|
|
|
|
yield m("Fazendo o rerank");
|
|
doclist = [doc['ScriptContent'] for doc in FoundScripts]
|
|
|
|
|
|
for score in rerank(Question, doclist):
|
|
i = score['corpus_id'];
|
|
FoundScripts[i]['rank'] = str(score['score'])
|
|
|
|
RankedScripts = sorted(FoundScripts, key=lambda item: float(item['rank']), reverse=True)
|
|
|
|
|
|
|
|
ScriptTable = []
|
|
for script in RankedScripts:
|
|
link = "https://github.com/rrg92/sqlserver-lib/tree/main/" + script['RelPath']
|
|
script['link'] = link;
|
|
|
|
ScriptTable.append({
|
|
'Link': f'<a title="{link}" href="{link}" target="_blank">{script["RelPath"]}</a>'
|
|
,'Length': script['ContentLength']
|
|
,'Cosine Similarity': script['Similaridade']
|
|
,'Rank': script['rank']
|
|
})
|
|
|
|
|
|
CurrentTable = pd.DataFrame(ScriptTable)
|
|
yield m("Script encontrados, a aba Rank atualizada!")
|
|
|
|
|
|
WaitMessage = ai("""
|
|
You will analyze some T-SQL scripts in order to check which is best for the user.
|
|
You found scripts, presented them to the user, and now will do some work that takes time.
|
|
Generate a message to tell the user to wait while you work, in the same language as the user.
|
|
You will receive the question the user sent that triggered this process.
|
|
Use the user’s original question to customize the message.
|
|
""",message,rfGenericText).message.parsed.text
|
|
|
|
yield Reply(WaitMessage);
|
|
|
|
yield m(f"Analisando scripts...")
|
|
|
|
|
|
ResultJson = json.dumps(RankedScripts);
|
|
|
|
SystemPrompt = f"""
|
|
You are an assistant that helps users find the best T-SQL scripts for their specific needs.
|
|
These scripts were created by Rodrigo Ribeiro Gomes and are publicly available for users to query and use.
|
|
|
|
The user will provide a short description of what they are looking for, and your task is to present the most relevant scripts.
|
|
|
|
To assist you, here is a JSON object with the top matches based on the current user query:
|
|
{ResultJson}
|
|
|
|
---
|
|
This JSON contains all the scripts that matched the user's input.
|
|
Analyze each script's name and content, and create a ranked summary of the best recommendations according to the user's need.
|
|
|
|
Only use the information available in the provided JSON. Do not reference or mention anything outside of this list.
|
|
You can include parts of the scripts in your answer to illustrate or give usage examples based on the user's request.
|
|
|
|
Re-rank the results if necessary, presenting them from the most to the least relevant.
|
|
You may filter out scripts that appear unrelated to the user query.
|
|
|
|
Respond in the user's original language.
|
|
---
|
|
### Output Rules
|
|
|
|
- Review each script and evaluate how well it matches the user’s request.
|
|
- Summarize each script, ordering from the most relevant to the least relevant.
|
|
- Write personalized and informative review text for each recommendation.
|
|
- If applicable, explain how the user should run the script, including parameters or sections (like `WHERE` clauses) they might need to customize.
|
|
- When referencing a script, include the link provided in the JSON — all scripts are hosted on GitHub.
|
|
"""
|
|
|
|
ScriptPrompt = [
|
|
{ 'role':'system', 'content':SystemPrompt }
|
|
,{ 'role':'user', 'content':message }
|
|
]
|
|
|
|
|
|
|
|
|
|
llmanswer = llm(ScriptPrompt, stream = True)
|
|
yield m.done()
|
|
|
|
answer = BotMessage("");
|
|
|
|
for chunk in llmanswer:
|
|
content = chunk.choices[0].delta.content
|
|
yield answer(content, noNewLine = True)
|
|
finally:
|
|
yield m.done()
|
|
|
|
|
|
resultTable = gr.Dataframe(datatype = ['html','number','number'], interactive = False, show_search = "search");
|
|
|
|
with gr.Blocks(fill_height=True) as demo:
|
|
|
|
with gr.Column():
|
|
|
|
with gr.Tab("Chat", scale = 1):
|
|
ChatTextBox = gr.Textbox(max_length = 100, info = "Que script precisa?", submit_btn = True);
|
|
|
|
gr.ChatInterface(
|
|
ChatFunc
|
|
,additional_outputs=[resultTable]
|
|
,type="messages"
|
|
,textbox = ChatTextBox
|
|
)
|
|
|
|
with gr.Tab("Rank"):
|
|
resultTable.render();
|
|
|
|
with gr.Tab("Help"):
|
|
gr.Markdown("""
|
|
Bem-vindo ao Space SQL Server Lib
|
|
Este space permite que você encontre scripts SQL do https://github.com/rrg92/sqlserver-lib com base nas suas necessidades
|
|
|
|
|
|
## Instruções de Uso
|
|
Apenas descreva o que você precisa no campo de chat e aguarde a IA analisar os melhores scripts do repositório para você.
|
|
Além de uma explicação feita pela IA, a aba "Rank", contém uma tabela com os scripts encontrados e seus respectictos rank.
|
|
A coluna Cosine Similarity é o nível de similaridades da sua pergunta com o script (calculado baseado nos embeddings do seu texto e do script).
|
|
A coluna Rank é um score onde quanto maior o valor mais relacionado ao seu texto o script é (calculado usando rerank/cross encoders). A tabela vem ordenada por essa coluna.
|
|
|
|
|
|
## Fluxo básico
|
|
- Quando você digita o texto, iremos fazer uma busca usando embeddings em um banco Azure SQL Database
|
|
- Os embeddings são calculados usando um modelo carregado no proprio script, via ZeroGPU.
|
|
- Os top 20 resultados mais similares são retornados e então um rerank é feito
|
|
- O rerank também é feito por um modelo que roda no próprio script, em ZeroGPU
|
|
- Estes resultados ordenados por reran, são então enviados ao LLM para que analise e monte uma resposta para você.
|
|
|
|
|
|
## Sobre o uso e eventuais erros
|
|
Eu tento usar o máximo de recursos FREE e open possíveis, e portanto, eventualmente, o Space pode falhar por alguma limitação.
|
|
Alguns possíveis pontos de falha:
|
|
- Créditos free do google ou rate limit
|
|
- Azure SQL database offline devido a crédito ou ao auto-pause (devido ao free tier)
|
|
- Limites de uso do ZeroGPU do Hugging Face.
|
|
|
|
Você pode me procurar no [linkedin](https://www.linkedin.com/in/rodrigoribeirogomes/), caso receba erroslimit
|
|
|
|
""")
|
|
|
|
with gr.Tab("Other", visible = False):
|
|
txtEmbed = gr.Text(label="Text to embed", visible=False)
|
|
btnEmbed = gr.Button("embed");
|
|
btnEmbed.click(embed, [txtEmbed], [txtEmbed])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
demo.launch(
|
|
share=False,
|
|
debug=False,
|
|
server_port=7860,
|
|
server_name="0.0.0.0",
|
|
allowed_paths=[]
|
|
)
|
|
|