Spaces:
Build error
Build error
| import gradio as gr | |
| import spaces | |
| import subprocess | |
| import os | |
| import shutil | |
| import string | |
| import random | |
| import glob | |
| from pypdf import PdfReader | |
| from sentence_transformers import SentenceTransformer | |
| model_name = os.environ.get("MODEL", "Snowflake/snowflake-arctic-embed-m") | |
| chunk_size = int(os.environ.get("CHUNK_SIZE", 128)) | |
| default_max_characters = int(os.environ.get("DEFAULT_MAX_CHARACTERS", 258)) | |
| model = SentenceTransformer(model_name) | |
| # model.to(device="cuda") | |
| def embed(queries, chunks) -> dict[str, list[tuple[str, float]]]: | |
| query_embeddings = model.encode(queries, prompt_name="query") | |
| document_embeddings = model.encode(chunks) | |
| scores = query_embeddings @ document_embeddings.T | |
| results = {} | |
| for query, query_scores in zip(queries, scores): | |
| chunk_idxs = [i for i in range(len(chunks))] | |
| # Get a structure like {query: [(chunk_idx, score), (chunk_idx, score), ...]} | |
| results[query] = list(zip(chunk_idxs, query_scores)) | |
| return results | |
| def extract_text_from_pdf(reader): | |
| full_text = "" | |
| for idx, page in enumerate(reader.pages): | |
| text = page.extract_text() | |
| if len(text) > 0: | |
| full_text += f"---- Page {idx} ----\n" + page.extract_text() + "\n\n" | |
| return full_text.strip() | |
| def convert(filename) -> str: | |
| plain_text_filetypes = [ | |
| ".txt", | |
| ".csv", | |
| ".tsv", | |
| ".md", | |
| ".yaml", | |
| ".toml", | |
| ".json", | |
| ".json5", | |
| ".jsonc", | |
| ] | |
| # Already a plain text file that wouldn't benefit from pandoc so return the content | |
| if any(filename.endswith(ft) for ft in plain_text_filetypes): | |
| with open(filename, "r") as f: | |
| return f.read() | |
| if filename.endswith(".pdf"): | |
| return extract_text_from_pdf(PdfReader(filename)) | |
| raise ValueError(f"Unsupported file type: {filename}") | |
| def chunk_to_length(text, max_length=512): | |
| chunks = [] | |
| while len(text) > max_length: | |
| chunks.append(text[:max_length]) | |
| text = text[max_length:] | |
| chunks.append(text) | |
| return chunks | |
| def predict(query, max_characters) -> str: | |
| # Embed the query | |
| query_embedding = model.encode(query, prompt_name="query") | |
| # Initialize a list to store all chunks and their similarities across all documents | |
| all_chunks = [] | |
| # Iterate through all documents | |
| for filename, doc in docs.items(): | |
| # Calculate dot product between query and document embeddings | |
| similarities = doc["embeddings"] @ query_embedding.T | |
| # Add chunks and similarities to the all_chunks list | |
| all_chunks.extend([(filename, chunk, sim) for chunk, sim in zip(doc["chunks"], similarities)]) | |
| # Sort all chunks by similarity | |
| all_chunks.sort(key=lambda x: x[2], reverse=True) | |
| # Initialize a dictionary to store relevant chunks for each document | |
| relevant_chunks = {} | |
| # Add most relevant chunks until max_characters is reached | |
| total_chars = 0 | |
| for filename, chunk, _ in all_chunks: | |
| if total_chars + len(chunk) <= max_characters: | |
| if filename not in relevant_chunks: | |
| relevant_chunks[filename] = [] | |
| relevant_chunks[filename].append(chunk) | |
| total_chars += len(chunk) | |
| else: | |
| break | |
| return relevant_chunks | |
| docs = {} | |
| for filename in glob.glob("sources/*"): | |
| if filename.endswith("add_your_files_here"): | |
| continue | |
| converted_doc = convert(filename) | |
| chunks = chunk_to_length(converted_doc, chunk_size) | |
| embeddings = model.encode(chunks) | |
| docs[filename] = { | |
| "chunks": chunks, | |
| "embeddings": embeddings, | |
| } | |
| gr.Interface( | |
| predict, | |
| theme="Nymbo/Nymbo_Theme", | |
| inputs=[ | |
| gr.Textbox(label="Query asked about the documents"), | |
| gr.Number(label="Max output characters", value=default_max_characters), | |
| ], | |
| outputs=[gr.JSON(label="Relevant chunks")], | |
| title="Hugging Chat RAG Tool", | |
| ).launch() |