Liam Dyer
add filenames because of a gradio client bug
d6c1ef6 unverified
import gradio as gr
import spaces
import subprocess
import os
import shutil
import string
import random
from pypdf import PdfReader
import ocrmypdf
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("Snowflake/snowflake-arctic-embed-m")
model.to(device="cuda")
@spaces.GPU
def embed(queries, chunks) -> dict[str, list[tuple[str, float]]]:
query_embeddings = model.encode(queries, prompt_name="query")
document_embeddings = model.encode(chunks)
scores = query_embeddings @ document_embeddings.T
results = {}
for query, query_scores in zip(queries, scores):
chunk_idxs = [i for i in range(len(chunks))]
# Get a structure like {query: [(chunk_idx, score), (chunk_idx, score), ...]}
results[query] = list(zip(chunk_idxs, query_scores))
return results
def random_word(length):
letters = string.ascii_lowercase
return "".join(random.choice(letters) for _ in range(length))
def convert_pdf(input_file) -> str:
reader = PdfReader(input_file)
text = extract_text_from_pdf(reader)
# Check if there are any images
image_count = 0
for page in reader.pages:
image_count += len(page.images)
# If there are images and not much content, perform OCR on the document
if image_count > 0 and len(text) < 1000:
out_pdf_file = input_file.replace(".pdf", "_ocr.pdf")
ocrmypdf.ocr(input_file, out_pdf_file, force_ocr=True)
# Re-extract text
text = extract_text_from_pdf(PdfReader(input_file))
# Delete the OCR file
os.remove(out_pdf_file)
return text
def extract_text_from_pdf(reader):
full_text = ""
for idx, page in enumerate(reader.pages):
text = page.extract_text()
if len(text) > 0:
full_text += f"---- Page {idx} ----\n" + page.extract_text() + "\n\n"
return full_text.strip()
def convert_pandoc(input_file, filename) -> str:
# Temporarily copy the file
shutil.copyfile(input_file, filename)
# Convert the file to markdown with pandoc
output_file = f"{random_word(16)}.md"
result = subprocess.call(["pandoc", filename, "-t", "markdown", "-o", output_file])
if result != 0:
raise ValueError("Error converting file to markdown with pandoc")
# Read the file and delete temporary files
with open(output_file, "r") as f:
markdown = f.read()
os.remove(output_file)
os.remove(filename)
return markdown
@spaces.GPU
def convert(input_file, filename) -> str:
plain_text_filetypes = [
".txt",
".csv",
".tsv",
".md",
".yaml",
".toml",
".json",
".json5",
".jsonc",
]
# Already a plain text file that wouldn't benefit from pandoc so return the content
if any(filename.endswith(ft) for ft in plain_text_filetypes):
with open(input_file, "r") as f:
return f.read()
if filename.endswith(".pdf"):
return convert_pdf(input_file)
return convert_pandoc(input_file, filename)
def chunk_to_length(text, max_length=512):
chunks = []
while len(text) > max_length:
chunks.append(text[:max_length])
text = text[max_length:]
chunks.append(text)
return chunks
@spaces.GPU
def predict(queries, documents, document_filenames, max_characters) -> list[list[str]]:
queries = queries.split("\n")
document_filenames = document_filenames.split("\n")
# Convert the documents to text
converted_docs = [
convert(doc, filename) for doc, filename in zip(documents, document_filenames)
]
# Return if the total length is less than the max characters
total_doc_lengths = sum([len(doc) for doc in converted_docs])
if total_doc_lengths < max_characters:
return [[doc] for doc, _ in converted_docs]
# Embed the documents in 512 character chunks
chunked_docs = [chunk_to_length(doc, 512) for doc in converted_docs]
embedded_docs = [embed(queries, chunks) for chunks in chunked_docs]
# Get a structure like {query: [(doc_idx, chunk_idx, score), (doc_idx, chunk_idx, score), ...]}
query_embeddings = {}
for doc_idx, embedded_doc in enumerate(embedded_docs):
for query, doc_scores in embedded_doc.items():
doc_scores_with_doc = [
(doc_idx, chunk_idx, score) for (chunk_idx, score) in doc_scores
]
if query not in query_embeddings:
query_embeddings[query] = []
query_embeddings[query] = query_embeddings[query] + doc_scores_with_doc
# Sort the embeddings by score
for query, doc_scores in query_embeddings.items():
query_embeddings[query] = sorted(doc_scores, key=lambda x: x[2], reverse=True)
# Choose the top embedding from each query until we reach the max characters
# Getting a structure like [[chunk, ...]]
document_embeddings = [[] for _ in range(len(documents))]
total_chars = 0
while (
total_chars < max_characters
and sum([len(x) for x in query_embeddings.values()]) > 0
):
for query, doc_scores in query_embeddings.items():
if len(doc_scores) == 0:
continue
# Grab the top score for the query
doc_idx, chunk_idx, _ = doc_scores.pop(0)
# Ensure we have space
chunk = chunked_docs[doc_idx][chunk_idx]
if total_chars + len(chunk) > max_characters:
continue
# Ensure we haven't already added this chunk from this document
if chunk_idx in document_embeddings[doc_idx]:
continue
# Add the chunk
document_embeddings[doc_idx].append(chunk_idx)
total_chars += len(chunk)
# Get the actual text for the chunks
document_embeddings = [
[chunked_docs[doc_idx][chunk_idx] for chunk_idx in chunks]
for doc_idx, chunks in enumerate(document_embeddings)
]
return document_embeddings
# We accept a filename because the gradio JS interface removes this information
# and it's critical for choosing the correct processing pipeline
gr.Interface(
predict,
inputs=[
gr.Textbox(label="Queries separated by newline"),
gr.File(label="Upload File", file_count="multiple"),
gr.Textbox(label="Filenames separated by newline"),
gr.Number(label="Max output characters", value=16384),
],
outputs=[gr.JSON(label="Embedded documents")],
).launch()