Liam Dyer commited on
Commit
83f2c7b
1 Parent(s): 15d68b8

letting it rip bud

Browse files
Files changed (2) hide show
  1. app.py +102 -23
  2. requirements.txt +1 -0
app.py CHANGED
@@ -7,6 +7,34 @@ import string
7
  import random
8
  from pypdf import PdfReader
9
  import ocrmypdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  def random_word(length):
@@ -14,9 +42,8 @@ def random_word(length):
14
  return "".join(random.choice(letters) for _ in range(length))
15
 
16
 
17
- def convert_pdf(input_file):
18
  reader = PdfReader(input_file)
19
- metadata = extract_metadata_from_pdf(reader)
20
  text = extract_text_from_pdf(reader)
21
 
22
  # Check if there are any images
@@ -35,7 +62,7 @@ def convert_pdf(input_file):
35
  # Delete the OCR file
36
  os.remove(out_pdf_file)
37
 
38
- return text, metadata
39
 
40
 
41
  def extract_text_from_pdf(reader):
@@ -48,17 +75,7 @@ def extract_text_from_pdf(reader):
48
  return full_text.strip()
49
 
50
 
51
- def extract_metadata_from_pdf(reader):
52
- return {
53
- "author": reader.metadata.author,
54
- "creator": reader.metadata.creator,
55
- "producer": reader.metadata.producer,
56
- "subject": reader.metadata.subject,
57
- "title": reader.metadata.title,
58
- }
59
-
60
-
61
- def convert_pandoc(input_file, filename):
62
  # Temporarily copy the file
63
  shutil.copyfile(input_file, filename)
64
 
@@ -78,7 +95,7 @@ def convert_pandoc(input_file, filename):
78
 
79
 
80
  @spaces.GPU
81
- def convert(input_file, filename):
82
  plain_text_filetypes = [
83
  ".txt",
84
  ".csv",
@@ -91,23 +108,85 @@ def convert(input_file, filename):
91
  ".jsonc",
92
  ]
93
  # Already a plain text file that wouldn't benefit from pandoc so return the content
94
- if any(filename.endswith(ft) for ft in plain_text_filetypes):
95
  with open(input_file, "r") as f:
96
- return f.read(), {}
97
 
98
- if filename.endswith(".pdf"):
99
  return convert_pdf(input_file)
100
 
