mishig's picture
mishig HF staff
Update app.py
be701db verified
raw
history blame
2.78 kB
import time
import os
import gradio as gr
import torch
from transformers import AutoModel, AutoTokenizer
import meilisearch
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-base-en-v1.5')
model = AutoModel.from_pretrained('BAAI/bge-base-en-v1.5')
model.eval()
cuda_available = torch.cuda.is_available()
print(f"CUDA available: {cuda_available}")
meilisearch_client = meilisearch.Client("https://edge.meilisearch.com", os.environ["MEILISEARCH_KEY"])
meilisearch_index_name = "docs-embed"
meilisearch_index = meilisearch_client.index(meilisearch_index_name)
output_options = ["RAG-friendly", "human-friendly"]
def search_embeddings(query_text, output_option):
start_time_embedding = time.time()
query_prefix = 'Represent this sentence for searching code documentation: '
query_tokens = tokenizer(query_prefix + query_text, padding=True, truncation=True, return_tensors='pt', max_length=512)
# step1: tokenizer the query
with torch.no_grad():
# Compute token embeddings
model_output = model(**query_tokens)
sentence_embeddings = model_output[0][:, 0]
# normalize embeddings
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
sentence_embeddings_list = sentence_embeddings[0].tolist()
elapsed_time_embedding = time.time() - start_time_embedding
# step2: search meilisearch
start_time_meilisearch = time.time()
response = meilisearch_index.search(
"", opt_params={"vector": sentence_embeddings_list, "hybrid": {"semanticRatio": 1.0}, "limit": 5, "attributesToRetrieve": ["text", "source", "library"]}
)
elapsed_time_meilisearch = time.time() - start_time_meilisearch
hits = response["hits"]
# step3: present the results in markdown
if output_option == "human-friendly":
md = f"Stats:\n\nembedding time: {elapsed_time_embedding:.2f}s\n\nmeilisearch time: {elapsed_time_meilisearch:.2f}s\n\n---\n\n"
for hit in hits:
text, source, library = hit["text"], hit["source"], hit["library"]
source = f"[source](https://huggingface.co/docs/{library}/{source})"
md += text + f"\n\n{source}\n\n---\n\n"
return md
elif output_option == "RAG-friendly":
hit_texts = [hit["text"] for hit in hits]
hit_text_str = "\n------------\n".join(hit_texts)
return hit_text_str
demo = gr.Interface(
fn=search_embeddings,
inputs=[gr.Textbox(label="enter your query", placeholder="Type Markdown here...", lines=10), gr.Radio(label="Select an output option", choices=output_options, value="RAG-friendly")],
outputs=gr.Markdown(),
title="HF Docs Emebddings Explorer",
allow_flagging="never"
)
if __name__ == "__main__":
demo.launch()