minar09's picture
Update main.py
58a106e verified
raw
history blame
7.77 kB
import os
import json
import time
import logging
from pathlib import Path
from typing import List, Dict, Optional
from dataclasses import dataclass
from fastapi.encoders import jsonable_encoder
# from sentence_transformers import SentenceTransformer
# from llama_cpp import Llama
# Fix: Dynamically adjust the module path if magic_pdf is in a non-standard location
try:
from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedDataReader
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.config.enums import SupportedPdfParseMethod
except ModuleNotFoundError as e:
logging.error(f"Failed to import magic_pdf modules: {e}")
logging.info("Ensure that the magic_pdf package is installed and accessible in your Python environment.")
raise e
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class ProductSpec:
name: str
description: Optional[str] = None
price: Optional[float] = None
attributes: Dict[str, str] = None
tables: List[Dict] = None
def to_dict(self):
return jsonable_encoder(self)
class PDFProcessor:
def __init__(self):
self.emb_model = self._initialize_emb_model("all-MiniLM-L6-v2")
self.llm = self._initialize_llm("deepseek-llm-7b-base.Q5_K_M.gguf")
self.output_dir = Path("./output")
self.output_dir.mkdir(exist_ok=True)
def _initialize_emb_model(self, model_name):
# try:
# model = SentenceTransformer("sentence-transformers/" + model_name)
# model.save('models/'+ model_name)
# return model
# except:
# Load model directly
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
return model
def _initialize_llm(self, model_name):
"""Initialize LLM with automatic download if needed"""
"""
model_path = os.path.join("models/", model_name)
if os.path.exists(model_path):
return Llama(
model_path=model_path,
n_ctx=2048,
n_gpu_layers=35 if os.getenv('USE_GPU') else 0,
n_threads=os.cpu_count() - 1,
verbose=False
)
else:
return Llama.from_pretrained(
repo_id="TheBloke/deepseek-llm-7B-base-GGUF",
filename=model_name,
n_ctx=2048,
n_threads=os.cpu_count() - 1,
n_gpu_layers=35 if os.getenv('USE_GPU') else 0,
verbose=False
)
"""
# Load model directly
from transformers import AutoModel
model = AutoModel.from_pretrained("TheBloke/deepseek-llm-7B-base-GGUF")
return model
def process_pdf(self, pdf_path: str) -> Dict:
"""Process PDF using MinerU pipeline"""
start_time = time.time()
# Initialize MinerU components
local_image_dir = self.output_dir / "images"
local_md_dir = self.output_dir
image_dir = str(local_image_dir.name)
os.makedirs(local_image_dir, exist_ok=True)
try:
image_writer = FileBasedDataWriter(str(local_image_dir))
md_writer = FileBasedDataWriter(str(local_md_dir))
# Read PDF
reader = FileBasedDataReader("")
pdf_bytes = reader.read(pdf_path)
# Create dataset and process
ds = PymuDocDataset(pdf_bytes)
if ds.classify() == SupportedPdfParseMethod.OCR:
infer_result = ds.apply(doc_analyze, ocr=True)
pipe_result = infer_result.pipe_ocr_mode(image_writer)
else:
infer_result = ds.apply(doc_analyze, ocr=False)
pipe_result = infer_result.pipe_txt_mode(image_writer)
# Get structured content
middle_json = pipe_result.get_middle_json()
tables = self._extract_tables(middle_json)
text_blocks = self._extract_text_blocks(middle_json)
# Process text blocks with LLM
products = []
for block in text_blocks:
product = self._process_text_block(block)
if product:
product.tables = tables
products.append(product.to_dict())
logger.info(f"Processed {len(products)} products in {time.time()-start_time:.2f}s")
return {"products": products, "tables": tables}
except Exception as e:
logger.error(f"Error during PDF processing: {e}")
raise RuntimeError("PDF processing failed.") from e
def _extract_tables(self, middle_json: Dict) -> List[Dict]:
"""Extract tables from MinerU's middle JSON"""
tables = []
for page in middle_json.get('pages', []):
for table in page.get('tables', []):
tables.append({
"page": page.get('page_number'),
"cells": table.get('cells', []),
"header": table.get('header', []),
"content": table.get('content', [])
})
return tables
def _extract_text_blocks(self, middle_json: Dict) -> List[str]:
"""Extract text blocks from MinerU's middle JSON"""
text_blocks = []
for page in middle_json.get('pages', []):
for block in page.get('blocks', []):
if block.get('type') == 'text':
text_blocks.append(block.get('text', ''))
return text_blocks
def _process_text_block(self, text: str) -> Optional[ProductSpec]:
"""Process text block with LLM"""
prompt = self._generate_query_prompt(text)
try:
response = self.llm.create_chat_completion(
messages=[{"role": "user", "content": prompt}],
temperature=0.1,
max_tokens=512
)
return self._parse_response(response['choices'][0]['message']['content'])
except Exception as e:
logger.warning(f"Error processing text block: {e}")
return None
def _generate_query_prompt(self, text: str) -> str:
"""Generate extraction prompt"""
return f"""Extract product specifications from this text:
{text}
Return JSON format:
{{
"name": "product name",
"description": "product description",
"price": numeric_price,
"attributes": {{ "key": "value" }}
}}"""
def _parse_response(self, response: str) -> Optional[ProductSpec]:
"""Parse LLM response"""
try:
json_start = response.find('{')
json_end = response.rfind('}') + 1
data = json.loads(response[json_start:json_end])
return ProductSpec(
name=data.get('name', ''),
description=data.get('description'),
price=data.get('price'),
attributes=data.get('attributes', {})
)
except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"Parse error: {e}")
return None
def process_pdf_catalog(pdf_path: str):
processor = PDFProcessor()
try:
result = processor.process_pdf(pdf_path)
return result, "Processing completed successfully!"
except Exception as e:
logger.error(f"Processing failed: {e}")
return {}, "Error processing PDF"