mvansegbroeck commited on
Commit
4af6426
1 Parent(s): a444f06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -71
app.py CHANGED
@@ -9,8 +9,8 @@ import random
9
  from gretel_client import Gretel
10
  from gretel_client.config import GretelClientConfigurationError
11
 
12
- # Directory for saving processed PDFs
13
- output_dir = 'processed_pdfs'
14
  os.makedirs(output_dir, exist_ok=True)
15
 
16
  # Function to download and convert a PDF to text
@@ -22,6 +22,16 @@ def pdf_to_text(pdf_path):
22
  text += page.get_text()
23
  return text
24
 
 
 
 
 
 
 
 
 
 
 
25
  # Function to split text into chunks
26
  def split_text_into_chunks(text, chunk_size=25, chunk_overlap=5, min_chunk_chars=50):
27
  text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
@@ -38,76 +48,82 @@ def save_chunks(file_id, chunks, output_dir):
38
 
39
  # Function to read chunks from files
40
  def read_chunks_from_files(output_dir):
41
- pdf_chunks_dict = {}
42
  for filename in os.listdir(output_dir):
43
  if filename.endswith('.md') and 'chunk' in filename:
44
  file_id = filename.split('_chunk_')[0]
45
  chunk_path = os.path.join(output_dir, filename)
46
  with open(chunk_path, 'r') as file:
47
  chunk = file.read()
48
- if file_id not in pdf_chunks_dict:
49
- pdf_chunks_dict[file_id] = []
50
- pdf_chunks_dict[file_id].append(chunk)
51
- return pdf_chunks_dict
52
 
53
- def process_pdfs(uploaded_files, use_example, chunk_size, chunk_overlap, min_chunk_chars, current_chunk, direction):
54
- selected_pdfs = []
55
  if use_example:
56
  example_file_url = "https://gretel-datasets.s3.us-west-2.amazonaws.com/rag/GDPR_2016.pdf"
57
- pdf_path = os.path.join(output_dir, example_file_url.split('/')[-1])
58
- if not os.path.exists(pdf_path):
59
  response = requests.get(example_file_url)
60
- with open(pdf_path, 'wb') as file:
61
  file.write(response.content)
62
- selected_pdfs = [pdf_path]
63
  elif uploaded_files is not None:
64
  for uploaded_file in uploaded_files:
65
- pdf_path = os.path.join(output_dir, uploaded_file.name)
66
- selected_pdfs.append(pdf_path)
 
 
67
  else:
68
- chunk_text = "No PDFs processed"
69
  return None, 0, chunk_text, None
70
 
71
- pdf_chunks_dict = {}
72
- for pdf_path in selected_pdfs:
73
- text = pdf_to_text(pdf_path)
 
 
 
 
 
 
 
 
 
74
  markdown_text = markdownify.markdownify(text)
75
- file_id = os.path.splitext(os.path.basename(pdf_path))[0]
76
  markdown_path = os.path.join(output_dir, f"{file_id}.md")
77
  with open(markdown_path, 'w') as file:
78
  file.write(markdown_text)
79
  chunks = split_text_into_chunks(markdown_text, chunk_size=chunk_size, chunk_overlap=chunk_overlap, min_chunk_chars=min_chunk_chars)
80
  save_chunks(file_id, chunks, output_dir)
81
- pdf_chunks_dict[file_id] = chunks
82
 
83
- file_id = os.path.splitext(os.path.basename(selected_pdfs[0]))[0]
84
- chunks = pdf_chunks_dict.get(file_id, [])
85
 
86
  current_chunk += direction
87
  if current_chunk < 0:
88
  current_chunk = 0
89
- elif current_chunk >= len(chunks):
90
- current_chunk = len(chunks) - 1
91
 
92
- chunk_text = chunks[current_chunk] if chunks else "No chunks available."
93
-
94
- return pdf_chunks_dict, selected_pdfs, chunk_text, current_chunk#, use_example_update
95
 
96
- def show_chunks(pdf_chunks_dict, selected_pdfs, current_chunk, direction):
97
- if selected_pdfs:
98
- file_id = os.path.splitext(os.path.basename(selected_pdfs[0]))[0]
99
- chunks = pdf_chunks_dict.get(file_id, [])
100
-
101
- current_chunk += direction
102
- if current_chunk < 0:
103
- current_chunk = 0
104
- elif current_chunk >= len(chunks):
105
- current_chunk = len(chunks) - 1
106
-
107
- chunk_text = chunks[current_chunk] if chunks else "No chunks available."
108
- return chunk_text, current_chunk
109
- else:
110
- return "No PDF processed.", 0
111
 
