NotebookLlamaGroq / extract_text_from_pdf.py
yasserrmd's picture
Update extract_text_from_pdf.py
c95c91f verified
raw
history blame
5.81 kB
# extract_text_from_pdf.py
import os
import torch
import spaces
from PyPDF2 import PdfReader
from accelerate import Accelerator
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
class PDFTextExtractor:
"""
A class to handle PDF text extraction and preprocessing for podcast preparation.
"""
@spaces.GPU
def __init__(self, pdf_path, output_path):
"""
Initialize the PDFTextExtractor with paths and model details.
Args:
pdf_path (str): Path to the PDF file.
output_path (str): Path to save the cleaned text file.
model_name (str): Name of the model to use for text processing.
"""
model_name="meta-llama/Llama-3.2-1B-Instruct"
self.pdf_path = pdf_path
self.output_path = output_path
self.max_chars = 100000
self.chunk_size = 1000
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize model and tokenizer
self.accelerator = Accelerator()
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model, self.tokenizer = self.accelerator.prepare(self.model, self.tokenizer)
# System prompt for text processing
self.system_prompt = """
You are a world class text pre-processor, here is the raw data from a PDF, please parse and return it in a way that is crispy and usable to send to a podcast writer.
Be smart and aggressive with removing details; you're only cleaning up the text without summarizing.
Here is the text:
"""
@spaces.GPU
def validate_pdf(self):
"""Check if the file exists and is a valid PDF."""
if not os.path.exists(self.pdf_path):
print(f"Error: File not found at path: {self.pdf_path}")
return False
if not self.pdf_path.lower().endswith('.pdf'):
print("Error: File is not a PDF")
return False
return True
@spaces.GPU
def extract_text(self):
"""Extract text from the PDF, limited by max_chars."""
if not self.validate_pdf():
return None
with open(self.pdf_path, 'rb') as file:
pdf_reader = PdfReader(file)
num_pages = len(pdf_reader.pages)
print(f"Processing PDF with {num_pages} pages...")
extracted_text = []
total_chars = 0
for page_num in range(num_pages):
page = pdf_reader.pages[page_num]
text = page.extract_text() or ""
if total_chars + len(text) > self.max_chars:
remaining_chars = self.max_chars - total_chars
extracted_text.append(text[:remaining_chars])
print(f"Reached {self.max_chars} character limit at page {page_num + 1}")
break
extracted_text.append(text)
total_chars += len(text)
print(f"Processed page {page_num + 1}/{num_pages}")
final_text = '\n'.join(extracted_text)
print(f"Extraction complete! Total characters: {len(final_text)}")
return final_text
@spaces.GPU
def create_word_bounded_chunks(self, text):
"""Split text into chunks around the target size."""
words = text.split()
chunks = []
current_chunk = []
current_length = 0
for word in words:
word_length = len(word) + 1 # +1 for the space
if current_length + word_length > self.chunk_size and current_chunk:
chunks.append(' '.join(current_chunk))
current_chunk = [word]
current_length = word_length
else:
current_chunk.append(word)
current_length += word_length
if current_chunk:
chunks.append(' '.join(current_chunk))
return chunks
@spaces.GPU(duration=120)
def process_chunk(self, text_chunk):
"""Process a text chunk with the model and return the cleaned text."""
conversation = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": text_chunk}
]
prompt = self.tokenizer.apply_chat_template(conversation, tokenize=False)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
output = self.model.generate(**inputs, temperature=0.7, top_p=0.9, max_new_tokens=512)
processed_text = self.tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt):].strip()
return processed_text
@spaces.GPU
def clean_and_save_text(self):
"""Extract, clean, and save processed text to a file."""
extracted_text = self.extract_text()
if not extracted_text:
return None
chunks = self.create_word_bounded_chunks(extracted_text)
processed_text = ""
with open(self.output_path, 'w', encoding='utf-8') as out_file:
for chunk_num, chunk in enumerate(tqdm(chunks, desc="Processing chunks")):
processed_chunk = self.process_chunk(chunk)
processed_text += processed_chunk + "\n"
out_file.write(processed_chunk + "\n")
out_file.flush()
print(f"\nExtracted and cleaned text has been saved to {self.output_path}")
return self.output_path