Spaces:
Runtime error
Runtime error
import time | |
import os | |
import gradio as gr | |
import torch | |
from transformers import AutoModel, AutoTokenizer | |
import meilisearch | |
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-base-en-v1.5') | |
model = AutoModel.from_pretrained('BAAI/bge-base-en-v1.5') | |
model.eval() | |
cuda_available = torch.cuda.is_available() | |
print(f"CUDA available: {cuda_available}") | |
meilisearch_client = meilisearch.Client("https://edge.meilisearch.com", os.environ["MEILISEARCH_KEY"]) | |
meilisearch_index_name = "docs-embed" | |
meilisearch_index = meilisearch_client.index(meilisearch_index_name) | |
output_options = ["RAG-friendly", "human-friendly"] | |
def search_embeddings(query_text, output_option): | |
start_time_embedding = time.time() | |
query_prefix = 'Represent this sentence for searching code documentation: ' | |
query_tokens = tokenizer(query_prefix + query_text, padding=True, truncation=True, return_tensors='pt', max_length=512) | |
# step1: tokenizer the query | |
with torch.no_grad(): | |
# Compute token embeddings | |
model_output = model(**query_tokens) | |
sentence_embeddings = model_output[0][:, 0] | |
# normalize embeddings | |
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) | |
sentence_embeddings_list = sentence_embeddings[0].tolist() | |
elapsed_time_embedding = time.time() - start_time_embedding | |
# step2: search meilisearch | |
start_time_meilisearch = time.time() | |
response = meilisearch_index.search( | |
"", opt_params={"vector": sentence_embeddings_list, "hybrid": {"semanticRatio": 1.0}, "limit": 5, "attributesToRetrieve": ["text", "source_page_url", "source_page_title", "library"]} | |
) | |
elapsed_time_meilisearch = time.time() - start_time_meilisearch | |
hits = response["hits"] | |
sources_md = [f"[\"{hit['source_page_title']}\"]({hit['source_page_url']})" for hit in hits] | |
sources_md = ", ".join(sources_md) | |
# step3: present the results in markdown | |
if output_option == "human-friendly": | |
md = f"Stats:\n\nembedding time: {elapsed_time_embedding:.2f}s\n\nmeilisearch time: {elapsed_time_meilisearch:.2f}s\n\n---\n\n" | |
for hit in hits: | |
text, source_page_url, source_page_title = hit["text"], hit["source_page_url"], hit["source_page_title"] | |
source = f"src: [\"{source_page_title}\"]({source_page_url})" | |
md += text + f"\n\n{source}\n\n---\n\n" | |
return md, sources_md | |
elif output_option == "RAG-friendly": | |
hit_texts = [hit["text"] for hit in hits] | |
hit_text_str = "\n------------\n".join(hit_texts) | |
return hit_text_str, sources_md | |
demo = gr.Interface( | |
fn=search_embeddings, | |
inputs=[gr.Textbox(label="enter your query", placeholder="Type Markdown here...", lines=10), gr.Radio(label="Select an output option", choices=output_options, value="RAG-friendly")], | |
outputs=[gr.Markdown(), gr.Markdown()], | |
title="HF Docs Emebddings Explorer", | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
demo.launch() |