|
import gradio as gr |
|
import spaces |
|
import subprocess |
|
import os |
|
import shutil |
|
import string |
|
import random |
|
import glob |
|
from pypdf import PdfReader |
|
from sentence_transformers import SentenceTransformer |
|
|
|
model_name = os.environ.get("MODEL", "Snowflake/snowflake-arctic-embed-m") |
|
chunk_size = int(os.environ.get("CHUNK_SIZE", 128)) |
|
default_max_characters = int(os.environ.get("DEFAULT_MAX_CHARACTERS", 258)) |
|
|
|
model = SentenceTransformer(model_name) |
|
|
|
|
|
@spaces.GPU |
|
def embed(queries, chunks) -> dict[str, list[tuple[str, float]]]: |
|
query_embeddings = model.encode(queries, prompt_name="query") |
|
document_embeddings = model.encode(chunks) |
|
|
|
scores = query_embeddings @ document_embeddings.T |
|
results = {} |
|
for query, query_scores in zip(queries, scores): |
|
chunk_idxs = [i for i in range(len(chunks))] |
|
|
|
results[query] = list(zip(chunk_idxs, query_scores)) |
|
|
|
return results |
|
|
|
|
|
def extract_text_from_pdf(reader): |
|
full_text = "" |
|
for idx, page in enumerate(reader.pages): |
|
text = page.extract_text() |
|
if len(text) > 0: |
|
full_text += f"---- Page {idx} ----\n" + page.extract_text() + "\n\n" |
|
|
|
return full_text.strip() |
|
|
|
def convert(filename) -> str: |
|
plain_text_filetypes = [ |
|
".txt", |
|
".csv", |
|
".tsv", |
|
".md", |
|
".yaml", |
|
".toml", |
|
".json", |
|
".json5", |
|
".jsonc", |
|
] |
|
|
|
if any(filename.endswith(ft) for ft in plain_text_filetypes): |
|
with open(filename, "r") as f: |
|
return f.read() |
|
|
|
if filename.endswith(".pdf"): |
|
return extract_text_from_pdf(PdfReader(filename)) |
|
|
|
raise ValueError(f"Unsupported file type: {filename}") |
|
|
|
|
|
def chunk_to_length(text, max_length=512): |
|
chunks = [] |
|
while len(text) > max_length: |
|
chunks.append(text[:max_length]) |
|
text = text[max_length:] |
|
chunks.append(text) |
|
return chunks |
|
|
|
@spaces.GPU |
|
def predict(query, max_characters) -> str: |
|
|
|
query_embedding = model.encode(query, prompt_name="query") |
|
|
|
|
|
all_chunks = [] |
|
|
|
|
|
for filename, doc in docs.items(): |
|
|
|
similarities = doc["embeddings"] @ query_embedding.T |
|
|
|
|
|
all_chunks.extend([(filename, chunk, sim) for chunk, sim in zip(doc["chunks"], similarities)]) |
|
|
|
|
|
all_chunks.sort(key=lambda x: x[2], reverse=True) |
|
|
|
|
|
relevant_chunks = {} |
|
|
|
|
|
total_chars = 0 |
|
for filename, chunk, _ in all_chunks: |
|
if total_chars + len(chunk) <= max_characters: |
|
if filename not in relevant_chunks: |
|
relevant_chunks[filename] = [] |
|
relevant_chunks[filename].append(chunk) |
|
total_chars += len(chunk) |
|
else: |
|
break |
|
|
|
return relevant_chunks |
|
|
|
|
|
|
|
docs = {} |
|
|
|
for filename in glob.glob("sources/*"): |
|
if filename.endswith("add_your_files_here"): |
|
continue |
|
|
|
converted_doc = convert(filename) |
|
|
|
chunks = chunk_to_length(converted_doc, chunk_size) |
|
embeddings = model.encode(chunks) |
|
|
|
docs[filename] = { |
|
"chunks": chunks, |
|
"embeddings": embeddings, |
|
} |
|
|
|
|
|
gr.Interface( |
|
predict, |
|
inputs=[ |
|
gr.Textbox(label="Query asked about the documents"), |
|
gr.Number(label="Max output characters", value=default_max_characters), |
|
], |
|
outputs=[gr.JSON(label="Relevant chunks")], |
|
title="Gradio Docs", |
|
description="This is a gradio docs rag tool for use in hf chat tools", |
|
).launch() |