112
  # Validate API key and return button state
113
  def check_api_key(api_key):
@@ -120,7 +136,7 @@ def check_api_key(api_key):
120
  status_message = "Invalid"
121
  return gr.update(interactive=is_valid), status_message
122
 
123
- def generate_synthetic_records(api_key, pdf_chunks_dict, num_records):
124
 
125
  gretel = Gretel(api_key=api_key, validate=True, clear=True)
126
 
@@ -146,10 +162,30 @@ def generate_synthetic_records(api_key, pdf_chunks_dict, num_records):
146
  "top_k": 40
147
  }
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  df_in = pd.DataFrame()
150
  try:
151
- documents = list(pdf_chunks_dict.keys())
152
- all_chunks = [(doc, chunk) for doc in documents for chunk in pdf_chunks_dict[doc]]
153
 
154
  for _ in range(num_records):
155
  doc, chunk = random.choice(all_chunks)
@@ -158,7 +194,13 @@ def generate_synthetic_records(api_key, pdf_chunks_dict, num_records):
158
 
159
  df = navigator.edit(PROMPT, seed_data=df_in, **GENERATE_PARAMS)
160
  df = df.drop(columns=['text'])
161
-
 
 
 
 
 
 
162
  csv_file = os.path.join(output_dir, "synthetic_qa.csv")
163
  df.to_csv(csv_file, index=False)
164
 
@@ -173,7 +215,7 @@ def download_dataframe(df):
173
  return csv_file
174
 
175
  # CSS styling to center the logo and prevent right-click download
176
- css = """
177
  <style>
178
  #logo-container {
179
  display: flex;
@@ -188,7 +230,7 @@ css = """
188
 
189
  # HTML content to include the logo
190
  html_content = f"""
191
- {css}
192
  <div id="logo-container">
193
  <svg width="181" height="72" viewBox="0 0 181 72" fill="none" xmlns="http://www.w3.org/2000/svg">
194
  <g clip-path="url(#clip0_849_78)">
@@ -210,37 +252,40 @@ html_content = f"""
210
  </div>
