Spaces:
Sleeping
Sleeping
Commit ·
2d0ef3b
1
Parent(s): aa47fca
feat: update classifier model to local zero-shot NLI and enhance language detection with local library
Browse files- .env.example +2 -4
- README.md +4 -2
- app/core/config.py +2 -4
- app/pipelines/classification_pipeline.py +3 -1
- app/services/classifier_service.py +61 -20
- app/services/extraction_service.py +5 -4
- app/services/language_service.py +13 -29
- docker-compose.yml +2 -4
- docs/explanation/architecture.md +7 -7
- docs/explanation/decisions.md +5 -6
- docs/how-to/deploy-with-docker-compose.md +1 -1
- docs/how-to/run-locally.md +3 -4
- docs/reference/api.md +1 -1
- docs/reference/configuration.md +3 -4
- docs/reference/runtime-state.md +2 -2
- docs/tutorials/getting-started.md +1 -1
- requirements.txt +1 -1
- tests/test_classification_pipeline_behavior.py +37 -0
- tests/test_classifier_service.py +94 -0
- tests/test_language_service.py +34 -0
.env.example
CHANGED
|
@@ -5,10 +5,8 @@ DEBUG=false
|
|
| 5 |
STATIC_DIR=static
|
| 6 |
UPLOAD_SUBDIR=uploads
|
| 7 |
|
| 8 |
-
CLASSIFIER_MODEL=AyoubChLin/
|
|
|
|
| 9 |
HUGGINGFACE_TOKEN=
|
| 10 |
|
| 11 |
-
LANGUAGE_DETECTOR_URL=https://team-language-detector-languagedetector.hf.space/run/predict
|
| 12 |
-
REQUEST_TIMEOUT_SECONDS=30
|
| 13 |
-
|
| 14 |
DEFAULT_LABELS_CSV=news,sport,finance,politics
|
|
|
|
| 5 |
STATIC_DIR=static
|
| 6 |
UPLOAD_SUBDIR=uploads
|
| 7 |
|
| 8 |
+
CLASSIFIER_MODEL=AyoubChLin/bert-base-uncased-zeroshot-nli
|
| 9 |
+
ENABLE_MODEL_QUANTIZATION=true
|
| 10 |
HUGGINGFACE_TOKEN=
|
| 11 |
|
|
|
|
|
|
|
|
|
|
| 12 |
DEFAULT_LABELS_CSV=news,sport,finance,politics
|
README.md
CHANGED
|
@@ -40,8 +40,8 @@ cp .env.example .env
|
|
| 40 |
|
| 41 |
Key vars:
|
| 42 |
- `CLASSIFIER_MODEL`
|
|
|
|
| 43 |
- `HUGGINGFACE_TOKEN`
|
| 44 |
-
- `LANGUAGE_DETECTOR_URL`
|
| 45 |
- `DEFAULT_LABELS_CSV`
|
| 46 |
|
| 47 |
## Local Run
|
|
@@ -63,4 +63,6 @@ pytest -q
|
|
| 63 |
## Notes
|
| 64 |
- OCR requires `tesseract-ocr` (installed in Dockerfile).
|
| 65 |
- Supported extraction formats in this refactor: `.pdf`, `.docx`, `.xlsx`, image formats, and plain text files.
|
| 66 |
-
- The classifier model is loaded directly from Hugging Face Hub
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
Key vars:
|
| 42 |
- `CLASSIFIER_MODEL`
|
| 43 |
+
- `ENABLE_MODEL_QUANTIZATION`
|
| 44 |
- `HUGGINGFACE_TOKEN`
|
|
|
|
| 45 |
- `DEFAULT_LABELS_CSV`
|
| 46 |
|
| 47 |
## Local Run
|
|
|
|
| 63 |
## Notes
|
| 64 |
- OCR requires `tesseract-ocr` (installed in Dockerfile).
|
| 65 |
- Supported extraction formats in this refactor: `.pdf`, `.docx`, `.xlsx`, image formats, and plain text files.
|
| 66 |
+
- The classifier model is loaded directly from Hugging Face Hub and runs true zero-shot classification over runtime labels.
|
| 67 |
+
- Language detection runs locally via `langdetect` (no remote language endpoint dependency).
|
| 68 |
+
- `/classify` uses only the first PDF page for classification; `/api/transformer` still extracts full content.
|
app/core/config.py
CHANGED
|
@@ -15,12 +15,10 @@ class Settings(BaseSettings):
|
|
| 15 |
static_dir: Path = Path("static")
|
| 16 |
upload_subdir: str = "uploads"
|
| 17 |
|
| 18 |
-
classifier_model: str = "AyoubChLin/
|
|
|
|
| 19 |
huggingface_token: str | None = None
|
| 20 |
|
| 21 |
-
language_detector_url: str = "https://team-language-detector-languagedetector.hf.space/run/predict"
|
| 22 |
-
request_timeout_seconds: float = 30.0
|
| 23 |
-
|
| 24 |
default_labels_csv: str = Field(default="news,sport,finance,politics")
|
| 25 |
|
| 26 |
@property
|
|
|
|
| 15 |
static_dir: Path = Path("static")
|
| 16 |
upload_subdir: str = "uploads"
|
| 17 |
|
| 18 |
+
classifier_model: str = "AyoubChLin/bert-base-uncased-zeroshot-nli"
|
| 19 |
+
enable_model_quantization: bool = True
|
| 20 |
huggingface_token: str | None = None
|
| 21 |
|
|
|
|
|
|
|
|
|
|
| 22 |
default_labels_csv: str = Field(default="news,sport,finance,politics")
|
| 23 |
|
| 24 |
@property
|
app/pipelines/classification_pipeline.py
CHANGED
|
@@ -25,7 +25,9 @@ class ClassificationPipeline:
|
|
| 25 |
return text
|
| 26 |
|
| 27 |
def classify_file(self, original_filename: str, file_path: Path) -> dict:
|
| 28 |
-
text =
|
|
|
|
|
|
|
| 29 |
preprocessed_text = preprocess_text(text)
|
| 30 |
|
| 31 |
language = language_service.detect_language(preprocessed_text)
|
|
|
|
| 25 |
return text
|
| 26 |
|
| 27 |
def classify_file(self, original_filename: str, file_path: Path) -> dict:
|
| 28 |
+
text = extraction_service.extract_text(original_filename, file_path, pdf_first_page_only=True)
|
| 29 |
+
if not text or not text.strip():
|
| 30 |
+
raise ExtractionError("No text extracted from file")
|
| 31 |
preprocessed_text = preprocess_text(text)
|
| 32 |
|
| 33 |
language = language_service.detect_language(preprocessed_text)
|
app/services/classifier_service.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from typing import Any
|
| 2 |
|
| 3 |
import torch
|
|
@@ -6,8 +7,12 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
| 6 |
from app.core.config import settings
|
| 7 |
from app.core.exceptions import ClassificationError
|
| 8 |
|
|
|
|
|
|
|
| 9 |
|
| 10 |
class ClassifierService:
|
|
|
|
|
|
|
| 11 |
def __init__(self) -> None:
|
| 12 |
self._tokenizer: Any | None = None
|
| 13 |
self._model: Any | None = None
|
|
@@ -15,21 +20,34 @@ class ClassifierService:
|
|
| 15 |
def _load_model(self) -> tuple[Any, Any]:
|
| 16 |
if self._tokenizer is None or self._model is None:
|
| 17 |
try:
|
| 18 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
model.eval()
|
| 22 |
model.to("cpu")
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
self._tokenizer = tokenizer
|
| 32 |
-
self._model =
|
| 33 |
except Exception as exc:
|
| 34 |
raise ClassificationError("Unable to initialize classifier model") from exc
|
| 35 |
|
|
@@ -38,31 +56,54 @@ class ClassifierService:
|
|
| 38 |
def warmup(self) -> None:
|
| 39 |
self._load_model()
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
def classify(self, text: str, labels: list[str]) -> str:
|
| 42 |
-
|
|
|
|
| 43 |
raise ClassificationError("No labels configured")
|
| 44 |
|
| 45 |
tokenizer, model = self._load_model()
|
|
|
|
| 46 |
|
| 47 |
try:
|
|
|
|
| 48 |
inputs = tokenizer(
|
| 49 |
-
|
| 50 |
padding=True,
|
| 51 |
-
truncation=
|
| 52 |
return_tensors="pt",
|
| 53 |
)
|
| 54 |
|
| 55 |
with torch.no_grad():
|
| 56 |
logits = model(**inputs).logits
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
if
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
except Exception as exc:
|
| 63 |
raise ClassificationError("Classifier prediction failed") from exc
|
| 64 |
|
| 65 |
-
raise ClassificationError("Classifier did not return a valid label")
|
| 66 |
-
|
| 67 |
-
|
| 68 |
classifier_service = ClassifierService()
|
|
|
|
| 1 |
+
import logging
|
| 2 |
from typing import Any
|
| 3 |
|
| 4 |
import torch
|
|
|
|
| 7 |
from app.core.config import settings
|
| 8 |
from app.core.exceptions import ClassificationError
|
| 9 |
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
|
| 13 |
class ClassifierService:
|
| 14 |
+
_HYPOTHESIS_TEMPLATE = "This text is about {}."
|
| 15 |
+
|
| 16 |
def __init__(self) -> None:
|
| 17 |
self._tokenizer: Any | None = None
|
| 18 |
self._model: Any | None = None
|
|
|
|
| 20 |
def _load_model(self) -> tuple[Any, Any]:
|
| 21 |
if self._tokenizer is None or self._model is None:
|
| 22 |
try:
|
| 23 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 24 |
+
settings.classifier_model,
|
| 25 |
+
token=settings.huggingface_token,
|
| 26 |
+
)
|
| 27 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 28 |
+
settings.classifier_model,
|
| 29 |
+
token=settings.huggingface_token,
|
| 30 |
+
)
|
| 31 |
model.eval()
|
| 32 |
model.to("cpu")
|
| 33 |
|
| 34 |
+
if settings.enable_model_quantization:
|
| 35 |
+
try:
|
| 36 |
+
# Dynamic INT8 quantization for CPU inference.
|
| 37 |
+
quantized_model = torch.ao.quantization.quantize_dynamic(
|
| 38 |
+
model,
|
| 39 |
+
{torch.nn.Linear},
|
| 40 |
+
dtype=torch.qint8,
|
| 41 |
+
)
|
| 42 |
+
model = quantized_model
|
| 43 |
+
except Exception:
|
| 44 |
+
logger.warning(
|
| 45 |
+
"Model quantization failed; using non-quantized model instead.",
|
| 46 |
+
exc_info=True,
|
| 47 |
+
)
|
| 48 |
|
| 49 |
self._tokenizer = tokenizer
|
| 50 |
+
self._model = model
|
| 51 |
except Exception as exc:
|
| 52 |
raise ClassificationError("Unable to initialize classifier model") from exc
|
| 53 |
|
|
|
|
| 56 |
def warmup(self) -> None:
|
| 57 |
self._load_model()
|
| 58 |
|
| 59 |
+
@staticmethod
|
| 60 |
+
def _normalize_labels(labels: list[str]) -> list[str]:
|
| 61 |
+
cleaned = [label.strip() for label in labels if isinstance(label, str) and label.strip()]
|
| 62 |
+
return list(dict.fromkeys(cleaned))
|
| 63 |
+
|
| 64 |
+
@staticmethod
|
| 65 |
+
def _resolve_entailment_id(model: Any) -> int:
|
| 66 |
+
label2id = getattr(model.config, "label2id", {}) or {}
|
| 67 |
+
for label, label_id in label2id.items():
|
| 68 |
+
if isinstance(label, str) and label.lower().startswith("entail"):
|
| 69 |
+
return int(label_id)
|
| 70 |
+
|
| 71 |
+
id2label = getattr(model.config, "id2label", {}) or {}
|
| 72 |
+
for label_id, label in id2label.items():
|
| 73 |
+
if isinstance(label, str) and label.lower().startswith("entail"):
|
| 74 |
+
return int(label_id)
|
| 75 |
+
|
| 76 |
+
raise ClassificationError("Classifier model is missing an entailment label mapping")
|
| 77 |
+
|
| 78 |
def classify(self, text: str, labels: list[str]) -> str:
|
| 79 |
+
candidate_labels = self._normalize_labels(labels)
|
| 80 |
+
if not candidate_labels:
|
| 81 |
raise ClassificationError("No labels configured")
|
| 82 |
|
| 83 |
tokenizer, model = self._load_model()
|
| 84 |
+
entailment_id = self._resolve_entailment_id(model)
|
| 85 |
|
| 86 |
try:
|
| 87 |
+
sequence_pairs = [[text, self._HYPOTHESIS_TEMPLATE.format(label)] for label in candidate_labels]
|
| 88 |
inputs = tokenizer(
|
| 89 |
+
sequence_pairs,
|
| 90 |
padding=True,
|
| 91 |
+
truncation="only_first",
|
| 92 |
return_tensors="pt",
|
| 93 |
)
|
| 94 |
|
| 95 |
with torch.no_grad():
|
| 96 |
logits = model(**inputs).logits
|
| 97 |
|
| 98 |
+
if logits.ndim != 2:
|
| 99 |
+
raise ClassificationError("Classifier returned unexpected logits shape")
|
| 100 |
+
if entailment_id < 0 or entailment_id >= logits.shape[-1]:
|
| 101 |
+
raise ClassificationError("Entailment label index is out of range for classifier output")
|
| 102 |
+
|
| 103 |
+
entailment_logits = logits[:, entailment_id]
|
| 104 |
+
best_index = int(torch.argmax(entailment_logits).item())
|
| 105 |
+
return candidate_labels[best_index]
|
| 106 |
except Exception as exc:
|
| 107 |
raise ClassificationError("Classifier prediction failed") from exc
|
| 108 |
|
|
|
|
|
|
|
|
|
|
| 109 |
classifier_service = ClassifierService()
|
app/services/extraction_service.py
CHANGED
|
@@ -16,10 +16,11 @@ TEXT_EXTENSIONS = {".txt", ".md", ".csv", ".json"}
|
|
| 16 |
|
| 17 |
class ExtractionService:
|
| 18 |
@staticmethod
|
| 19 |
-
def _extract_pdf(file_path: Path) -> str:
|
| 20 |
reader = PdfReader(str(file_path))
|
| 21 |
chunks: list[str] = []
|
| 22 |
-
|
|
|
|
| 23 |
text = page.extract_text() or ""
|
| 24 |
if text.strip():
|
| 25 |
chunks.append(text)
|
|
@@ -41,13 +42,13 @@ class ExtractionService:
|
|
| 41 |
workbook.close()
|
| 42 |
return "\n".join(chunks)
|
| 43 |
|
| 44 |
-
def extract_text(self, file_name: str, file_path: Path) -> str:
|
| 45 |
extension = Path(file_name).suffix.lower()
|
| 46 |
|
| 47 |
try:
|
| 48 |
if extension in DOC_EXTENSIONS:
|
| 49 |
if extension == ".pdf":
|
| 50 |
-
return self._extract_pdf(file_path)
|
| 51 |
if extension == ".docx":
|
| 52 |
return self._extract_docx(file_path)
|
| 53 |
if extension == ".xlsx":
|
|
|
|
| 16 |
|
| 17 |
class ExtractionService:
|
| 18 |
@staticmethod
|
| 19 |
+
def _extract_pdf(file_path: Path, first_page_only: bool = False) -> str:
|
| 20 |
reader = PdfReader(str(file_path))
|
| 21 |
chunks: list[str] = []
|
| 22 |
+
pages = reader.pages[:1] if first_page_only else reader.pages
|
| 23 |
+
for page in pages:
|
| 24 |
text = page.extract_text() or ""
|
| 25 |
if text.strip():
|
| 26 |
chunks.append(text)
|
|
|
|
| 42 |
workbook.close()
|
| 43 |
return "\n".join(chunks)
|
| 44 |
|
| 45 |
+
def extract_text(self, file_name: str, file_path: Path, pdf_first_page_only: bool = False) -> str:
|
| 46 |
extension = Path(file_name).suffix.lower()
|
| 47 |
|
| 48 |
try:
|
| 49 |
if extension in DOC_EXTENSIONS:
|
| 50 |
if extension == ".pdf":
|
| 51 |
+
return self._extract_pdf(file_path, first_page_only=pdf_first_page_only)
|
| 52 |
if extension == ".docx":
|
| 53 |
return self._extract_docx(file_path)
|
| 54 |
if extension == ".xlsx":
|
app/services/language_service.py
CHANGED
|
@@ -1,41 +1,25 @@
|
|
| 1 |
-
import
|
| 2 |
|
| 3 |
-
from app.core.config import settings
|
| 4 |
from app.core.exceptions import LanguageDetectionError
|
| 5 |
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
class LanguageService:
|
| 8 |
-
def __init__(self) -> None:
|
| 9 |
-
self._session = requests.Session()
|
| 10 |
|
|
|
|
| 11 |
def detect_language(self, text: str) -> str:
|
| 12 |
try:
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
)
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
raise LanguageDetectionError("Language detection request failed") from exc
|
| 22 |
-
except ValueError as exc:
|
| 23 |
-
raise LanguageDetectionError("Language detector returned invalid JSON") from exc
|
| 24 |
-
|
| 25 |
-
data = payload.get("data") if isinstance(payload, dict) else None
|
| 26 |
-
if not data or not isinstance(data, list):
|
| 27 |
-
raise LanguageDetectionError("Language detector response missing 'data' field")
|
| 28 |
-
|
| 29 |
-
first = data[0]
|
| 30 |
-
if isinstance(first, dict):
|
| 31 |
-
label = first.get("label")
|
| 32 |
-
else:
|
| 33 |
-
label = first
|
| 34 |
-
|
| 35 |
-
if not isinstance(label, str) or not label.strip():
|
| 36 |
raise LanguageDetectionError("Language detector did not return a valid label")
|
| 37 |
|
| 38 |
-
return
|
| 39 |
|
| 40 |
|
| 41 |
language_service = LanguageService()
|
|
|
|
| 1 |
+
from langdetect import DetectorFactory, LangDetectException, detect
|
| 2 |
|
|
|
|
| 3 |
from app.core.exceptions import LanguageDetectionError
|
| 4 |
|
| 5 |
+
# Ensure deterministic language detection outcomes across runs.
|
| 6 |
+
DetectorFactory.seed = 0
|
| 7 |
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
class LanguageService:
|
| 10 |
def detect_language(self, text: str) -> str:
|
| 11 |
try:
|
| 12 |
+
language = detect(text)
|
| 13 |
+
except LangDetectException as exc:
|
| 14 |
+
raise LanguageDetectionError("Language detection failed") from exc
|
| 15 |
+
except Exception as exc:
|
| 16 |
+
raise LanguageDetectionError("Language detector raised an unexpected error") from exc
|
| 17 |
+
|
| 18 |
+
normalized_language = language.split("-", 1)[0].strip().lower() if isinstance(language, str) else ""
|
| 19 |
+
if not normalized_language:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
raise LanguageDetectionError("Language detector did not return a valid label")
|
| 21 |
|
| 22 |
+
return normalized_language
|
| 23 |
|
| 24 |
|
| 25 |
language_service = LanguageService()
|
docker-compose.yml
CHANGED
|
@@ -10,11 +10,9 @@ services:
|
|
| 10 |
DEBUG: ${DEBUG:-false}
|
| 11 |
STATIC_DIR: ${STATIC_DIR:-static}
|
| 12 |
UPLOAD_SUBDIR: ${UPLOAD_SUBDIR:-uploads}
|
| 13 |
-
|
| 14 |
-
|
| 15 |
HUGGINGFACE_TOKEN: ${HUGGINGFACE_TOKEN:-}
|
| 16 |
-
LANGUAGE_DETECTOR_URL: ${LANGUAGE_DETECTOR_URL:-https://team-language-detector-languagedetector.hf.space/run/predict}
|
| 17 |
-
REQUEST_TIMEOUT_SECONDS: ${REQUEST_TIMEOUT_SECONDS:-30}
|
| 18 |
DEFAULT_LABELS_CSV: ${DEFAULT_LABELS_CSV:-news,sport,finance,politics}
|
| 19 |
ports:
|
| 20 |
- "7860:7860"
|
|
|
|
| 10 |
DEBUG: ${DEBUG:-false}
|
| 11 |
STATIC_DIR: ${STATIC_DIR:-static}
|
| 12 |
UPLOAD_SUBDIR: ${UPLOAD_SUBDIR:-uploads}
|
| 13 |
+
CLASSIFIER_MODEL: ${CLASSIFIER_MODEL:-AyoubChLin/bert-base-uncased-zeroshot-nli}
|
| 14 |
+
ENABLE_MODEL_QUANTIZATION: ${ENABLE_MODEL_QUANTIZATION:-true}
|
| 15 |
HUGGINGFACE_TOKEN: ${HUGGINGFACE_TOKEN:-}
|
|
|
|
|
|
|
| 16 |
DEFAULT_LABELS_CSV: ${DEFAULT_LABELS_CSV:-news,sport,finance,politics}
|
| 17 |
ports:
|
| 18 |
- "7860:7860"
|
docs/explanation/architecture.md
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
# Architecture Explanation
|
| 2 |
|
| 3 |
## 1. Executive summary
|
| 4 |
-
`classifier-general` is a single FastAPI service that classifies text and files
|
| 5 |
|
| 6 |
Evidence:
|
| 7 |
- `app/main.py`
|
|
@@ -40,8 +40,7 @@ Evidence:
|
|
| 40 |
### Context view
|
| 41 |
Actors/systems:
|
| 42 |
- API client sending text/files.
|
| 43 |
-
-
|
| 44 |
-
- External language detector endpoint (`LANGUAGE_DETECTOR_URL`).
|
| 45 |
- Local filesystem for uploaded files.
|
| 46 |
|
| 47 |
Evidence:
|
|
@@ -67,9 +66,10 @@ Evidence:
|
|
| 67 |
1. `POST /classify` receives file.
|
| 68 |
2. File saved to upload directory.
|
| 69 |
3. Text extracted by extension-specific handlers.
|
|
|
|
| 70 |
4. Text preprocessed (regex cleanup + min words).
|
| 71 |
-
5.
|
| 72 |
-
6.
|
| 73 |
7. Response returns `{label, language}` plus `type=not english` when applicable.
|
| 74 |
|
| 75 |
Evidence:
|
|
@@ -82,7 +82,7 @@ Evidence:
|
|
| 82 |
## 4. Cross-cutting concerns
|
| 83 |
### Validation and error mapping
|
| 84 |
- Input schemas use strict `extra=forbid`.
|
| 85 |
-
- Error mapping explicitly separates validation/extraction (400) from
|
| 86 |
|
| 87 |
Evidence:
|
| 88 |
- `app/schemas/classification.py`
|
|
@@ -112,7 +112,7 @@ Evidence:
|
|
| 112 |
- `tests/test_routes.py`
|
| 113 |
|
| 114 |
## 5. Risks, gaps, and technical debt
|
| 115 |
-
-
|
| 116 |
- No upload retention/cleanup process.
|
| 117 |
- Readiness check does not probe external AI services, only local label readiness.
|
| 118 |
- No authentication/authorization layer on API endpoints.
|
|
|
|
| 1 |
# Architecture Explanation
|
| 2 |
|
| 3 |
## 1. Executive summary
|
| 4 |
+
`classifier-general` is a single FastAPI service that classifies text and files with local extraction/preprocessing, a local Hugging Face zero-shot NLI model, and local language detection.
|
| 5 |
|
| 6 |
Evidence:
|
| 7 |
- `app/main.py`
|
|
|
|
| 40 |
### Context view
|
| 41 |
Actors/systems:
|
| 42 |
- API client sending text/files.
|
| 43 |
+
- Hugging Face model hub (model download/auth when needed).
|
|
|
|
| 44 |
- Local filesystem for uploaded files.
|
| 45 |
|
| 46 |
Evidence:
|
|
|
|
| 66 |
1. `POST /classify` receives file.
|
| 67 |
2. File saved to upload directory.
|
| 68 |
3. Text extracted by extension-specific handlers.
|
| 69 |
+
- For `/classify`, PDF extraction is first-page only.
|
| 70 |
4. Text preprocessed (regex cleanup + min words).
|
| 71 |
+
5. Local language detector called.
|
| 72 |
+
6. Zero-shot NLI classifier scores runtime labels and selects top label.
|
| 73 |
7. Response returns `{label, language}` plus `type=not english` when applicable.
|
| 74 |
|
| 75 |
Evidence:
|
|
|
|
| 82 |
## 4. Cross-cutting concerns
|
| 83 |
### Validation and error mapping
|
| 84 |
- Input schemas use strict `extra=forbid`.
|
| 85 |
+
- Error mapping explicitly separates validation/extraction (400) from classifier/language inference failures (502).
|
| 86 |
|
| 87 |
Evidence:
|
| 88 |
- `app/schemas/classification.py`
|
|
|
|
| 112 |
- `tests/test_routes.py`
|
| 113 |
|
| 114 |
## 5. Risks, gaps, and technical debt
|
| 115 |
+
- Local model initialization can fail if model/token/resources are invalid.
|
| 116 |
- No upload retention/cleanup process.
|
| 117 |
- Readiness check does not probe external AI services, only local label readiness.
|
| 118 |
- No authentication/authorization layer on API endpoints.
|
docs/explanation/decisions.md
CHANGED
|
@@ -21,23 +21,22 @@
|
|
| 21 |
- Rationale:
|
| 22 |
- Keep clients functional while refactoring internals.
|
| 23 |
|
| 24 |
-
## ADR-003: Use
|
| 25 |
- Status: Accepted
|
| 26 |
- Type: Explicit
|
| 27 |
- Evidence:
|
| 28 |
- `app/core/config.py`
|
| 29 |
- `app/services/classifier_service.py`
|
| 30 |
- Rationale:
|
| 31 |
-
-
|
| 32 |
|
| 33 |
-
## ADR-004: Use
|
| 34 |
- Status: Accepted
|
| 35 |
- Type: Explicit
|
| 36 |
- Evidence:
|
| 37 |
- `app/services/language_service.py`
|
| 38 |
-
- `app/core/config.py`
|
| 39 |
- Rationale:
|
| 40 |
-
-
|
| 41 |
|
| 42 |
## ADR-005: Keep labels in in-memory mutable config
|
| 43 |
- Status: Accepted (current), Needs review
|
|
@@ -68,7 +67,7 @@
|
|
| 68 |
- `app/routers/classification.py`
|
| 69 |
- `app/core/exceptions.py`
|
| 70 |
- Rationale:
|
| 71 |
-
- Differentiate local validation issues (`400`) from
|
| 72 |
|
| 73 |
## ADR-008: No built-in auth layer for this API
|
| 74 |
- Status: Accepted (current), Needs review
|
|
|
|
| 21 |
- Rationale:
|
| 22 |
- Keep clients functional while refactoring internals.
|
| 23 |
|
| 24 |
+
## ADR-003: Use local Hugging Face zero-shot NLI model for classification
|
| 25 |
- Status: Accepted
|
| 26 |
- Type: Explicit
|
| 27 |
- Evidence:
|
| 28 |
- `app/core/config.py`
|
| 29 |
- `app/services/classifier_service.py`
|
| 30 |
- Rationale:
|
| 31 |
+
- Perform true runtime-label zero-shot classification with local inference control.
|
| 32 |
|
| 33 |
+
## ADR-004: Use local `langdetect` library for language detection
|
| 34 |
- Status: Accepted
|
| 35 |
- Type: Explicit
|
| 36 |
- Evidence:
|
| 37 |
- `app/services/language_service.py`
|
|
|
|
| 38 |
- Rationale:
|
| 39 |
+
- Remove external dependency and keep language inference local.
|
| 40 |
|
| 41 |
## ADR-005: Keep labels in in-memory mutable config
|
| 42 |
- Status: Accepted (current), Needs review
|
|
|
|
| 67 |
- `app/routers/classification.py`
|
| 68 |
- `app/core/exceptions.py`
|
| 69 |
- Rationale:
|
| 70 |
+
- Differentiate local validation issues (`400`) from inference failures (`502`).
|
| 71 |
|
| 72 |
## ADR-008: No built-in auth layer for this API
|
| 73 |
- Status: Accepted (current), Needs review
|
docs/how-to/deploy-with-docker-compose.md
CHANGED
|
@@ -22,7 +22,7 @@ curl -s http://localhost:4002/health/liveness
|
|
| 22 |
|
| 23 |
## Production hardening gaps
|
| 24 |
- No reverse proxy/TLS config in this repo.
|
| 25 |
-
-
|
| 26 |
- No horizontal scaling coordination for in-memory labels (`/configlabel` mutates process-local state).
|
| 27 |
|
| 28 |
Evidence:
|
|
|
|
| 22 |
|
| 23 |
## Production hardening gaps
|
| 24 |
- No reverse proxy/TLS config in this repo.
|
| 25 |
+
- Initial model pull can require network access if the HF cache is cold.
|
| 26 |
- No horizontal scaling coordination for in-memory labels (`/configlabel` mutates process-local state).
|
| 27 |
|
| 28 |
Evidence:
|
docs/how-to/run-locally.md
CHANGED
|
@@ -17,9 +17,8 @@ cp .env.example .env
|
|
| 17 |
```
|
| 18 |
|
| 19 |
Critical settings:
|
| 20 |
-
- `
|
| 21 |
-
- `
|
| 22 |
-
- `LANGUAGE_DETECTOR_URL`
|
| 23 |
- `DEFAULT_LABELS_CSV`
|
| 24 |
|
| 25 |
Evidence:
|
|
@@ -54,7 +53,7 @@ Evidence:
|
|
| 54 |
- `400 Text must contain at least 4 words`:
|
| 55 |
- input failed preprocessing minimum-word rule.
|
| 56 |
- `502 Classifier request failed`:
|
| 57 |
-
-
|
| 58 |
- OCR extraction quality is low:
|
| 59 |
- verify tesseract install and image quality.
|
| 60 |
|
|
|
|
| 17 |
```
|
| 18 |
|
| 19 |
Critical settings:
|
| 20 |
+
- `CLASSIFIER_MODEL`
|
| 21 |
+
- `ENABLE_MODEL_QUANTIZATION`
|
|
|
|
| 22 |
- `DEFAULT_LABELS_CSV`
|
| 23 |
|
| 24 |
Evidence:
|
|
|
|
| 53 |
- `400 Text must contain at least 4 words`:
|
| 54 |
- input failed preprocessing minimum-word rule.
|
| 55 |
- `502 Classifier request failed`:
|
| 56 |
+
- local model load or prediction failed (model ID/token/resource issue).
|
| 57 |
- OCR extraction quality is low:
|
| 58 |
- verify tesseract install and image quality.
|
| 59 |
|
docs/reference/api.md
CHANGED
|
@@ -23,7 +23,7 @@ Evidence:
|
|
| 23 |
|
| 24 |
## Validation and errors
|
| 25 |
- `400` for input validation and extraction problems.
|
| 26 |
-
- `502` for
|
| 27 |
- `500` for unexpected failures.
|
| 28 |
|
| 29 |
Evidence:
|
|
|
|
| 23 |
|
| 24 |
## Validation and errors
|
| 25 |
- `400` for input validation and extraction problems.
|
| 26 |
+
- `502` for classifier/language inference failures.
|
| 27 |
- `500` for unexpected failures.
|
| 28 |
|
| 29 |
Evidence:
|
docs/reference/configuration.md
CHANGED
|
@@ -26,16 +26,15 @@ Evidence:
|
|
| 26 |
|
| 27 |
| Variable | Default | Purpose |
|
| 28 |
|---|---|---|
|
| 29 |
-
| `
|
| 30 |
-
| `
|
| 31 |
| `HUGGINGFACE_TOKEN` | empty | optional auth token for client init |
|
| 32 |
|
| 33 |
## Language detector settings
|
| 34 |
|
| 35 |
| Variable | Default | Purpose |
|
| 36 |
|---|---|---|
|
| 37 |
-
|
|
| 38 |
-
| `REQUEST_TIMEOUT_SECONDS` | `30` | HTTP timeout for language requests |
|
| 39 |
|
| 40 |
## Label settings
|
| 41 |
|
|
|
|
| 26 |
|
| 27 |
| Variable | Default | Purpose |
|
| 28 |
|---|---|---|
|
| 29 |
+
| `CLASSIFIER_MODEL` | `AyoubChLin/bert-base-uncased-zeroshot-nli` | Hugging Face model ID used for local zero-shot NLI classification |
|
| 30 |
+
| `ENABLE_MODEL_QUANTIZATION` | `true` | enable dynamic INT8 quantization with automatic fallback |
|
| 31 |
| `HUGGINGFACE_TOKEN` | empty | optional auth token for client init |
|
| 32 |
|
| 33 |
## Language detector settings
|
| 34 |
|
| 35 |
| Variable | Default | Purpose |
|
| 36 |
|---|---|---|
|
| 37 |
+
| none | n/a | language detection now uses local `langdetect` |
|
|
|
|
| 38 |
|
| 39 |
## Label settings
|
| 40 |
|
docs/reference/runtime-state.md
CHANGED
|
@@ -33,8 +33,8 @@ Evidence:
|
|
| 33 |
|
| 34 |
| Dependency | Usage |
|
| 35 |
|---|---|
|
| 36 |
-
|
|
| 37 |
-
|
|
| 38 |
| Tesseract binary | OCR extraction for images |
|
| 39 |
|
| 40 |
Evidence:
|
|
|
|
| 33 |
|
| 34 |
| Dependency | Usage |
|
| 35 |
|---|---|
|
| 36 |
+
| Hugging Face model hub | model download/auth for local classifier initialization |
|
| 37 |
+
| `langdetect` library | local language inference |
|
| 38 |
| Tesseract binary | OCR extraction for images |
|
| 39 |
|
| 40 |
Evidence:
|
docs/tutorials/getting-started.md
CHANGED
|
@@ -4,7 +4,7 @@ This tutorial runs the classifier API and validates endpoint contracts.
|
|
| 4 |
|
| 5 |
## Prerequisites
|
| 6 |
- Docker and Docker Compose
|
| 7 |
-
- Internet access for
|
| 8 |
|
| 9 |
Evidence:
|
| 10 |
- `docker-compose.yml`
|
|
|
|
| 4 |
|
| 5 |
## Prerequisites
|
| 6 |
- Docker and Docker Compose
|
| 7 |
+
- Internet access for initial Hugging Face model download (unless model is already cached)
|
| 8 |
|
| 9 |
Evidence:
|
| 10 |
- `docker-compose.yml`
|
requirements.txt
CHANGED
|
@@ -2,8 +2,8 @@ fastapi==0.115.8
|
|
| 2 |
uvicorn[standard]==0.34.0
|
| 3 |
pydantic==2.10.6
|
| 4 |
pydantic-settings==2.7.1
|
| 5 |
-
requests==2.32.3
|
| 6 |
python-multipart==0.0.20
|
|
|
|
| 7 |
|
| 8 |
transformers==4.46.0
|
| 9 |
torch==2.5.1
|
|
|
|
| 2 |
uvicorn[standard]==0.34.0
|
| 3 |
pydantic==2.10.6
|
| 4 |
pydantic-settings==2.7.1
|
|
|
|
| 5 |
python-multipart==0.0.20
|
| 6 |
+
langdetect==1.0.9
|
| 7 |
|
| 8 |
transformers==4.46.0
|
| 9 |
torch==2.5.1
|
tests/test_classification_pipeline_behavior.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
from app.pipelines.classification_pipeline import classification_pipeline
|
| 4 |
+
import app.pipelines.classification_pipeline as pipeline_module
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def test_classify_file_uses_pdf_first_page_only(monkeypatch):
|
| 8 |
+
extraction_flags: list[bool] = []
|
| 9 |
+
|
| 10 |
+
def _fake_extract_text(file_name, file_path, pdf_first_page_only=False):
|
| 11 |
+
extraction_flags.append(pdf_first_page_only)
|
| 12 |
+
return "This is enough content for preprocessing and classification."
|
| 13 |
+
|
| 14 |
+
monkeypatch.setattr(pipeline_module.extraction_service, "extract_text", _fake_extract_text)
|
| 15 |
+
monkeypatch.setattr(pipeline_module.language_service, "detect_language", lambda text: "en")
|
| 16 |
+
monkeypatch.setattr(pipeline_module.label_service, "get_labels", lambda: ["news", "sport"])
|
| 17 |
+
monkeypatch.setattr(pipeline_module.classifier_service, "classify", lambda text, labels: "news")
|
| 18 |
+
|
| 19 |
+
result = classification_pipeline.classify_file("sample.pdf", Path("sample.pdf"))
|
| 20 |
+
|
| 21 |
+
assert extraction_flags == [True]
|
| 22 |
+
assert result == {"label": "news", "language": "en"}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def test_transform_file_uses_full_extraction(monkeypatch):
|
| 26 |
+
extraction_flags: list[bool] = []
|
| 27 |
+
|
| 28 |
+
def _fake_extract_text(file_name, file_path, pdf_first_page_only=False):
|
| 29 |
+
extraction_flags.append(pdf_first_page_only)
|
| 30 |
+
return "This is full extracted content."
|
| 31 |
+
|
| 32 |
+
monkeypatch.setattr(pipeline_module.extraction_service, "extract_text", _fake_extract_text)
|
| 33 |
+
|
| 34 |
+
content = classification_pipeline.transform_file("sample.pdf", Path("sample.pdf"))
|
| 35 |
+
|
| 36 |
+
assert extraction_flags == [False]
|
| 37 |
+
assert content == "This is full extracted content."
|
tests/test_classifier_service.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from types import SimpleNamespace
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
import app.services.classifier_service as classifier_module
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class _FakeTokenizer:
|
| 9 |
+
def __call__(self, sequence_pairs, padding, truncation, return_tensors):
|
| 10 |
+
batch_size = len(sequence_pairs)
|
| 11 |
+
return {
|
| 12 |
+
"input_ids": torch.ones((batch_size, 2), dtype=torch.long),
|
| 13 |
+
"attention_mask": torch.ones((batch_size, 2), dtype=torch.long),
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class _FakeInferenceModel:
|
| 18 |
+
def __init__(self, logits: torch.Tensor) -> None:
|
| 19 |
+
self._logits = logits
|
| 20 |
+
self.config = SimpleNamespace(
|
| 21 |
+
label2id={"CONTRADICTION": 0, "ENTAILMENT": 1},
|
| 22 |
+
id2label={0: "CONTRADICTION", 1: "ENTAILMENT"},
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
def __call__(self, **kwargs):
|
| 26 |
+
return SimpleNamespace(logits=self._logits)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class _FakeLoadModel:
|
| 30 |
+
def __init__(self) -> None:
|
| 31 |
+
self.config = SimpleNamespace(
|
| 32 |
+
label2id={"CONTRADICTION": 0, "ENTAILMENT": 1},
|
| 33 |
+
id2label={0: "CONTRADICTION", 1: "ENTAILMENT"},
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def eval(self):
|
| 37 |
+
return self
|
| 38 |
+
|
| 39 |
+
def to(self, device):
|
| 40 |
+
return self
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def test_classify_uses_runtime_candidate_labels(monkeypatch):
|
| 44 |
+
service = classifier_module.ClassifierService()
|
| 45 |
+
tokenizer = _FakeTokenizer()
|
| 46 |
+
model = _FakeInferenceModel(
|
| 47 |
+
logits=torch.tensor(
|
| 48 |
+
[
|
| 49 |
+
[3.2, 0.4], # finance -> low entailment
|
| 50 |
+
[0.3, 4.1], # sport -> highest entailment
|
| 51 |
+
[1.5, 1.9], # politics -> second-best entailment
|
| 52 |
+
]
|
| 53 |
+
)
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
monkeypatch.setattr(service, "_load_model", lambda: (tokenizer, model))
|
| 57 |
+
|
| 58 |
+
predicted = service.classify(
|
| 59 |
+
"This article discusses the latest football transfer strategy.",
|
| 60 |
+
["finance", "sport", "politics"],
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
assert predicted == "sport"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def test_model_quantization_falls_back_to_non_quantized_model(monkeypatch):
|
| 67 |
+
service = classifier_module.ClassifierService()
|
| 68 |
+
fake_model = _FakeLoadModel()
|
| 69 |
+
fake_tokenizer = object()
|
| 70 |
+
|
| 71 |
+
monkeypatch.setattr(
|
| 72 |
+
classifier_module.AutoTokenizer,
|
| 73 |
+
"from_pretrained",
|
| 74 |
+
lambda *args, **kwargs: fake_tokenizer,
|
| 75 |
+
)
|
| 76 |
+
monkeypatch.setattr(
|
| 77 |
+
classifier_module.AutoModelForSequenceClassification,
|
| 78 |
+
"from_pretrained",
|
| 79 |
+
lambda *args, **kwargs: fake_model,
|
| 80 |
+
)
|
| 81 |
+
monkeypatch.setattr(classifier_module.settings, "enable_model_quantization", True)
|
| 82 |
+
|
| 83 |
+
def _raise_quantization_error(*args, **kwargs):
|
| 84 |
+
raise RuntimeError("quantization backend unavailable")
|
| 85 |
+
|
| 86 |
+
monkeypatch.setattr(
|
| 87 |
+
classifier_module.torch.ao.quantization,
|
| 88 |
+
"quantize_dynamic",
|
| 89 |
+
_raise_quantization_error,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
_, loaded_model = service._load_model()
|
| 93 |
+
|
| 94 |
+
assert loaded_model is fake_model
|
tests/test_language_service.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
from app.core.exceptions import LanguageDetectionError
|
| 4 |
+
import app.services.language_service as language_module
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def test_detect_language_returns_en_for_english_and_non_en_for_french():
|
| 8 |
+
service = language_module.LanguageService()
|
| 9 |
+
|
| 10 |
+
english_text = "This is a detailed English sentence about technology trends and financial markets."
|
| 11 |
+
french_text = "Ceci est une phrase francaise detaillee sur les tendances technologiques et les marches financiers."
|
| 12 |
+
|
| 13 |
+
assert service.detect_language(english_text) == "en"
|
| 14 |
+
assert service.detect_language(french_text) != "en"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def test_detect_language_raises_for_invalid_detector_output(monkeypatch):
|
| 18 |
+
service = language_module.LanguageService()
|
| 19 |
+
monkeypatch.setattr(language_module, "detect", lambda text: "")
|
| 20 |
+
|
| 21 |
+
with pytest.raises(LanguageDetectionError, match="did not return a valid label"):
|
| 22 |
+
service.detect_language("This text is long enough for language detection.")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def test_detect_language_wraps_unexpected_detector_errors(monkeypatch):
|
| 26 |
+
service = language_module.LanguageService()
|
| 27 |
+
|
| 28 |
+
def _raise_error(_: str):
|
| 29 |
+
raise RuntimeError("unexpected detector failure")
|
| 30 |
+
|
| 31 |
+
monkeypatch.setattr(language_module, "detect", _raise_error)
|
| 32 |
+
|
| 33 |
+
with pytest.raises(LanguageDetectionError, match="unexpected error"):
|
| 34 |
+
service.detect_language("This text is long enough for language detection.")
|