heerjtdev commited on
Commit
7bf9c65
·
verified ·
1 Parent(s): 6b41319

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -39
app.py CHANGED
@@ -2,32 +2,75 @@ import gradio as gr
2
  import PyPDF2
3
  import re
4
  import json
5
- from typing import List, Dict, Tuple
6
- from transformers import pipeline
 
7
  import tempfile
8
  import os
9
 
10
- # Initialize the question generation pipeline using a small CPU-friendly model
11
  print("Loading models... This may take a minute on first run.")
12
- qa_generator = pipeline(
13
- "text2text-generation",
14
- model="valhalla/t5-small-qg-hl",
15
- tokenizer="valhalla/t5-small-qg-hl",
16
- device=-1 # Force CPU
17
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def extract_text_from_pdf(pdf_file) -> str:
20
  """Extract text from uploaded PDF file."""
21
  text = ""
22
  try:
23
- # Handle both file path and file object
24
  if isinstance(pdf_file, str):
25
  pdf_reader = PyPDF2.PdfReader(pdf_file)
26
  else:
27
  pdf_reader = PyPDF2.PdfReader(pdf_file)
28
 
29
  for page in pdf_reader.pages:
30
- text += page.extract_text() + "\n"
 
 
31
  except Exception as e:
32
  return f"Error reading PDF: {str(e)}"
33
 
@@ -74,44 +117,32 @@ def generate_qa_pairs(chunk: str, num_questions: int = 2) -> List[Dict[str, str]
74
  flashcards = []
75
 
76
  # Skip chunks that are too short
77
- if len(chunk.split()) < 20:
 
78
  return []
79
 
80
  try:
81
- # Generate highlight format for T5 question generation
82
- # We'll create simple highlight by taking key sentences
83
- sentences = chunk.split('. ')
84
- if len(sentences) < 2:
85
  return []
86
 
87
- # Generate questions for different parts of the chunk
88
  for i in range(min(num_questions, len(sentences))):
89
- # Create highlight context
90
- highlight = sentences[i]
91
- context = chunk
92
-
93
- # Format for T5: "generate question: <hl> highlight <hl> context"
94
- input_text = f"generate question: <hl> {highlight} <hl> {context}"
95
 
96
- # Generate question
97
- outputs = qa_generator(
98
- input_text,
99
- max_length=128,
100
- num_return_sequences=1,
101
- do_sample=True,
102
- temperature=0.7
103
- )
104
 
105
- question = outputs[0]['generated_text'].strip()
106
 
107
- # Clean up question
108
- question = re.sub(r'^(question:|q:)', '', question, flags=re.IGNORECASE).strip()
109
-
110
- if question and len(question) > 10:
111
  flashcards.append({
112
  "question": question,
113
- "answer": highlight.strip(),
114
- "context": context[:200] + "..." if len(context) > 200 else context
115
  })
116
 
117
  except Exception as e:
@@ -305,12 +336,18 @@ with gr.Blocks(css=custom_css, title="PDF to Flashcards") as demo:
305
  gr.Markdown("*Raw JSON data for custom applications*")
306
 
307
  # Event handlers
 
 
 
 
 
 
308
  process_btn.click(
309
  fn=process_pdf,
310
  inputs=[pdf_input, questions_per_chunk, max_chunks],
311
  outputs=[status_text, csv_output, json_output]
312
  ).then(
313
- fn=lambda x: x if not isinstance(x, str) or not x.startswith("📄") else gr.update(),
314
  inputs=status_text,
315
  outputs=output_display
316
  )
 
2
  import PyPDF2
3
  import re
4
  import json
5
+ from typing import List, Dict
6
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
+ import torch
8
  import tempfile
9
  import os
10
 
11
+ # Initialize the model and tokenizer directly
12
  print("Loading models... This may take a minute on first run.")
13
+
14
+ model_name = "valhalla/t5-small-qg-hl"
15
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
16
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
17
+
18
+ # Set to evaluation mode and CPU
19
+ model.eval()
20
+ device = torch.device("cpu")
21
+ model.to(device)
22
+
23
+ def generate_questions(context: str, answer: str, max_length: int = 128) -> str:
24
+ """Generate a question using T5 model."""
25
+ try:
26
+ # Format: "generate question: <hl> answer <hl> context"
27
+ input_text = f"generate question: <hl> {answer} <hl> {context}"
28
+
29
+ # Tokenize
30
+ inputs = tokenizer(
31
+ input_text,
32
+ return_tensors="pt",
33
+ max_length=512,
34
+ truncation=True,
35
+ padding=True
36
+ ).to(device)
37
+
38
+ # Generate
39
+ with torch.no_grad():
40
+ outputs = model.generate(
41
+ **inputs,
42
+ max_length=max_length,
43
+ num_beams=4,
44
+ early_stopping=True,
45
+ do_sample=True,
46
+ temperature=0.7
47
+ )
48
+
49
+ # Decode
50
+ question = tokenizer.decode(outputs[0], skip_special_tokens=True)
51
+
52
+ # Clean up
53
+ question = re.sub(r'^(question:|q:)', '', question, flags=re.IGNORECASE).strip()
54
+
55
+ return question if len(question) > 10 else ""
56
+
57
+ except Exception as e:
58
+ print(f"Error generating question: {e}")
59
+ return ""
60
 
61
  def extract_text_from_pdf(pdf_file) -> str:
62
  """Extract text from uploaded PDF file."""
63
  text = ""
64
  try:
 
65
  if isinstance(pdf_file, str):
66
  pdf_reader = PyPDF2.PdfReader(pdf_file)
67
  else:
68
  pdf_reader = PyPDF2.PdfReader(pdf_file)
69
 
70
  for page in pdf_reader.pages:
71
+ page_text = page.extract_text()
72
+ if page_text:
73
+ text += page_text + "\n"
74
  except Exception as e:
75
  return f"Error reading PDF: {str(e)}"
76
 
 
117
  flashcards = []
118
 
119
  # Skip chunks that are too short
120
+ words = chunk.split()
121
+ if len(words) < 20:
122
  return []
123
 
124
  try:
125
+ # Split into sentences to use as answers
126
+ sentences = [s.strip() for s in chunk.split('. ') if len(s.strip()) > 20]
127
+
128
+ if len(sentences) < 1:
129
  return []
130
 
131
+ # Generate questions for different sentences
132
  for i in range(min(num_questions, len(sentences))):
133
+ answer = sentences[i]
 
 
 
 
 
134
 
135
+ # Skip very short answers
136
+ if len(answer.split()) < 3:
137
+ continue
 
 
 
 
 
138
 
139
+ question = generate_questions(chunk, answer)
140
 
141
+ if question and question != answer: # Make sure they're different
 
 
 
142
  flashcards.append({
143
  "question": question,
144
+ "answer": answer,
145
+ "context": chunk[:200] + "..." if len(chunk) > 200 else chunk
146
  })
147
 
148
  except Exception as e:
 
336
  gr.Markdown("*Raw JSON data for custom applications*")
337
 
338
  # Event handlers
339
+ def update_display(status):
340
+ """Update display when processing is done."""
341
+ if status and not status.startswith(("📄", "🧹", "✂️", "🎴", "✅")):
342
+ return status
343
+ return gr.update()
344
+
345
  process_btn.click(
346
  fn=process_pdf,
347
  inputs=[pdf_input, questions_per_chunk, max_chunks],
348
  outputs=[status_text, csv_output, json_output]
349
  ).then(
350
+ fn=update_display,
351
  inputs=status_text,
352
  outputs=output_display
353
  )