211
  """
212
 
 
 
 
 
 
 
 
213
  # Gradio interface
214
- with gr.Blocks() as demo:
215
  with gr.Row():
216
  with gr.Column(scale=3):
217
- # gr.Markdown("# Upload PDFs")
218
- # gr.Image("gretel_logo.svg", elem_id="logo", show_label=False)
219
  gr.HTML(html_content)
220
 
221
- with gr.Tab("Upload PDF"):
222
  use_example = gr.Checkbox(label="Continue with Example PDF", value=False, interactive=True)
223
- uploaded_files = gr.File(label="Upload your PDF files", file_count="multiple")
224
- # if uploaded_files:
225
- # use_example = gr.Checkbox(label="Continue with Example PDF", value=False, interactive=False)
226
 
227
  chunk_size = gr.Slider(label="Chunk Size (tokens)", minimum=10, maximum=1500, step=10, value=500)
228
  chunk_overlap = gr.Slider(label="Chunk Overlap (tokens)", minimum=0, maximum=500, step=5, value=100)
229
  min_chunk_chars = gr.Slider(label="Minimum Chunk Characters", minimum=10, maximum=2500, step=10, value=750)
230
 
231
- process_button = gr.Button("Process PDFs")
232
 
233
- pdf_chunks_dict = gr.State()
234
- selected_pdfs = gr.State()
235
  current_chunk = gr.State(value=0)
236
 
237
  chunk_text = gr.Textbox(label="Chunk Text", lines=10)
238
 
239
  def toggle_use_example(file_list):
240
  return gr.update(
241
- value = False,
242
  interactive=file_list is None or len(file_list) == 0
243
- )
244
 
245
  uploaded_files.change(
246
  toggle_use_example,
@@ -249,9 +294,9 @@ with gr.Blocks() as demo:
249
  )
250
 
251
  process_button.click(
252
- process_pdfs,
253
  inputs=[uploaded_files, use_example, chunk_size, chunk_overlap, min_chunk_chars, current_chunk, gr.State(0)],
254
- outputs=[pdf_chunks_dict, selected_pdfs, chunk_text, current_chunk]
255
  )
256
 
257
  with gr.Row():
@@ -260,13 +305,13 @@ with gr.Blocks() as demo:
260
 
261
  prev_button.click(
262
  show_chunks,
263
- inputs=[pdf_chunks_dict, selected_pdfs, current_chunk, gr.State(-1)],
264
  outputs=[chunk_text, current_chunk]
265
  )
266
 
267
  next_button.click(
268
  show_chunks,
269
- inputs=[pdf_chunks_dict, selected_pdfs, current_chunk, gr.State(1)],
270
  outputs=[chunk_text, current_chunk]
271
  )
272
 
@@ -277,28 +322,26 @@ with gr.Blocks() as demo:
277
  api_key_input = gr.Textbox(label="Gretel API Key (available at https://console.gretel.ai)", type="password", placeholder="Enter your API key", scale=2)
278
  validate_status = gr.Textbox(label="Validation Status", interactive=False, scale=1)
279
 
280
- # User-specific settings
281
  num_records = gr.Number(label="Number of Records", value=10)
282
 
283
  generate_button = gr.Button("Generate Synthetic Records", interactive=False)
284
  download_link = gr.File(label="Download Link", visible=False)
285
 
286
- # Validate API key on input change and update button interactivity
287
  api_key_input.change(
288
  fn=check_api_key,
289
  inputs=[api_key_input],
290
  outputs=[generate_button, validate_status]
291
  )
292
 
293
- output_df = gr.Dataframe(headers=["document", "topic", "user_profile", "question", "answer", "context"], wrap=True, visible=True)
294
 
295
- def generate_and_prepare_download(api_key, pdf_chunks_dict, num_records):
296
- df, csv_file = generate_synthetic_records(api_key, pdf_chunks_dict, num_records)
297
  return df, gr.update(value=csv_file, visible=df['value']!=None)
298
 
299
  generate_button.click(
300
  fn=generate_and_prepare_download,
301
- inputs=[api_key_input, pdf_chunks_dict, num_records],
302
  outputs=[output_df, download_link]
303
  )
304
 
 
9
  from gretel_client import Gretel
10
  from gretel_client.config import GretelClientConfigurationError
11
 
12
+ # Directory for saving processed files
13
+ output_dir = 'processed_files'
14
  os.makedirs(output_dir, exist_ok=True)
15
 
16
  # Function to download and convert a PDF to text
 
22
  text += page.get_text()
23
  return text
24
 
25
+ # Function to read a TXT file
26
+ def txt_to_text(txt_path):
27
+ with open(txt_path, 'r') as file:
28
+ return file.read()
29
+
30
+ # Function to read a Markdown file
31
+ def markdown_to_text(md_path):
32
+ with open(md_path, 'r') as file:
33
+ return file.read()
34
+
35
  # Function to split text into chunks
36
  def split_text_into_chunks(text, chunk_size=25, chunk_overlap=5, min_chunk_chars=50):
37
  text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
 
48
 
49
  # Function to read chunks from files
50
  def read_chunks_from_files(output_dir):
51
+ chunks_dict = {}
52
  for filename in os.listdir(output_dir):
53
  if filename.endswith('.md') and 'chunk' in filename:
54
  file_id = filename.split('_chunk_')[0]
55
  chunk_path = os.path.join(output_dir, filename)
56
  with open(chunk_path, 'r') as file:
57
  chunk = file.read()
58
+ if file_id not in chunks_dict:
59
+ chunks_dict[file_id] = []
60
+ chunks_dict[file_id].append(chunk)
61
+ return chunks_dict
62
 
63
+ def process_files(uploaded_files, use_example, chunk_size, chunk_overlap, min_chunk_chars, current_chunk, direction):
64
+ selected_files = []
65
  if use_example:
66
  example_file_url = "https://gretel-datasets.s3.us-west-2.amazonaws.com/rag/GDPR_2016.pdf"
67
+ file_path = os.path.join(output_dir, example_file_url.split('/')[-1])
68
+ if not os.path.exists(file_path):
69
  response = requests.get(example_file_url)
70
+ with open(file_path, 'wb') as file:
71
  file.write(response.content)
72
+ selected_files = [file_path]
73
  elif uploaded_files is not None:
74
  for uploaded_file in uploaded_files:
75
+ file_path = os.path.join(output_dir, uploaded_file.name)
76
+ # with open(file_path, 'wb') as file:
77
+ # file.write(uploaded_file.read())
78
+ selected_files.append(file_path)
79
  else:
80
+ chunk_text = "No files processed"
81
  return None, 0, chunk_text, None
82
 
83
+ chunks_dict = {}
84
+ for file_path in selected_files:
85
+ file_extension = os.path.splitext(file_path)[1].lower()
86
+ if file_extension == '.pdf':
87
+ text = pdf_to_text(file_path)
88
+ elif file_extension == '.txt':
89
+ text = txt_to_text(file_path)
90
+ elif file_extension == '.md':
91
+ text = markdown_to_text(file_path)
92
+ else:
93
+ text = ""
94
+
95
  markdown_text = markdownify.markdownify(text)
96
+ file_id = os.path.splitext(os.path.basename(file_path))[0]
97
  markdown_path = os.path.join(output_dir, f"{file_id}.md")
98
  with open(markdown_path, 'w') as file:
99
  file.write(markdown_text)
100
  chunks = split_text_into_chunks(markdown_text, chunk_size=chunk_size, chunk_overlap=chunk_overlap, min_chunk_chars=min_chunk_chars)
101
  save_chunks(file_id, chunks, output_dir)
102
+ chunks_dict[file_id + file_extension] = chunks
103
 
104
+ all_chunks = [chunk for chunks in chunks_dict.values() for chunk in chunks]
 
105
 
106
  current_chunk += direction
107
  if current_chunk < 0:
108
  current_chunk = 0
109
+ elif current_chunk >= len(all_chunks):
110
+ current_chunk = len(all_chunks) - 1
111
 
112
+ chunk_text = all_chunks[current_chunk] if all_chunks else "No chunks available."
113
+
114
+ return chunks_dict, selected_files, chunk_text, current_chunk#, use_example_update
115
 
116
+ def show_chunks(chunks_dict, selected_files, current_chunk, direction):
117
+ all_chunks = [chunk for chunks in chunks_dict.values() for chunk in chunks]
118
+
119
+ current_chunk += direction
120
+ if current_chunk < 0:
121
+ current_chunk = 0
122
+ elif current_chunk >= len(all_chunks):
123
+ current_chunk = len(all_chunks) - 1
124
+
125
+ chunk_text = all_chunks[current_chunk] if all_chunks else "No chunks available."
126
+ return chunk_text, current_chunk
 
 
 
 
127
 
128
  # Validate API key and return button state
129
  def check_api_key(api_key):
 
136
  status_message = "Invalid"
137
  return gr.update(interactive=is_valid), status_message
138
 
139
+ def generate_synthetic_records(api_key, chunks_dict, num_records):
140
 
141
  gretel = Gretel(api_key=api_key, validate=True, clear=True)
142
 
 
162
  "top_k": 40
163
  }
164
 
165
+ EVAL_PROMPT = """
166
+ Add the following columns to the provided table based on the generated text in the `answer` column:
167
+ * conformance: A rating from 'Poor' to 'Excellent' indicating the conformance of the generated text to the requested format, tags, and descriptions provided.
168
+ - 'Poor', 'Fair', 'Good', 'Very Good', 'Excellent'
169
+ * quality: A rating from 'Poor' to 'Excellent' based on the grammatical correctness, coherence, and relevance of the generated text.
170
+ - 'Poor', 'Fair', 'Good', 'Very Good', 'Excellent'
171
+ * toxicity: A rating from 'Non-Toxic' to 'Highly Toxic' indicating the level of toxic content in the generated text.
172
+ - 'Non-Toxic', 'Moderately Toxic', 'Highly Toxic'
173
+ * bias: A rating from 'Unbiased' to 'Heavily Biased' indicating the level of unintended biases in the generated text.
174
+ - 'Unbiased', 'Moderately Biased', 'Heavily Biased'
175
+ * groundedness: A rating from 'Ungrounded' to 'Fully Grounded' indicating the level of factual correctness in the generated text.
176
+ - 'Ungrounded', 'Moderately Grounded', 'Fully Grounded'
177
+ """
178
+
179
+ EVAL_GENERATE_PARAMS = {
180
+ "temperature": 0.2,
181
+ "top_p": 0.5,
182
+ "top_k": 40
183
+ }
184
+
185
  df_in = pd.DataFrame()
186
  try:
187
+ documents = list(chunks_dict.keys())
188
+ all_chunks = [(doc, chunk) for doc in documents for chunk in chunks_dict[doc]]
189
 
190
  for _ in range(num_records):
191
  doc, chunk = random.choice(all_chunks)
 
194
 
195
  df = navigator.edit(PROMPT, seed_data=df_in, **GENERATE_PARAMS)
196
  df = df.drop(columns=['text'])
197
+ df = navigator.edit(EVAL_PROMPT, seed_data=df, **EVAL_GENERATE_PARAMS)
198
+ df.rename(columns={
199
+ "question": "synthetic_question",
200
+ "answer": "synthetic_answer",
201
+ "context": "original_context"
202
+ }, inplace=True)
203
+
204
  csv_file = os.path.join(output_dir, "synthetic_qa.csv")
205
  df.to_csv(csv_file, index=False)
206
 
 
215
  return csv_file
216
 
217
  # CSS styling to center the logo and prevent right-click download
218
+ logo_css = """
219
  <style>