101
- return convert_pandoc(input_file, filename), {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
 
104
  # We accept a filename because the gradio JS interface removes this information
105
  # and it's critical for choosing the correct processing pipeline
106
  gr.Interface(
107
  convert,
108
- inputs=[gr.File(label="Upload File", type="filepath"), gr.Text(label="Filename")],
109
- outputs=[
110
- gr.Text(label="Markdown"),
111
- gr.JSON(label="Metadata"),
112
  ],
 
113
  ).launch()
 
7
  import random
8
  from pypdf import PdfReader
9
  import ocrmypdf
10
+ from sentence_transformers import SentenceTransformer
11
+
12
+ model = SentenceTransformer("Snowflake/snowflake-arctic-embed-m")
13
+ model.to(device="cuda")
14
+
15
+
16
+ def chunk(text, max_length=512):
17
+ chunks = []
18
+ while len(text) > max_length:
19
+ chunks.append(text[:max_length])
20
+ text = text[max_length:]
21
+ chunks.append(text)
22
+ return chunks
23
+
24
+
25
+ @spaces.GPU
26
+ def embed(queries, chunks) -> dict[str, list[tuple[str, float]]]:
27
+ query_embeddings = model.encode(queries, prompt_name="query")
28
+ document_embeddings = model.encode(chunks)
29
+
30
+ scores = query_embeddings @ document_embeddings.T
31
+ results = {}
32
+ for query, query_scores in zip(queries, scores):
33
+ chunk_idxs = [i for i in range(len(chunks))]
34
+ # Get a structure like {query: [(chunk_idx, score), (chunk_idx, score), ...]}
35
+ results[query] = list(zip(chunk_idxs, query_scores))
36
+
37
+ return results
38
 
39
 
40
  def random_word(length):
 
42
  return "".join(random.choice(letters) for _ in range(length))
43
 
44
 
45
+ def convert_pdf(input_file) -> str:
46
  reader = PdfReader(input_file)
 
47
  text = extract_text_from_pdf(reader)
48
 
49
  # Check if there are any images
 
62
  # Delete the OCR file
63
  os.remove(out_pdf_file)
64
 
65
+ return text
66
 
67
 
68
  def extract_text_from_pdf(reader):
 
75
  return full_text.strip()
76
 
77
 
78
+ def convert_pandoc(input_file, filename) -> str:
 
 
 
 
 
 
 
 
 
 
79
  # Temporarily copy the file
80
  shutil.copyfile(input_file, filename)
81
 
 
95
 
96
 
97
  @spaces.GPU
98
+ def convert(input_file) -> str:
99
  plain_text_filetypes = [
100
  ".txt",
101
  ".csv",
 
108
  ".jsonc",
109
  ]
110
  # Already a plain text file that wouldn't benefit from pandoc so return the content
111
+ if any(input_file.endswith(ft) for ft in plain_text_filetypes):
112
  with open(input_file, "r") as f:
113
+ return f.read()
114
 
115
+ if input_file.endswith(".pdf"):
116
  return convert_pdf(input_file)
117
 
118
+ return convert_pandoc(input_file, input_file)
119
+
120
+
121
+ @spaces.GPU
122
+ def predict(queries, documents, max_characters) -> list[list[str]]:
123
+ queries = queries.split("\n")
124
+
125
+ # Conver the documents to text
126
+ converted_docs = [convert(doc) for doc in documents]
127
+
128
+ # Return if the total length is less than the max characters
129
+ total_doc_lengths = sum([len(doc) for doc, _ in converted_docs])
130
+ if total_doc_lengths < max_characters:
131
+ return [[doc] for doc, _ in converted_docs]
132
+
133
+ # Embed the documents in 512 character chunks
134
+ chunked_docs = [chunk(doc, 512) for doc in converted_docs]
135
+ embedded_docs = [embed(queries, chunks) for chunks in chunked_docs]
136
+
137
+ # Get a structure like {query: [(doc_idx, chunk_idx, score), (doc_idx, chunk_idx, score), ...]}
138
+ query_embeddings = {}
139
+ for doc_idx, embedded_doc in enumerate(embedded_docs):
140
+ for query, doc_scores in embedded_doc.items():
141
+ doc_scores_with_doc = [
142
+ (doc_idx, chunk_idx, score) for (chunk_idx, score) in doc_scores
143
+ ]
144
+ if query not in query_embeddings:
145
+ query_embeddings[query] = []
146
+ query_embeddings[query] = query_embeddings[query] + doc_scores_with_doc
147
+
148
+ # Sort the embeddings by score
149
+ for query, doc_scores in query_embeddings.items():
150
+ query_embeddings[query] = sorted(doc_scores, key=lambda x: x[2], reverse=True)
151
+
152
+ # Choose the top embedding from each query until we reach the max characters
153
+ # Getting a structure like [[chunk, ...]]
154
+ document_embeddings = [[] for _ in range(len(documents))]
155
+ total_chars = 0
156
+ while total_chars < max_characters:
157
+ for query, doc_scores in query_embeddings.items():
158
+ if len(doc_scores) == 0:
159
+ continue
160
+
161
+ # Grab the top score for the query
162
+ doc_idx, chunk_idx, _ = doc_scores.pop(0)
163
+ if doc_idx not in document_embeddings:
164
+ document_embeddings[doc_idx] = []
165
+
166
+ # Ensure we have space
167
+ chunk = chunked_docs[doc_idx][chunk_idx]
168
+ if total_chars + len(chunk) > max_characters:
169
+ continue
170
+
171
+ # Ensure we haven't already added this chunk from this document
172
+ if chunk_idx in document_embeddings[doc_idx]:
173
+ continue
174
+
175
+ # Add the chunk
176
+ document_embeddings[doc_idx].append(chunk_idx)
177
+ total_chars += len(chunk)
178
+
179
+ return document_embeddings
180
 
181
 
182
  # We accept a filename because the gradio JS interface removes this information
183
  # and it's critical for choosing the correct processing pipeline
184
  gr.Interface(
185
  convert,
186
+ inputs=[
187
+ gr.Textbox(label="Queries separated by newline"),
188
+ gr.Files(label="Upload File"),
189
+ gr.Number(label="Max output characters", value=16384),
190
  ],
191
+ outputs=[gr.JSON(label="Embedded documents")],
192
  ).launch()
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  ocrmypdf==16.3.1
2
  pypdf==4.2.0
 
 
1
  ocrmypdf==16.3.1
2
  pypdf==4.2.0
3
+ sentence-transformers==3.0.0