WebSearchLLM / app.py
sohoso's picture
Update app.py
02d6c1e verified
raw
history blame contribute delete
No virus
7.03 kB
import time
import os
import multiprocessing
import torch
import requests
import asyncio
import json
import aiohttp
from minivectordb.embedding_model import EmbeddingModel
from minivectordb.vector_database import VectorDatabase
from text_util_en_pt.cleaner import structurize_text, detect_language, Language
from webtextcrawler.webtextcrawler import extract_text_from_url
import gradio as gr
from googlesearch import search
torch.set_num_threads(2)
openrouter_key = os.environ.get("sk-proj-sbgYj5kgnU35y0xjMSEyT3BlbkFJRMigEKIR9YdLqyx4y5bD")
model = EmbeddingModel(use_quantized_onnx_model=True)
def fetch_links(query, max_results=10):
return list(search(query, num_results=max_results))
def fetch_texts(links):
with multiprocessing.Pool(10) as pool:
texts = pool.map(extract_text_from_url, links)
return '\n'.join([t for t in texts if t])
def index_and_search(query, text):
start = time.time()
query_embedding = model.extract_embeddings(query)
# Indexing
vector_db = VectorDatabase()
sentences = [s['sentence'] for s in structurize_text(text)]
for idx, sentence in enumerate(sentences):
sentence_embedding = model.extract_embeddings(sentence)
vector_db.store_embedding(idx + 1, sentence_embedding, {'sentence': sentence})
embedding_time = time.time() - start
# Retrieval
start = time.time()
search_results = vector_db.find_most_similar(query_embedding, k=30)
retrieval_time = time.time() - start
return '\n'.join([s['sentence'] for s in search_results[2]]), embedding_time, retrieval_time
def generate_search_terms(message, lang):
if lang == Language.ptbr:
prompt = f"A partir do texto a seguir, gere alguns termos de pesquisa: \"{message}\"\nSua resposta deve ser apenas o termo de busca mais adequado, e nada mais."
else:
prompt = f"From the following text, generate some search terms: \"{message}\"\nYour answer should be just the most appropriate search term, and nothing else."
url = "https://openrouter.ai/api/v1/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {openrouter_key}"
}
body = {
"stream": False,
"models": [
"mistralai/mistral-7b-instruct:free",
"openchat/openchat-7b:free"
],
"route": "fallback",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": prompt}
]
}
response = requests.post(url, headers=headers, json=body)
response_json = response.json()
try:
return response_json['choices'][0]['message']['content']
except KeyError:
print(f"Error: 'choices' key not found in the response. Response: {response_json}")
return None
async def predict(message, history):
full_response = ""
query_language = detect_language(message)
start = time.time()
full_response += "Generating search terms...\n"
yield full_response
search_query = generate_search_terms(message, query_language)
search_terms_time = time.time() - start
full_response += f"Search terms: \"{search_query}\"\n"
yield full_response
full_response += f"Search terms took: {search_terms_time:.4f} seconds\n"
yield full_response
start = time.time()
full_response += "\nSearching the web...\n"
yield full_response
links = fetch_links(search_query)
websearch_time = time.time() - start
full_response += f"Web search took: {websearch_time:.4f} seconds\n"
yield full_response
full_response += f"Links visited:\n"
yield full_response
for link in links:
full_response += f"{link}\n"
yield full_response
full_response += "\nExtracting text from web pages...\n"
yield full_response
start = time.time()
text = fetch_texts(links)
webcrawl_time = time.time() - start
full_response += f"Text extraction took: {webcrawl_time:.4f} seconds\n"
full_response += "\nIndexing in vector database and building prompt...\n"
yield full_response
context, embedding_time, retrieval_time = index_and_search(message, text)
if query_language == Language.ptbr:
prompt = f"Contexto:\n{context}\n\nResponda: \"{message}\"\n(Você pode utilizar o contexto para responder)\n(Sua resposta deve ser completa, detalhada e bem estruturada)"
else:
prompt = f"Context:\n{context}\n\nAnswer: \"{message}\"\n(You can use the context to answer)\n(Your answer should be complete, detailed and well-structured)"
full_response += f"Embedding time: {embedding_time:.4f} seconds\n"
full_response += f"Retrieval from VectorDB time: {retrieval_time:.4f} seconds\n"
yield full_response
full_response += "\nGenerating response...\n"
yield full_response
full_response += "\nResponse: "
url = "https://openrouter.ai/api/v1/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {openrouter_key}"
}
body = {
"stream": True,
"models": [
"mistralai/mistral-7b-instruct:free",
"openchat/openchat-7b:free"
],
"route": "fallback",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": prompt}
]
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=body) as response:
buffer = "" # A buffer to hold incomplete lines of data
async for chunk in response.content.iter_any():
buffer += chunk.decode()
while "\n" in buffer: # Process as long as there are complete lines in the buffer
line, buffer = buffer.split("\n", 1)
if line.startswith("data: "):
event_data = line[len("data: "):]
if event_data != '[DONE]':
try:
current_text = json.loads(event_data)['choices'][0]['delta']['content']
full_response += current_text
yield full_response
await asyncio.sleep(0.01)
except Exception:
try:
current_text = json.loads(event_data)['choices'][0]['text']
full_response += current_text
yield full_response
await asyncio.sleep(0.01)
except Exception:
pass
gr.ChatInterface(
predict,
title="Live Web Chat",
description="",
retry_btn=None,
undo_btn=None,
examples=[
'What is the current sentiment of the Brazil election?',
'Compare the current economies of China and India?',
'What are new shoe design trends in 2024',
]
).launch()