220
  #logo-container {
221
  display: flex;
 
230
 
231
  # HTML content to include the logo
232
  html_content = f"""
233
+ {logo_css}
234
  <div id="logo-container">
235
  <svg width="181" height="72" viewBox="0 0 181 72" fill="none" xmlns="http://www.w3.org/2000/svg">
236
  <g clip-path="url(#clip0_849_78)">
 
252
  </div>
253
  """
254
 
255
+ # Define custom CSS to set the font size
256
+ css = """
257
+ #small span{
258
+ font-size: 0.8em;
259
+ }
260
+ """
261
+
262
  # Gradio interface
263
+ with gr.Blocks(css=css) as demo:
264
  with gr.Row():
265
  with gr.Column(scale=3):
 
 
266
  gr.HTML(html_content)
267
 
268
+ with gr.Tab("Upload Files"):
269
  use_example = gr.Checkbox(label="Continue with Example PDF", value=False, interactive=True)
270
+ uploaded_files = gr.File(label="Upload your files (TXT, Markdown, or PDF)", file_count="multiple", file_types=[".pdf", ".txt", ".md"])
 
 
271
 
272
  chunk_size = gr.Slider(label="Chunk Size (tokens)", minimum=10, maximum=1500, step=10, value=500)
273
  chunk_overlap = gr.Slider(label="Chunk Overlap (tokens)", minimum=0, maximum=500, step=5, value=100)
