|
import gradio as gr |
|
import spaces |
|
import subprocess |
|
import os |
|
import shutil |
|
import string |
|
import random |
|
from pypdf import PdfReader |
|
import ocrmypdf |
|
from sentence_transformers import SentenceTransformer |
|
|
|
model = SentenceTransformer("Snowflake/snowflake-arctic-embed-m") |
|
model.to(device="cuda") |
|
|
|
|
|
@spaces.GPU |
|
def embed(queries, chunks) -> dict[str, list[tuple[str, float]]]: |
|
query_embeddings = model.encode(queries, prompt_name="query") |
|
document_embeddings = model.encode(chunks) |
|
|
|
scores = query_embeddings @ document_embeddings.T |
|
results = {} |
|
for query, query_scores in zip(queries, scores): |
|
chunk_idxs = [i for i in range(len(chunks))] |
|
|
|
results[query] = list(zip(chunk_idxs, query_scores)) |
|
|
|
return results |
|
|
|
|
|
def random_word(length): |
|
letters = string.ascii_lowercase |
|
return "".join(random.choice(letters) for _ in range(length)) |
|
|
|
|
|
def convert_pdf(input_file) -> str: |
|
reader = PdfReader(input_file) |
|
text = extract_text_from_pdf(reader) |
|
|
|
|
|
image_count = 0 |
|
for page in reader.pages: |
|
image_count += len(page.images) |
|
|
|
|
|
if image_count > 0 and len(text) < 1000: |
|
out_pdf_file = input_file.replace(".pdf", "_ocr.pdf") |
|
ocrmypdf.ocr(input_file, out_pdf_file, force_ocr=True) |
|
|
|
|
|
text = extract_text_from_pdf(PdfReader(input_file)) |
|
|
|
|
|
os.remove(out_pdf_file) |
|
|
|
return text |
|
|
|
|
|
def extract_text_from_pdf(reader): |
|
full_text = "" |
|
for idx, page in enumerate(reader.pages): |
|
text = page.extract_text() |
|
if len(text) > 0: |
|
full_text += f"---- Page {idx} ----\n" + page.extract_text() + "\n\n" |
|
|
|
return full_text.strip() |
|
|
|
|
|
def convert_pandoc(input_file, filename) -> str: |
|
|
|
shutil.copyfile(input_file, filename) |
|
|
|
|
|
output_file = f"{random_word(16)}.md" |
|
result = subprocess.call(["pandoc", filename, "-t", "markdown", "-o", output_file]) |
|
if result != 0: |
|
raise ValueError("Error converting file to markdown with pandoc") |
|
|
|
|
|
with open(output_file, "r") as f: |
|
markdown = f.read() |
|
os.remove(output_file) |
|
os.remove(filename) |
|
|
|
return markdown |
|
|
|
|
|
@spaces.GPU |
|
def convert(input_file, filename) -> str: |
|
plain_text_filetypes = [ |
|
".txt", |
|
".csv", |
|
".tsv", |
|
".md", |
|
".yaml", |
|
".toml", |
|
".json", |
|
".json5", |
|
".jsonc", |
|
] |
|
|
|
if any(filename.endswith(ft) for ft in plain_text_filetypes): |
|
with open(input_file, "r") as f: |
|
return f.read() |
|
|
|
if filename.endswith(".pdf"): |
|
return convert_pdf(input_file) |
|
|
|
return convert_pandoc(input_file, filename) |
|
|
|
|
|
def chunk_to_length(text, max_length=512): |
|
chunks = [] |
|
while len(text) > max_length: |
|
chunks.append(text[:max_length]) |
|
text = text[max_length:] |
|
chunks.append(text) |
|
return chunks |
|
|
|
|
|
@spaces.GPU |
|
def predict(queries, documents, document_filenames, max_characters) -> list[list[str]]: |
|
queries = queries.split("\n") |
|
document_filenames = document_filenames.split("\n") |
|
|
|
|
|
converted_docs = [ |
|
convert(doc, filename) for doc, filename in zip(documents, document_filenames) |
|
] |
|
|
|
|
|
total_doc_lengths = sum([len(doc) for doc in converted_docs]) |
|
if total_doc_lengths < max_characters: |
|
return [[doc] for doc, _ in converted_docs] |
|
|
|
|
|
chunked_docs = [chunk_to_length(doc, 512) for doc in converted_docs] |
|
embedded_docs = [embed(queries, chunks) for chunks in chunked_docs] |
|
|
|
|
|
query_embeddings = {} |
|
for doc_idx, embedded_doc in enumerate(embedded_docs): |
|
for query, doc_scores in embedded_doc.items(): |
|
doc_scores_with_doc = [ |
|
(doc_idx, chunk_idx, score) for (chunk_idx, score) in doc_scores |
|
] |
|
if query not in query_embeddings: |
|
query_embeddings[query] = [] |
|
query_embeddings[query] = query_embeddings[query] + doc_scores_with_doc |
|
|
|
|
|
for query, doc_scores in query_embeddings.items(): |
|
query_embeddings[query] = sorted(doc_scores, key=lambda x: x[2], reverse=True) |
|
|
|
|
|
|
|
document_embeddings = [[] for _ in range(len(documents))] |
|
total_chars = 0 |
|
while ( |
|
total_chars < max_characters |
|
and sum([len(x) for x in query_embeddings.values()]) > 0 |
|
): |
|
for query, doc_scores in query_embeddings.items(): |
|
if len(doc_scores) == 0: |
|
continue |
|
|
|
|
|
doc_idx, chunk_idx, _ = doc_scores.pop(0) |
|
|
|
|
|
chunk = chunked_docs[doc_idx][chunk_idx] |
|
if total_chars + len(chunk) > max_characters: |
|
continue |
|
|
|
|
|
if chunk_idx in document_embeddings[doc_idx]: |
|
continue |
|
|
|
|
|
document_embeddings[doc_idx].append(chunk_idx) |
|
total_chars += len(chunk) |
|
|
|
|
|
document_embeddings = [ |
|
[chunked_docs[doc_idx][chunk_idx] for chunk_idx in chunks] |
|
for doc_idx, chunks in enumerate(document_embeddings) |
|
] |
|
|
|
return document_embeddings |
|
|
|
|
|
|
|
|
|
gr.Interface( |
|
predict, |
|
inputs=[ |
|
gr.Textbox(label="Queries separated by newline"), |
|
gr.File(label="Upload File", file_count="multiple"), |
|
gr.Textbox(label="Filenames separated by newline"), |
|
gr.Number(label="Max output characters", value=16384), |
|
], |
|
outputs=[gr.JSON(label="Embedded documents")], |
|
).launch() |
|
|