from model.inlegalbert_loader import InLegalBERTClassifier from utils.pdf_extractor import extract_pdf_text from utils.text_cleaner import clean_text classifier = InLegalBERTClassifier() def filter_top_k(results, top_k=5, threshold=0.5): """Filter multi-label predictions by score threshold and keep top-k.""" filtered = [r for r in results if r["score"] >= threshold] return sorted(filtered, key=lambda x: x["score"], reverse=True)[:top_k] def classify_text(text: str, top_k: int = 5, threshold: float = 0.5): text = clean_text(text) ilsi_raw = classifier.predict_ilsi(text) ildc_raw = classifier.predict_ildc(text) ilsi_filtered = filter_top_k(ilsi_raw, top_k=top_k, threshold=threshold) return { "ILSI": ilsi_filtered, "ILDC": ildc_raw } def classify_pdf(file_bytes: bytes, top_k: int = 5, threshold: float = 0.5): pages = extract_pdf_text(file_bytes) full_text = " ".join(pages) ilsi_raw = classifier.predict_ilsi(full_text) ildc_raw = classifier.predict_ildc(full_text) results = { "ILSI": filter_top_k(ilsi_raw, top_k=top_k, threshold=threshold), "ILDC": ildc_raw, "ISS": {} } for i, page in enumerate(pages): if page.strip(): iss_raw = classifier.predict_iss(page) results["ISS"][f"page_{i+1}"] = filter_top_k(iss_raw, top_k=top_k, threshold=threshold) return results