274
  min_chunk_chars = gr.Slider(label="Minimum Chunk Characters", minimum=10, maximum=2500, step=10, value=750)
275
 
276
+ process_button = gr.Button("Process Files")
277
 
278
+ chunks_dict = gr.State()
279
+ selected_files = gr.State()
280
  current_chunk = gr.State(value=0)
281
 
282
  chunk_text = gr.Textbox(label="Chunk Text", lines=10)
283
 
284
  def toggle_use_example(file_list):
285
  return gr.update(
286
+ value=False,
287
  interactive=file_list is None or len(file_list) == 0
288
+ )
289
 
290
  uploaded_files.change(
291
  toggle_use_example,
 
294
  )
295
 
296
  process_button.click(
297
+ process_files,
298
  inputs=[uploaded_files, use_example, chunk_size, chunk_overlap, min_chunk_chars, current_chunk, gr.State(0)],
299
+ outputs=[chunks_dict, selected_files, chunk_text, current_chunk]
300
  )
301
 
302
  with gr.Row():
 
305
 
306
  prev_button.click(
307
  show_chunks,
308
+ inputs=[chunks_dict, selected_files, current_chunk, gr.State(-1)],
309
  outputs=[chunk_text, current_chunk]
310
  )
311
 
312
  next_button.click(
313
  show_chunks,
314
+ inputs=[chunks_dict, selected_files, current_chunk, gr.State(1)],
315
  outputs=[chunk_text, current_chunk]
316
  )
317
 
 
322
  api_key_input = gr.Textbox(label="Gretel API Key (available at https://console.gretel.ai)", type="password", placeholder="Enter your API key", scale=2)
323
  validate_status = gr.Textbox(label="Validation Status", interactive=False, scale=1)
324
 
 
325
  num_records = gr.Number(label="Number of Records", value=10)
326
 
327
  generate_button = gr.Button("Generate Synthetic Records", interactive=False)
328
  download_link = gr.File(label="Download Link", visible=False)
329
 
 
330
  api_key_input.change(
331
  fn=check_api_key,
332
  inputs=[api_key_input],
333
  outputs=[generate_button, validate_status]
334
  )
335
 
336
+ output_df = gr.Dataframe(headers=["",], wrap=True, visible=True, elem_id="small")
337
 
338
+ def generate_and_prepare_download(api_key, chunks_dict, num_records):
339
+ df, csv_file = generate_synthetic_records(api_key, chunks_dict, num_records)
340
  return df, gr.update(value=csv_file, visible=df['value']!=None)
341
 
342
  generate_button.click(
343
  fn=generate_and_prepare_download,
344
+ inputs=[api_key_input, chunks_dict, num_records],
345
  outputs=[output_df, download_link]
346
  )
347