diff --git a/README.md b/README.md index 449b8339c0f1d53e86c39d8dd173a7128b49d068..5e8b6e6c6483170e5b166a0cc72a2cd095adee6e 100644 --- a/README.md +++ b/README.md @@ -1,175 +1,67 @@ -# LexiMind: Multi-Task Transformer for Document Analysis +# LexiMind (Inference Edition) -A PyTorch-based multi-task learning system that performs abstractive summarization, emotion classification, and topic clustering on textual data using a shared Transformer encoder architecture. +LexiMind now ships as a focused inference sandbox for the custom multitask Transformer found in +`src/models`. Training, dataset downloaders, and legacy scripts have been removed so it is easy to +load a checkpoint, run the Streamlit demo, and experiment with summarization, emotion +classification, and topic cues on your own text. -## 🎯 Project Overview +## What Stays +- Transformer encoder/decoder and task heads under `src/models` +- Unit tests for the model stack (`tests/test_models`) +- Streamlit UI (`src/ui/streamlit_app.py`) wired to the inference helpers in `src/api/inference` -LexiMind demonstrates multi-task learning (MTL) by training a single model to simultaneously: -1. **Abstractive Summarization**: Generate concise summaries with user-defined compression levels -2. **Emotion Classification**: Detect multiple emotions present in text (multi-label classification) -3. **Topic Clustering**: Group documents by semantic similarity for topic discovery +## What Changed +- Hugging Face tokenizers provide all tokenization (see `TextPreprocessor`) +- Training, dataset downloaders, and CLI scripts have been removed +- Scikit-learn powers light text normalization (stop-word removal optional) +- Requirements trimmed to inference-only dependencies -### Key Features -- Custom encoder-decoder Transformer architecture with shared representations -- Multi-task loss function with learnable task weighting -- Attention weight visualization for model interpretability -- Interactive web interface for real-time inference -- Trained on diverse corpora: news articles (CNN/DailyMail, BBC) and literary texts (Project Gutenberg) - -## πŸ—οΈ Architecture - -``` -Input Text - ↓ -β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” -β”‚ Shared Encoder β”‚ ← TransformerEncoder (6 layers) -β”‚ (Multi-head Attn) β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ - ↓ ↓ ↓ - β”‚ β”‚ └──────────────┐ - β”‚ β”‚ β”‚ - β”‚ └─────────┐ β”‚ - β”‚ β”‚ β”‚ - ↓ ↓ ↓ -β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β” -β”‚ Decoder β”‚ β”‚Classifyβ”‚ β”‚ Project β”‚ -β”‚ Head β”‚ β”‚ Head β”‚ β”‚ Head β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ - ↓ ↓ ↓ -Summary Emotions Embeddings - (for clustering) -``` - -## πŸ“Š Datasets - -- **CNN/DailyMail**: 300k+ news articles with human-written summaries -- **BBC News**: 2,225 articles across 5 categories -- **Project Gutenberg**: Classic literature for long-form text analysis - -## πŸš€ Quick Start - -### Installation +## Quick Start ```bash git clone https://github.com/OliverPerrin/LexiMind.git cd LexiMind pip install -r requirements.txt -``` -### Download Data -```bash -python src/download_datasets.py -``` +# Optional extras via setup.py packaging metadata +pip install .[web] # installs streamlit + plotly +pip install .[api] # installs fastapi +pip install .[all] # installs both groups -### Train Model -```bash -python src/train.py --config configs/default.yaml -``` - -### Launch Interface -```bash -python src/app.py +streamlit run src/ui/streamlit_app.py ``` -## πŸ“ Project Structure +Configure the Streamlit app via the sidebar to point at your tokenizer directory and model +checkpoint (defaults assume `artifacts/hf_tokenizer` and `checkpoints/best.pt`). +## Minimal Project Map ``` -LexiMind/ -β”œβ”€β”€ src/ -β”‚ β”œβ”€β”€ models/ -β”‚ β”‚ β”œβ”€β”€ encoder.py # Shared Transformer encoder -β”‚ β”‚ β”œβ”€β”€ summarization.py # Seq2seq decoder head -β”‚ β”‚ β”œβ”€β”€ emotion.py # Multi-label classification head -β”‚ β”‚ └── clustering.py # Projection head for embeddings -β”‚ β”œβ”€β”€ data/ -β”‚ β”‚ β”œβ”€β”€ download_datasets.py # Data acquisition -β”‚ β”‚ β”œβ”€β”€ preprocessing.py # Text cleaning & tokenization -β”‚ β”‚ └── dataset.py # PyTorch Dataset classes -β”‚ β”œβ”€β”€ training/ -β”‚ β”‚ β”œβ”€β”€ train.py # Training loop -β”‚ β”‚ β”œβ”€β”€ losses.py # Multi-task loss functions -β”‚ β”‚ └── metrics.py # ROUGE, F1, silhouette scores -β”‚ β”œβ”€β”€ inference/ -β”‚ β”‚ └── pipeline.py # End-to-end inference -β”‚ β”œβ”€β”€ visualization/ -β”‚ β”‚ └── attention.py # Attention heatmap generation -β”‚ └── app.py # Gradio/FastAPI interface -β”œβ”€β”€ configs/ -β”‚ └── default.yaml # Model & training hyperparameters -β”œβ”€β”€ tests/ -β”‚ └── test_*.py # Unit tests -β”œβ”€β”€ notebooks/ -β”‚ └── exploratory.ipynb # Data exploration & analysis -β”œβ”€β”€ requirements.txt -└── README.md +src/ +β”œβ”€β”€ api/ # load_models + helpers +β”œβ”€β”€ data/ # TextPreprocessor using Hugging Face + sklearn +β”œβ”€β”€ inference/ # thin summarizer facade +β”œβ”€β”€ models/ # core Transformer architecture (untouched) +└── ui/ # Streamlit interface ``` -## πŸ§ͺ Evaluation Metrics +Everything outside `src/` now holds optional assets such as checkpoints, tokenizer exports, and +documentation stubs. -| Task | Metric | Score | -|------|--------|-------| -| Summarization | ROUGE-1 / ROUGE-L | TBD | -| Emotion Classification | Macro F1 | TBD | -| Topic Clustering | Silhouette Score | TBD | +## Loading a Checkpoint Programmatically +```python +from src.api.inference import load_models, summarize_text -## πŸ”¬ Technical Details +models = load_models({ + "checkpoint_path": "checkpoints/best.pt", + "tokenizer_path": "artifacts/hf_tokenizer", + "hf_tokenizer_name": "facebook/bart-base", +}) -### Model Specifications -- **Encoder**: 6-layer Transformer (d_model=512, 8 attention heads) -- **Decoder**: 6-layer autoregressive Transformer -- **Vocab Size**: 32,000 (SentencePiece tokenizer) -- **Parameters**: ~60M total - -### Training -- **Optimizer**: AdamW (lr=1e-4, weight_decay=0.01) -- **Scheduler**: Linear warmup (5000 steps) + cosine decay -- **Loss**: Weighted sum of cross-entropy (summarization), BCE (emotions), triplet loss (clustering) -- **Hardware**: Trained on single NVIDIA RTX 3090 (24GB VRAM) -- **Time**: ~48 hours for 10 epochs - -### Multi-Task Learning Strategy -Uses uncertainty weighting ([Kendall et al., 2018](https://arxiv.org/abs/1705.07115)) to automatically balance task losses: - -``` -L_total = Ξ£ (1/2σ²_i * L_i + log(Οƒ_i)) +summary, _ = summarize_text("Paste any article here.", models=models) +print(summary) ``` -where Οƒ_i are learnable parameters representing task uncertainty. - -## 🎨 Interface Preview - -The web interface provides: -- Text input with real-time token count -- Compression level slider (20%-80%) -- Side-by-side original/summary comparison -- Emotion probability bars with color coding -- Interactive attention heatmap (click tokens to highlight attention) -- Downloadable results (JSON/CSV) - -## πŸ“ˆ Future Enhancements - -- [ ] Add multilingual support (mBART) -- [ ] Implement beam search for better summaries -- [ ] Fine-tune on domain-specific corpora (medical, legal) -- [ ] Add semantic search across document embeddings -- [ ] Deploy as REST API with Docker -- [ ] Implement model distillation for mobile deployment - -## πŸ“š References - -- Vaswani et al. (2017) - [Attention Is All You Need](https://arxiv.org/abs/1706.03762) -- Lewis et al. (2019) - [BART: Denoising Sequence-to-Sequence Pre-training](https://arxiv.org/abs/1910.13461) -- Caruana (1997) - [Multitask Learning](https://link.springer.com/article/10.1023/A:1007379606734) -- Demszky et al. (2020) - [GoEmotions Dataset](https://arxiv.org/abs/2005.00547) - -## πŸ“„ License - -GNU General Public License v3.0 - -## πŸ‘€ Author - -**Oliver Perrin** -- Portfolio: [oliverperrin.com](https://oliverperrin.com) -- LinkedIn: [linkedin.com/in/oliverperrin](https://linkedin.com/in/oliverperrin) -- Email: oliver.t.perrin@gmail.com +## License +GPL-3.0 ---- +## Author +Oliver Perrin Β· oliver.t.perrin@gmail.com diff --git a/configs/data/datasets.yaml b/configs/data/datasets.yaml index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..efe7b8377693ab85f58512fdacdd0d0f6ee26497 100644 --- a/configs/data/datasets.yaml +++ b/configs/data/datasets.yaml @@ -0,0 +1,26 @@ +raw: + summarization: data/raw/summarization/cnn_dailymail + emotion: data/raw/emotion + topic: data/raw/topic + books: data/raw/books +processed: + summarization: data/processed/summarization + emotion: data/processed/emotion + topic: data/processed/topic + books: data/processed/books +tokenizer: + pretrained_model_name: facebook/bart-base + max_length: 512 + lower: false +downloads: + summarization: + dataset: gowrishankarp/newspaper-text-summarization-cnn-dailymail + output: data/raw/summarization/cnn_dailymail + books: + - name: pride_and_prejudice + url: https://www.gutenberg.org/cache/epub/1342/pg1342.txt + output: data/raw/books/pride_and_prejudice.txt + emotion: + dataset: dair-ai/emotion + topic: + dataset: ag_news diff --git a/configs/model/base.yaml b/configs/model/base.yaml index 01594b2779e9fa7aeff6c30ea7b8f87c4192e542..d4a918e37ca9e0260d75054d907144b4cfb72291 100644 --- a/configs/model/base.yaml +++ b/configs/model/base.yaml @@ -1,50 +1,6 @@ -model: - vocab_size: 32000 - d_model: 512 - num_encoder_layers: 6 - num_decoder_layers: 6 - num_heads: 8 - d_ff: 2048 - dropout: 0.1 - max_seq_length: 512 - -tasks: - summarization: - enabled: true - decoder_layers: 6 - - emotion: - enabled: true - num_classes: 27 - pool_strategy: "mean" # Options: mean, max, cls, attention - - clustering: - enabled: true - embedding_dim: 128 - normalize: true - -training: - batch_size: 16 - gradient_accumulation_steps: 2 # Effective batch = 32 - learning_rate: 1e-4 - weight_decay: 0.01 - num_epochs: 10 - warmup_steps: 1000 - max_grad_norm: 1.0 - - scheduler: - type: "cosine" # Options: linear, cosine, polynomial - - mixed_precision: true # Use AMP for faster training - -data: - max_length: 512 - summary_max_length: 128 - train_split: 0.8 - val_split: 0.1 - test_split: 0.1 - - preprocessing: - lowercase: true - remove_stopwords: false - min_token_length: 3 \ No newline at end of file +d_model: 512 +num_encoder_layers: 6 +num_decoder_layers: 6 +num_attention_heads: 8 +ffn_dim: 2048 +dropout: 0.1 diff --git a/configs/model/large.yaml b/configs/model/large.yaml index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..16bb9741555dd3ae9953fe444eb96d48e8d30d25 100644 --- a/configs/model/large.yaml +++ b/configs/model/large.yaml @@ -0,0 +1,6 @@ +d_model: 768 +num_encoder_layers: 12 +num_decoder_layers: 12 +num_attention_heads: 12 +ffn_dim: 3072 +dropout: 0.1 diff --git a/configs/model/small.yaml b/configs/model/small.yaml index 453f39b5f110dc1aaa844ed9c75ec958b440060c..9a1517fae8b806878a85538740ab774d2d7736f8 100644 --- a/configs/model/small.yaml +++ b/configs/model/small.yaml @@ -1,23 +1,6 @@ -# configs/model/small.yaml (for fast iteration) -model: - d_model: 256 - num_encoder_layers: 4 - num_decoder_layers: 4 - num_heads: 8 - -training: - batch_size: 32 # ~4GB VRAM - gradient_accumulation_steps: 1 - mixed_precision: true # Essential! - -# configs/model/base.yaml (production) -model: - d_model: 512 - num_encoder_layers: 6 - num_decoder_layers: 6 - num_heads: 8 - -training: - batch_size: 8 # ~8GB VRAM - gradient_accumulation_steps: 4 # Effective batch = 32 - mixed_precision: true \ No newline at end of file +d_model: 256 +num_encoder_layers: 4 +num_decoder_layers: 4 +num_attention_heads: 4 +ffn_dim: 1024 +dropout: 0.1 diff --git a/configs/training/default.yaml b/configs/training/default.yaml index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..c13b8f964e41ef531ad0926cbe65f57cef1aab6e 100644 --- a/configs/training/default.yaml +++ b/configs/training/default.yaml @@ -0,0 +1,12 @@ +dataloader: + batch_size: 8 + shuffle: true +optimizer: + name: adamw + lr: 3.0e-5 +scheduler: + name: cosine + warmup_steps: 500 +trainer: + max_epochs: 5 + gradient_clip_norm: 1.0 diff --git a/configs/training/full.yaml b/configs/training/full.yaml index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..66433d65bb872d56d9740a5330925647f5751721 100644 --- a/configs/training/full.yaml +++ b/configs/training/full.yaml @@ -0,0 +1,12 @@ +dataloader: + batch_size: 16 + shuffle: true +optimizer: + name: adamw + lr: 2.0e-5 +scheduler: + name: cosine + warmup_steps: 1000 +trainer: + max_epochs: 15 + gradient_clip_norm: 1.0 diff --git a/configs/training/quick_test.yaml b/configs/training/quick_test.yaml index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..a57bc32fc1400862027ee482c4553750969001e9 100644 --- a/configs/training/quick_test.yaml +++ b/configs/training/quick_test.yaml @@ -0,0 +1,9 @@ +dataloader: + batch_size: 2 + shuffle: false +optimizer: + name: adamw + lr: 1.0e-4 +trainer: + max_epochs: 1 + gradient_clip_norm: 0.5 diff --git a/docker/Dockerfile b/docker/Dockerfile deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/docs/api.md b/docs/api.md index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..a61820babc3870f37a1dcb103caec1e973e70a60 100644 --- a/docs/api.md +++ b/docs/api.md @@ -0,0 +1,79 @@ +# API & CLI Documentation + +## FastAPI Service +The FastAPI application is defined in `src/api/app.py` and wires routes from +`src/api/routes.py`. All dependencies resolve through `src/api/dependencies.py`, which lazily constructs the shared inference pipeline. + +### POST `/summarize` +- **Request Body** (`SummaryRequest`): + ```json + { + "text": "Your input document" + } + ``` +- **Response** (`SummaryResponse`): + ```json + { + "summary": "Generated abstractive summary", + "emotion_labels": ["joy", "surprise"], + "emotion_scores": [0.91, 0.63], + "topic": "news", + "topic_confidence": 0.82 + } + ``` +- **Behaviour:** + 1. Text is preprocessed through `TextPreprocessor` (with optional sklearn transformer if configured). + 2. The multitask model generates a summary via greedy decoding. + 3. Emotion and topic heads produce logits which are converted to probabilities and mapped to + human-readable labels using `artifacts/labels.json`. + 4. Results are returned as structured JSON suitable for a future Gradio interface. + +### Error Handling +- If the checkpoint or label metadata is missing, the dependency raises an HTTP 503 error with + an explanatory message. +- Validation errors (missing `text`) are handled automatically by FastAPI/Pydantic. + +## Command-Line Interface +`scripts/inference.py` provides a CLI that mirrors the API behaviour. + +### Usage +```bash +python scripts/inference.py "Document to analyse" \ + --checkpoint checkpoints/best.pt \ + --labels artifacts/labels.json \ + --tokenizer artifacts/hf_tokenizer \ + --model-config configs/model/base.yaml \ + --device cpu +``` + +Options: +- `text` – zero or more positional arguments. If omitted, use `--file` to point to a newline + delimited text file. +- `--file` – optional path containing one text per line. +- `--checkpoint` – path to the trained model weights. +- `--labels` – JSON containing emotion/topic vocabularies (defaults to `artifacts/labels.json`). +- `--tokenizer` – optional tokenizer directory; defaults to the exported artifact if present. +- `--model-config` – YAML describing the architecture. +- `--device` – `cpu` or `cuda`. Passing `cuda` attempts to run inference on GPU. +- `--summary-max-length` – overrides the default maximum generation length. + +### Output +The CLI prints a JSON array where each entry contains the original text, summary, emotion labels +with scores, and topic prediction. This format is identical to the REST response, facilitating +integration tests and future Gradio UI rendering. + +## Future Gradio UI +- The planned UI will call the same inference pipeline and display results interactively. +- Given the response schema, the UI can show: + - Generated summary text. + - Emotion chips with probability bars. + - Topic confidence gauges. + - Placeholder panel for attention heatmaps and explanations. +- Once implemented, documentation updates will add a `docs/ui.md` section and screenshots under + `docs/images/`. + +## Testing +- `tests/test_api/test_routes.py` stubs the pipeline to ensure response fields and dependency + overrides behave as expected. +- `tests/test_inference/test_pipeline.py` validates pipeline methods end-to-end with dummy models, + guaranteeing API and CLI consumers receive consistent payload shapes. diff --git a/docs/architecture.md b/docs/architecture.md index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..77b39b8328a9e49af8e65d7dadb551128e91a9e8 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -0,0 +1,57 @@ +# LexiMind Architecture + +## Overview +LexiMind couples a from-scratch Transformer implementation with a modern data and inference stack. The project consists of three major layers: + +1. **Data & Preprocessing** – lightweight text cleaning built on top of scikit-learn + primitives and a Hugging Face tokenizer wrapper with deterministic batching helpers. +2. **Model Composition** – the bespoke encoder/decoder stack with task heads assembled via + `MultiTaskModel`, plus `models.factory.build_multitask_model` to rebuild the network from + configuration files. +3. **Inference & Serving** – a multi-task pipeline capable of summarization, emotion, and topic classification; surfaced through a CLI and FastAPI service with plans for a Gradio UI. + +## Custom Transformer Stack +- `src/models/encoder.py` and `src/models/decoder.py` implement Pre-LayerNorm Transformer + blocks with explicit positional encoding, masking logic, and incremental decoding support. +- `src/models/heads.py` provides modular output heads. Summarization uses an `LMHead` tied to + the decoder embedding weights; emotion and topic tasks use `ClassificationHead` instances. +- `src/models/multitask.py` routes inputs to the correct head, computes task-specific losses, + and exposes a single forward API used by the trainer and inference pipeline. +- `src/models/factory.py` rebuilds the encoder, decoder, and heads directly from YAML config + and tokenizer metadata so inference rebuilds the exact architecture used in training. + +## Data, Tokenization, and Preprocessing +- `src/data/tokenization.py` wraps `AutoTokenizer` to provide tensor-aware batching and helper + utilities for decoder input shifting, BOS/EOS resolution, and vocab size retrieval. +- `src/data/preprocessing.py` introduces `TextPreprocessor`, layering a `BasicTextCleaner` with + optional scikit-learn transformers (via `sklearn_transformer`) before tokenization. This keeps + the default cleaning minimal while allowing future reuse of `sklearn.preprocessing` utilities + without changing calling code. +- `src/data/dataset.py` and `src/data/dataloader.py` define strongly typed dataset containers and + collators that encode inputs with the shared tokenizer and set up task-specific labels (multi-label + emotions, categorical topics, seq2seq summaries). + +## Training Pipeline +- `src/training/trainer.py` coordinates multi-task optimization with per-task loss functions, gradient clipping, and shared tokenizer decoding for metric computation. +- Metrics in `src/training/metrics.py` include accuracy, multi-label F1, and a ROUGE-like overlap score for summarization. These metrics mirror the trainer outputs logged per task. +- Label vocabularies are serialized to `artifacts/labels.json` after training so inference can decode class indices consistently. + +## Inference & Serving +- `src/inference/pipeline.py` exposes summarization, emotion, and topic predictions with shared pre-processing, generation, and thresholding logic. It expects label vocabularies from the serialized metadata file. +- `src/inference/factory.py` rebuilds the full pipeline by loading the tokenizer (preferring the exported tokenizer artifact), reconstructing the model via the factory helpers, restoring checkpoints, and injecting label metadata. +- The CLI (`scripts/inference.py`) drives the pipeline from the command line. The FastAPI app (`src/api/routes.py`) exposes the `/summarize` endpoint that returns summaries, emotion labels + scores, and topic predictions. Test coverage in `tests/test_inference` and `tests/test_api` validates both layers with lightweight stubs. + +## Gradio UI Roadmap +- The inference pipeline returns structured outputs that are already suitable for a web UI. +- Planned steps for a Gradio demo: + 1. Wrap `InferencePipeline.batch_predict` inside Gradio callbacks for text input. + 2. Display summaries alongside emotion tag chips and topic confidence bars. + 3. Surface token-level attention visualizations by extending the pipeline to emit decoder attention maps (hooks already exist in the decoder). +- Documentation and code paths were structured to keep the Gradio integration isolated in a future `src/ui/gradio_app.py` module without altering core logic. + +## Key Decisions +- **Custom Transformer Preservation** – all modeling remains on the bespoke encoder/decoder, satisfying the constraint to avoid Hugging Face model classes while still leveraging their tokenizer implementation. +- **Tokenizer Artifact Preference** – inference automatically favors the exported tokenizer in `artifacts/hf_tokenizer`, guaranteeing consistent vocabularies between training and serving. +- **Sklearn-friendly Preprocessing** – the text preprocessor now accepts an optional + `TransformerMixin` so additional normalization (lemmatization, custom token filters, etc.) can be injected using familiar scikit-learn tooling without rewriting the batching code. +- **Documentation Alignment** – the `docs/` folder mirrors the structure requested, capturing design reasoning and paving the way for future diagrams in `docs/images`. diff --git a/docs/training.md b/docs/training.md index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..d1bd7d643c235d0a11bc95e91b23f5494aed9c69 100644 --- a/docs/training.md +++ b/docs/training.md @@ -0,0 +1,59 @@ +# Training Procedure + +## Data Sources +- **Summarization** – expects JSONL files with `source` and `summary` fields under + `data/processed/summarization`. +- **Emotion Classification** – multi-label samples loaded from JSONL files with + `text` and `emotions` arrays. The dataset owns a `MultiLabelBinarizer` for consistent encoding. +- **Topic Classification** – single-label categorical samples with `text` and `topic` fields, encoded via `LabelEncoder`. + +Paths and tokenizer defaults are configured in `configs/data/datasets.yaml`. The tokenizer section chooses the Hugging Face backbone (`facebook/bart-base` by default) and maximum length. Gutenberg book downloads are controlled via the `downloads.books` list (each entry includes `name`, `url`, and `output`). + +## Dataloaders & Collators +- `SummarizationCollator` encodes encoder/decoder inputs, prepares decoder input IDs via `Tokenizer.prepare_decoder_inputs`, and masks padding tokens with `-100` for loss computation. +- `EmotionCollator` applies the dataset's `MultiLabelBinarizer`, returning dense float tensors suitable for `BCEWithLogitsLoss`. +- `TopicCollator` emits integer class IDs via the dataset's `LabelEncoder` for `CrossEntropyLoss`. + +These collators keep all tokenization centralized, reducing duplication and making it easy to swap in additional sklearn transformations through `TextPreprocessor` should we wish to extend cleaning or normalization. + +## Model Assembly +- `src/models/factory.build_multitask_model` rebuilds the encoder, decoder, and heads from the tokenizer metadata and YAML config. This factory is used both during training and inference to eliminate drift between environments. +- The model wraps: + - Transformer encoder/decoder stacks with shared positional encodings. + - LM head tied to decoder embeddings for summarization. + - Mean-pooled classification heads for emotion and topic tasks. + +## Optimisation Loop +- `src/training/trainer.Trainer` orchestrates multi-task training. + - Cross-entropy is used for summarization (seq2seq logits vs. shifted labels). + - `BCEWithLogitsLoss` handles multi-label emotions. + - `CrossEntropyLoss` handles topic classification. +- Gradient clipping ensures stability, and per-task weights can be configured via + `TrainerConfig.task_weights` to balance gradients if needed. +- Metrics tracked per task: + - **Summarization** – ROUGE-like overlap metric (`training.metrics.rouge_like`). + - **Emotion** – micro F1 score for multi-label predictions. + - **Topic** – categorical accuracy. + +## Checkpoints & Artifacts +- `src/utils/io.save_state` stores model weights; checkpoints live under `checkpoints/`. +- `artifacts/labels.json` captures the ordered emotion/topic vocabularies immediately after + training. This file is required for inference so class indices map back to human-readable labels. +- The tokenizer is exported to `artifacts/hf_tokenizer/` for reproducible vocabularies. + +## Running Training +1. Ensure processed datasets are available (see `data/processed/` structure). +2. Choose a configuration (e.g., `configs/training/default.yaml`) for hyperparameters and data splits. +3. Instantiate the tokenizer via `TokenizerConfig` and build datasets/dataloaders. +4. Use `build_multitask_model` to construct the model, create an optimizer, and run + `Trainer.fit(train_loaders, val_loaders)`. +5. Save checkpoints and update `artifacts/labels.json` with the dataset label order. + +> **Note:** A full CLI for training is forthcoming. The scripts in `scripts/` currently act as +> scaffolding; once the Gradio UI is introduced we will extend these utilities to launch +> training jobs with configuration files directly. + +## Future Enhancements +- Integrate curriculum scheduling or task-balanced sampling once empirical results dictate. +- Capture attention maps during training to support visualization in the planned Gradio UI. +- Leverage the optional `sklearn_transformer` hook in `TextPreprocessor` for lemmatization or domain-specific normalization when datasets require it. diff --git a/pyproject.toml b/pyproject.toml index ad9d70799970730bb0a213d8f6f70f1541ea3884..d666c8bbe06ed032880be45e9b3480cacdd462dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,19 +13,14 @@ license = {text = "GPL-3.0"} dependencies = [ "torch>=2.0.0", - "transformers>=4.30.0", - "datasets>=2.14.0", - "tokenizers>=0.13.0", + "scikit-learn>=1.4.0", "numpy>=1.24.0", "pandas>=2.0.0", - "scikit-learn>=1.3.0", - "matplotlib>=3.7.0", - "seaborn>=0.12.0", - "tqdm>=4.65.0", - "pyyaml>=6.0", - "omegaconf>=2.3.0", - "tensorboard>=2.13.0", - "gradio>=3.35.0", + "streamlit>=1.25.0", + "plotly>=5.18.0", + "transformers>=4.40.0", + "fastapi>=0.110.0", + "datasets>=4.4.0", ] [project.optional-dependencies] diff --git a/requirements.txt b/requirements.txt index 1916909c8da824870966b316334f9460c2dfd69a..9d67b5b7cce0d5cf4d478ef6f688c61e66edbee9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,22 +1,12 @@ # requirements.txt torch>=2.0.0 -transformers>=4.30.0 -datasets>=2.14.0 -tokenizers>=0.13.0 +transformers>=4.40.0 +scikit-learn>=1.4.0 numpy>=1.24.0 pandas>=2.0.0 -scikit-learn>=1.3.0 -matplotlib>=3.7.0 -seaborn>=0.12.0 -nltk>=3.8.0 -tqdm>=4.65.0 -pyyaml>=6.0 -omegaconf>=2.3.0 -tensorboard>=2.13.0 -gradio>=3.35.0 -requests>=2.31.0 -kaggle>=1.5.12 streamlit>=1.25.0 plotly>=5.18.0 -faiss-cpu==1.9.0; platform_system != "Windows" -faiss-cpu==1.9.0; platform_system == "Windows" \ No newline at end of file +fastapi>=0.110.0 +datasets>=4.4.0 +pytest +matplotlib \ No newline at end of file diff --git a/scripts/download_data.py b/scripts/download_data.py new file mode 100644 index 0000000000000000000000000000000000000000..587e6c8bdb15a8034210df1ca78e3ed8cdc76d19 --- /dev/null +++ b/scripts/download_data.py @@ -0,0 +1,182 @@ +"""Download datasets used by LexiMind.""" +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path +from typing import Iterable, Iterator, cast + +from datasets import ClassLabel, Dataset, DatasetDict, load_dataset + + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + +from src.data.download import gutenberg_download, kaggle_download +from src.utils.config import load_yaml + + +DEFAULT_SUMMARIZATION_DATASET = "gowrishankarp/newspaper-text-summarization-cnn-dailymail" +DEFAULT_EMOTION_DATASET = "dair-ai/emotion" +DEFAULT_TOPIC_DATASET = "ag_news" +DEFAULT_BOOK_URL = "https://www.gutenberg.org/cache/epub/1342/pg1342.txt" +DEFAULT_BOOK_OUTPUT = "data/raw/books/pride_and_prejudice.txt" + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Download datasets required for LexiMind training") + parser.add_argument( + "--config", + default="configs/data/datasets.yaml", + help="Path to the dataset configuration YAML.", + ) + parser.add_argument("--skip-kaggle", action="store_true", help="Skip downloading the Kaggle summarization dataset.") + parser.add_argument("--skip-book", action="store_true", help="Skip downloading Gutenberg book texts.") + return parser.parse_args() + + +def _safe_load_config(path: str | None) -> dict: + if not path: + return {} + config_path = Path(path) + if not config_path.exists(): + raise FileNotFoundError(f"Config file not found: {config_path}") + return load_yaml(str(config_path)).data + + +def _write_jsonl(records: Iterable[dict[str, object]], destination: Path) -> None: + destination.parent.mkdir(parents=True, exist_ok=True) + with destination.open("w", encoding="utf-8") as handle: + for record in records: + handle.write(json.dumps(record, ensure_ascii=False) + "\n") + + +def _emotion_records(dataset_split: Dataset, label_names: list[str] | None) -> Iterator[dict[str, object]]: + for item in dataset_split: + data = dict(item) + text = data.get("text", "") + label_value = data.get("label") + def resolve_label(index: object) -> str: + if isinstance(index, int) and label_names and 0 <= index < len(label_names): + return label_names[index] + return str(index) + + if isinstance(label_value, list): + labels = [resolve_label(idx) for idx in label_value] + else: + labels = [resolve_label(label_value)] + yield {"text": text, "emotions": labels} + + +def _topic_records(dataset_split: Dataset, label_names: list[str] | None) -> Iterator[dict[str, object]]: + for item in dataset_split: + data = dict(item) + text = data.get("text") or data.get("content") or "" + label_value = data.get("label") + def resolve_topic(raw: object) -> str: + if label_names: + idx: int | None = None + if isinstance(raw, int): + idx = raw + elif isinstance(raw, str): + try: + idx = int(raw) + except ValueError: + idx = None + if idx is not None and 0 <= idx < len(label_names): + return label_names[idx] + return str(raw) if raw is not None else "" + + if isinstance(label_value, list): + topic = resolve_topic(label_value[0]) if label_value else "" + else: + topic = resolve_topic(label_value) + yield {"text": text, "topic": topic} + + +def main() -> None: + args = parse_args() + config = _safe_load_config(args.config) + + raw_paths = config.get("raw", {}) if isinstance(config, dict) else {} + downloads_cfg = config.get("downloads", {}) if isinstance(config, dict) else {} + + summarization_cfg = downloads_cfg.get("summarization", {}) if isinstance(downloads_cfg, dict) else {} + summarization_dataset = summarization_cfg.get("dataset", DEFAULT_SUMMARIZATION_DATASET) + summarization_output = summarization_cfg.get("output", raw_paths.get("summarization", "data/raw/summarization")) + + if not args.skip_kaggle and summarization_dataset: + print(f"Downloading summarization dataset '{summarization_dataset}' -> {summarization_output}") + kaggle_download(summarization_dataset, summarization_output) + else: + print("Skipping Kaggle summarization download.") + + books_root = Path(raw_paths.get("books", "data/raw/books")) + books_root.mkdir(parents=True, exist_ok=True) + + books_entries: list[dict[str, object]] = [] + if isinstance(downloads_cfg, dict): + raw_entries = downloads_cfg.get("books") + if isinstance(raw_entries, list): + books_entries = [entry for entry in raw_entries if isinstance(entry, dict)] + + if not args.skip_book: + if not books_entries: + books_entries = [ + { + "name": "pride_and_prejudice", + "url": DEFAULT_BOOK_URL, + "output": DEFAULT_BOOK_OUTPUT, + } + ] + for entry in books_entries: + name = str(entry.get("name") or "gutenberg_text") + url = str(entry.get("url") or DEFAULT_BOOK_URL) + output_value = entry.get("output") + destination = Path(output_value) if isinstance(output_value, str) and output_value else books_root / f"{name}.txt" + destination.parent.mkdir(parents=True, exist_ok=True) + print(f"Downloading Gutenberg text '{name}' from {url} -> {destination}") + gutenberg_download(url, str(destination)) + else: + print("Skipping Gutenberg downloads.") + emotion_cfg = downloads_cfg.get("emotion", {}) if isinstance(downloads_cfg, dict) else {} + emotion_name = emotion_cfg.get("dataset", DEFAULT_EMOTION_DATASET) + emotion_dir = Path(raw_paths.get("emotion", "data/raw/emotion")) + emotion_dir.mkdir(parents=True, exist_ok=True) + print(f"Downloading emotion dataset '{emotion_name}' -> {emotion_dir}") + emotion_dataset = cast(DatasetDict, load_dataset(emotion_name)) + first_emotion_key = next(iter(emotion_dataset.keys()), None) if emotion_dataset else None + emotion_label_feature = ( + emotion_dataset[first_emotion_key].features.get("label") + if first_emotion_key is not None + else None + ) + emotion_label_names = emotion_label_feature.names if isinstance(emotion_label_feature, ClassLabel) else None + for split_name, split in emotion_dataset.items(): + output_path = emotion_dir / f"{str(split_name)}.jsonl" + _write_jsonl(_emotion_records(split, emotion_label_names), output_path) + + topic_cfg = downloads_cfg.get("topic", {}) if isinstance(downloads_cfg, dict) else {} + topic_name = topic_cfg.get("dataset", DEFAULT_TOPIC_DATASET) + topic_dir = Path(raw_paths.get("topic", "data/raw/topic")) + topic_dir.mkdir(parents=True, exist_ok=True) + print(f"Downloading topic dataset '{topic_name}' -> {topic_dir}") + topic_dataset = cast(DatasetDict, load_dataset(topic_name)) + first_topic_key = next(iter(topic_dataset.keys()), None) if topic_dataset else None + topic_label_feature = ( + topic_dataset[first_topic_key].features.get("label") + if first_topic_key is not None + else None + ) + topic_label_names = topic_label_feature.names if isinstance(topic_label_feature, ClassLabel) else None + for split_name, split in topic_dataset.items(): + output_path = topic_dir / f"{str(split_name)}.jsonl" + _write_jsonl(_topic_records(split, topic_label_names), output_path) + + print("Download routine finished.") + + +if __name__ == "__main__": + main() diff --git a/scripts/download_data.sh b/scripts/download_data.sh new file mode 100644 index 0000000000000000000000000000000000000000..7805fa861f07eca6c5ec77a203eedf581345ece3 --- /dev/null +++ b/scripts/download_data.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +python3 "${SCRIPT_DIR}/download_data.py" diff --git a/scripts/evaluate.py b/scripts/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..8313c6dd706f623191f4c05ec8a4c6f3bda7fea8 --- /dev/null +++ b/scripts/evaluate.py @@ -0,0 +1,134 @@ +"""Evaluate the multitask model on processed validation/test splits.""" +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path +from typing import List + +import torch +from sklearn.preprocessing import MultiLabelBinarizer + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + +from src.data.dataset import ( + load_emotion_jsonl, + load_summarization_jsonl, + load_topic_jsonl, +) +from src.inference.factory import create_inference_pipeline +from src.training.metrics import accuracy, multilabel_f1, rouge_like +from src.utils.config import load_yaml + + +SPLIT_ALIASES = { + "train": ("train",), + "val": ("val", "validation"), + "test": ("test",), +} + + +def _read_split(root: Path, split: str, loader) -> list: + aliases = SPLIT_ALIASES.get(split, (split,)) + for alias in aliases: + for ext in ("jsonl", "json"): + candidate = root / f"{alias}.{ext}" + if candidate.exists(): + return loader(str(candidate)) + raise FileNotFoundError(f"Missing {split} split under {root}") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Evaluate the LexiMind multitask model") + parser.add_argument("--split", default="val", choices=["train", "val", "test"], help="Dataset split to evaluate.") + parser.add_argument("--checkpoint", default="checkpoints/best.pt", help="Path to the trained checkpoint.") + parser.add_argument("--labels", default="artifacts/labels.json", help="Label metadata JSON.") + parser.add_argument("--data-config", default="configs/data/datasets.yaml", help="Data configuration YAML.") + parser.add_argument("--model-config", default="configs/model/base.yaml", help="Model architecture YAML.") + parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="Device for evaluation.") + parser.add_argument("--batch-size", type=int, default=16, help="Batch size for generation/classification during evaluation.") + return parser.parse_args() + + +def chunks(items: List, size: int): + for start in range(0, len(items), size): + yield items[start : start + size] + + +def main() -> None: + args = parse_args() + data_cfg = load_yaml(args.data_config).data + + pipeline, metadata = create_inference_pipeline( + checkpoint_path=args.checkpoint, + labels_path=args.labels, + tokenizer_config=None, + model_config_path=args.model_config, + device=args.device, + ) + + summarization_dir = Path(data_cfg["processed"]["summarization"]) + emotion_dir = Path(data_cfg["processed"]["emotion"]) + topic_dir = Path(data_cfg["processed"]["topic"]) + + summary_examples = _read_split(summarization_dir, args.split, load_summarization_jsonl) + emotion_examples = _read_split(emotion_dir, args.split, load_emotion_jsonl) + topic_examples = _read_split(topic_dir, args.split, load_topic_jsonl) + + emotion_binarizer = MultiLabelBinarizer(classes=metadata.emotion) + # Ensure scikit-learn initializes the attributes using metadata ordering. + emotion_binarizer.fit([[label] for label in metadata.emotion]) + + # Summarization + summaries_pred = [] + summaries_ref = [] + for batch in chunks(summary_examples, args.batch_size): + inputs = [example.source for example in batch] + summaries_pred.extend(pipeline.summarize(inputs)) + summaries_ref.extend([example.summary for example in batch]) + rouge_score = rouge_like(summaries_pred, summaries_ref) + + # Emotion + emotion_preds_tensor = [] + emotion_target_tensor = [] + label_to_index = {label: idx for idx, label in enumerate(metadata.emotion)} + for batch in chunks(emotion_examples, args.batch_size): + inputs = [example.text for example in batch] + predictions = pipeline.predict_emotions(inputs) + target_matrix = emotion_binarizer.transform([list(example.emotions) for example in batch]) + for pred, target_row in zip(predictions, target_matrix): + vector = torch.zeros(len(metadata.emotion), dtype=torch.float32) + for label in pred.labels: + idx = label_to_index.get(label) + if idx is not None: + vector[idx] = 1.0 + emotion_preds_tensor.append(vector) + emotion_target_tensor.append(torch.tensor(target_row, dtype=torch.float32)) + emotion_f1 = multilabel_f1(torch.stack(emotion_preds_tensor), torch.stack(emotion_target_tensor)) + + # Topic + topic_preds = [] + topic_targets = [] + for batch in chunks(topic_examples, args.batch_size): + inputs = [example.text for example in batch] + predictions = pipeline.predict_topics(inputs) + topic_preds.extend([pred.label for pred in predictions]) + topic_targets.extend([example.topic for example in batch]) + topic_accuracy = accuracy(topic_preds, topic_targets) + + print(json.dumps( + { + "split": args.split, + "rouge_like": rouge_score, + "emotion_f1": emotion_f1, + "topic_accuracy": topic_accuracy, + }, + indent=2, + )) + + +if __name__ == "__main__": + main() diff --git a/scripts/export_model.py b/scripts/export_model.py new file mode 100644 index 0000000000000000000000000000000000000000..88b3b973352b316af12b9e37c565b1c45fec56f6 --- /dev/null +++ b/scripts/export_model.py @@ -0,0 +1,69 @@ +"""Rebuild and export the trained multitask model for downstream use.""" +from __future__ import annotations + +import argparse +from pathlib import Path + +import torch + +from src.data.tokenization import Tokenizer, TokenizerConfig +from src.models.factory import build_multitask_model, load_model_config +from src.utils.config import load_yaml +from src.utils.labels import load_label_metadata + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Export LexiMind model weights") + parser.add_argument("--checkpoint", default="checkpoints/best.pt", help="Path to the trained checkpoint.") + parser.add_argument("--output", default="outputs/model.pt", help="Output path for the exported state dict.") + parser.add_argument("--labels", default="artifacts/labels.json", help="Label metadata JSON produced after training.") + parser.add_argument("--model-config", default="configs/model/base.yaml", help="Model architecture configuration.") + parser.add_argument("--data-config", default="configs/data/datasets.yaml", help="Data configuration (for tokenizer settings).") + return parser.parse_args() + + +def main() -> None: + """Export multitask model weights from a training checkpoint to a standalone state dict.""" + args = parse_args() + + checkpoint = Path(args.checkpoint) + if not checkpoint.exists(): + raise FileNotFoundError(checkpoint) + + labels = load_label_metadata(args.labels) + data_cfg = load_yaml(args.data_config).data + tokenizer_section = data_cfg.get("tokenizer", {}) + tokenizer_config = TokenizerConfig( + pretrained_model_name=tokenizer_section.get("pretrained_model_name", "facebook/bart-base"), + max_length=int(tokenizer_section.get("max_length", 512)), + lower=bool(tokenizer_section.get("lower", False)), + ) + tokenizer = Tokenizer(tokenizer_config) + + model = build_multitask_model( + tokenizer, + num_emotions=labels.emotion_size, + num_topics=labels.topic_size, + config=load_model_config(args.model_config), + ) + + raw_state = torch.load(checkpoint, map_location="cpu") + if isinstance(raw_state, dict): + if "model_state_dict" in raw_state and isinstance(raw_state["model_state_dict"], dict): + state_dict = raw_state["model_state_dict"] + elif "state_dict" in raw_state and isinstance(raw_state["state_dict"], dict): + state_dict = raw_state["state_dict"] + else: + state_dict = raw_state + else: + raise TypeError(f"Unsupported checkpoint format: expected dict, got {type(raw_state)!r}") + model.load_state_dict(state_dict) + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + torch.save(model.state_dict(), output_path) + print(f"Model exported to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/inference.py b/scripts/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..9ceb6af876eba9c9cf0aa48a477a09e185ecd879 --- /dev/null +++ b/scripts/inference.py @@ -0,0 +1,112 @@ +"""Run inference with the multitask model.""" +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import List, cast + +from src.data.tokenization import TokenizerConfig +from src.inference import EmotionPrediction, TopicPrediction, create_inference_pipeline + + +def _load_texts(positional: List[str], file_path: Path | None) -> List[str]: + texts = [text for text in positional if text] + if file_path is not None: + if not file_path.exists(): + raise FileNotFoundError(file_path) + with file_path.open("r", encoding="utf-8") as handle: + texts.extend([line.strip() for line in handle if line.strip()]) + if not texts: + raise ValueError("No input texts provided. Pass text arguments or use --file.") + return texts + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run LexiMind multitask inference.") + parser.add_argument("text", nargs="*", help="Input text(s) to analyse.") + parser.add_argument("--file", type=Path, help="Path to a file containing one text per line.") + parser.add_argument( + "--checkpoint", + type=Path, + default=Path("checkpoints/best.pt"), + help="Path to the model checkpoint produced during training.", + ) + parser.add_argument( + "--labels", + type=Path, + default=Path("artifacts/labels.json"), + help="JSON file containing emotion/topic label vocabularies.", + ) + parser.add_argument( + "--tokenizer", + type=Path, + default=None, + help="Optional path to a tokenizer directory exported during training.", + ) + parser.add_argument( + "--model-config", + type=Path, + default=Path("configs/model/base.yaml"), + help="Model architecture config used to rebuild the transformer stack.", + ) + parser.add_argument("--device", default="cpu", help="Device to run inference on (cpu or cuda).") + parser.add_argument( + "--summary-max-length", + type=int, + default=None, + help="Optional maximum length for generated summaries.", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + texts = _load_texts(args.text, args.file) + + tokenizer_config = None + if args.tokenizer is not None: + tokenizer_config = TokenizerConfig(pretrained_model_name=str(args.tokenizer)) + else: + local_dir = Path("artifacts/hf_tokenizer") + if local_dir.exists(): + tokenizer_config = TokenizerConfig(pretrained_model_name=str(local_dir)) + + pipeline, _ = create_inference_pipeline( + checkpoint_path=args.checkpoint, + labels_path=args.labels, + tokenizer_config=tokenizer_config, + model_config_path=args.model_config, + device=args.device, + summary_max_length=args.summary_max_length, + ) + + results = pipeline.batch_predict(texts) + summaries = cast(List[str], results["summaries"]) + emotion_preds = cast(List[EmotionPrediction], results["emotion"]) + topic_preds = cast(List[TopicPrediction], results["topic"]) + + packaged = [] + for idx, text in enumerate(texts): + emotion = emotion_preds[idx] + topic = topic_preds[idx] + packaged.append( + { + "text": text, + "summary": summaries[idx], + "emotion": { + "labels": emotion.labels, + "scores": emotion.scores, + }, + "topic": { + "label": topic.label, + "confidence": topic.confidence, + }, + } + ) + + print(json.dumps(packaged, indent=2, ensure_ascii=False)) + + +if __name__ == "__main__": + main() diff --git a/scripts/preprocess_data.py b/scripts/preprocess_data.py new file mode 100644 index 0000000000000000000000000000000000000000..cb317871abff3897fb853fdb4e212e696d20a142 --- /dev/null +++ b/scripts/preprocess_data.py @@ -0,0 +1,321 @@ +"""Preprocess raw datasets into JSONL splits for LexiMind training.""" +from __future__ import annotations + +import argparse +import csv +import json +import sys +from pathlib import Path +from typing import Dict, Iterable, Iterator, Sequence, Tuple + +from sklearn.model_selection import train_test_split + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + +from src.data.preprocessing import BasicTextCleaner +from src.utils.config import load_yaml + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Preprocess datasets configured for LexiMind") + parser.add_argument( + "--config", + default="configs/data/datasets.yaml", + help="Path to data configuration YAML.", + ) + parser.add_argument("--val-ratio", type=float, default=0.1, help="Validation split size for topic dataset when no validation split is present.") + parser.add_argument("--seed", type=int, default=17, help="Random seed for deterministic splitting.") + return parser.parse_args() + + +def _resolve_csv(base: Path, filename: str) -> Path | None: + primary = base / filename + if primary.exists(): + return primary + nested = base / "cnn_dailymail" / filename + if nested.exists(): + return nested + return None + + +def _write_jsonl(records: Iterable[Dict[str, object]], destination: Path) -> None: + destination.parent.mkdir(parents=True, exist_ok=True) + with destination.open("w", encoding="utf-8") as handle: + for record in records: + handle.write(json.dumps(record, ensure_ascii=False) + "\n") + + +def _read_jsonl(path: Path) -> Iterator[Dict[str, object]]: + with path.open("r", encoding="utf-8") as handle: + for line in handle: + row = line.strip() + if not row: + continue + yield json.loads(row) + + +def preprocess_books( + raw_dir: Path, + processed_dir: Path, + cleaner: BasicTextCleaner, + *, + min_tokens: int = 30, +) -> None: + if not raw_dir.exists(): + print(f"Skipping book preprocessing (missing directory: {raw_dir})") + return + + processed_dir.mkdir(parents=True, exist_ok=True) + index: list[Dict[str, object]] = [] + + for book_path in sorted(raw_dir.glob("*.txt")): + text = book_path.read_text(encoding="utf-8").lstrip("\ufeff") + normalized = text.replace("\r\n", "\n") + paragraphs = [paragraph.strip() for paragraph in normalized.split("\n\n") if paragraph.strip()] + + records: list[Dict[str, object]] = [] + for paragraph_id, paragraph in enumerate(paragraphs): + cleaned = cleaner.transform([paragraph])[0] + tokens = cleaned.split() + if len(tokens) < min_tokens: + continue + record = { + "book": book_path.stem, + "title": book_path.stem.replace("_", " ").title(), + "paragraph_id": paragraph_id, + "text": paragraph, + "clean_text": cleaned, + "token_count": len(tokens), + "char_count": len(paragraph), + } + records.append(record) + + if not records: + print(f"No suitably sized paragraphs found in {book_path}; skipping.") + continue + + output_path = processed_dir / f"{book_path.stem}.jsonl" + print(f"Writing book segments for '{book_path.stem}' to {output_path}") + _write_jsonl(records, output_path) + index.append( + { + "book": book_path.stem, + "title": records[0]["title"], + "paragraphs": len(records), + "source": str(book_path), + "output": str(output_path), + } + ) + + if index: + index_path = processed_dir / "index.json" + with index_path.open("w", encoding="utf-8") as handle: + json.dump(index, handle, ensure_ascii=False, indent=2) + print(f"Book index written to {index_path}") + + +def preprocess_summarization(raw_dir: Path, processed_dir: Path) -> None: + if not raw_dir.exists(): + print(f"Skipping summarization preprocessing (missing directory: {raw_dir})") + return + + for split in ("train", "validation", "test"): + source_path = _resolve_csv(raw_dir, f"{split}.csv") + if source_path is None: + print(f"Skipping summarization split '{split}' (file not found)") + continue + + output_path = processed_dir / f"{split}.jsonl" + output_path.parent.mkdir(parents=True, exist_ok=True) + print(f"Writing summarization split '{split}' to {output_path}") + with source_path.open("r", encoding="utf-8", newline="") as source_handle, output_path.open("w", encoding="utf-8") as sink: + reader = csv.DictReader(source_handle) + for row in reader: + article = row.get("article") or row.get("Article") or "" + highlights = row.get("highlights") or row.get("summary") or "" + payload = {"source": article.strip(), "summary": highlights.strip()} + sink.write(json.dumps(payload, ensure_ascii=False) + "\n") + + +def preprocess_emotion(raw_dir: Path, processed_dir: Path, cleaner: BasicTextCleaner) -> None: + if not raw_dir.exists(): + print(f"Skipping emotion preprocessing (missing directory: {raw_dir})") + return + + split_aliases: Dict[str, Sequence[str]] = { + "train": ("train",), + "val": ("val", "validation"), + "test": ("test",), + } + + for split, aliases in split_aliases.items(): + source_path: Path | None = None + for alias in aliases: + for extension in ("jsonl", "txt", "csv"): + candidate = raw_dir / f"{alias}.{extension}" + if candidate.exists(): + source_path = candidate + break + if source_path is not None: + break + if source_path is None: + print(f"Skipping emotion split '{split}' (file not found)") + continue + + assert source_path is not None + path = source_path + + def iter_records() -> Iterator[Dict[str, object]]: + if path.suffix == ".jsonl": + for row in _read_jsonl(path): + raw_text = str(row.get("text", "")) + text = cleaner.transform([raw_text])[0] + labels = row.get("emotions") or row.get("labels") or [] + if isinstance(labels, str): + labels = [label.strip() for label in labels.split(",") if label.strip()] + elif isinstance(labels, Sequence): + labels = [str(label) for label in labels] + else: + labels = [str(labels)] if labels else [] + if not labels: + labels = ["neutral"] + yield {"text": text, "emotions": labels} + else: + delimiter = ";" if path.suffix == ".txt" else "," + with path.open("r", encoding="utf-8", newline="") as handle: + reader = csv.reader(handle, delimiter=delimiter) + for row in reader: + if not row: + continue + raw_text = str(row[0]) + text = cleaner.transform([raw_text])[0] + raw_labels = row[1] if len(row) > 1 else "" + labels = [label.strip() for label in raw_labels.split(",") if label.strip()] + if not labels: + labels = ["neutral"] + yield {"text": text, "emotions": labels} + + output_path = processed_dir / f"{split}.jsonl" + print(f"Writing emotion split '{split}' to {output_path}") + _write_jsonl(iter_records(), output_path) + + +def preprocess_topic( + raw_dir: Path, + processed_dir: Path, + cleaner: BasicTextCleaner, + val_ratio: float, + seed: int, +) -> None: + if not raw_dir.exists(): + print(f"Skipping topic preprocessing (missing directory: {raw_dir})") + return + + def locate(*names: str) -> Path | None: + for name in names: + candidate = raw_dir / name + if candidate.exists(): + return candidate + return None + + train_path = locate("train.jsonl", "train.csv") + if train_path is None: + print(f"Skipping topic preprocessing (missing train split in {raw_dir})") + return + + assert train_path is not None + + def load_topic_rows(path: Path) -> list[Tuple[str, str]]: + rows: list[Tuple[str, str]] = [] + if path.suffix == ".jsonl": + for record in _read_jsonl(path): + text = str(record.get("text") or record.get("content") or "") + topic = record.get("topic") or record.get("label") + cleaned_text = cleaner.transform([text])[0] + rows.append((cleaned_text, str(topic).strip())) + else: + with path.open("r", encoding="utf-8", newline="") as handle: + reader = csv.DictReader(handle) + for row in reader: + topic = row.get("Class Index") or row.get("topic") or row.get("label") + title = str(row.get("Title") or "") + description = str(row.get("Description") or row.get("text") or "") + text = " ".join(filter(None, (title, description))) + cleaned_text = cleaner.transform([text])[0] + rows.append((cleaned_text, str(topic).strip())) + return rows + + train_rows = load_topic_rows(train_path) + if not train_rows: + print("No topic training rows found; skipping topic preprocessing.") + return + + texts = [row[0] for row in train_rows] + topics = [row[1] for row in train_rows] + + validation_path = locate("val.jsonl", "validation.jsonl", "val.csv", "validation.csv") + has_validation = validation_path is not None + + if has_validation and validation_path: + val_rows = load_topic_rows(validation_path) + train_records = train_rows + else: + train_texts, val_texts, train_topics, val_topics = train_test_split( + texts, + topics, + test_size=val_ratio, + random_state=seed, + stratify=topics, + ) + train_records = list(zip(train_texts, train_topics)) + val_rows = list(zip(val_texts, val_topics)) + + def to_records(pairs: Sequence[Tuple[str, str]]) -> Iterator[Dict[str, object]]: + for text, topic in pairs: + yield {"text": text, "topic": topic} + + print(f"Writing topic train split to {processed_dir / 'train.jsonl'}") + _write_jsonl(to_records(train_records), processed_dir / "train.jsonl") + print(f"Writing topic val split to {processed_dir / 'val.jsonl'}") + _write_jsonl(to_records(val_rows), processed_dir / "val.jsonl") + + test_path = locate("test.jsonl", "test.csv") + if test_path is not None: + test_rows = load_topic_rows(test_path) + print(f"Writing topic test split to {processed_dir / 'test.jsonl'}") + _write_jsonl(to_records(test_rows), processed_dir / "test.jsonl") + else: + print(f"Skipping topic test split (missing test split in {raw_dir})") + + +def main() -> None: + args = parse_args() + config = load_yaml(args.config).data + + raw_cfg = config.get("raw", {}) + processed_cfg = config.get("processed", {}) + + books_raw = Path(raw_cfg.get("books", "data/raw/books")) + summarization_raw = Path(raw_cfg.get("summarization", "data/raw/summarization")) + emotion_raw = Path(raw_cfg.get("emotion", "data/raw/emotion")) + topic_raw = Path(raw_cfg.get("topic", "data/raw/topic")) + + books_processed = Path(processed_cfg.get("books", "data/processed/books")) + summarization_processed = Path(processed_cfg.get("summarization", "data/processed/summarization")) + emotion_processed = Path(processed_cfg.get("emotion", "data/processed/emotion")) + topic_processed = Path(processed_cfg.get("topic", "data/processed/topic")) + + cleaner = BasicTextCleaner() + + preprocess_books(books_raw, books_processed, cleaner) + preprocess_summarization(summarization_raw, summarization_processed) + preprocess_emotion(emotion_raw, emotion_processed, cleaner) + preprocess_topic(topic_raw, topic_processed, cleaner, val_ratio=args.val_ratio, seed=args.seed) + + print("Preprocessing complete.") + + +if __name__ == "__main__": + main() diff --git a/scripts/test_gpu.py b/scripts/test_gpu.py deleted file mode 100644 index 7ac2be0b9140c834818656cfb693001ed422cc2f..0000000000000000000000000000000000000000 --- a/scripts/test_gpu.py +++ /dev/null @@ -1,27 +0,0 @@ -# test_gpu.py -import torch - -print("=" * 50) -print("GPU Information") -print("=" * 50) - -if torch.cuda.is_available(): - gpu_name = torch.cuda.get_device_name(0) - gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 - - print(f"βœ… GPU: {gpu_name}") - print(f"βœ… Memory: {gpu_memory:.2f} GB") - - # Test tensor creation - x = torch.randn(1000, 1000, device='cuda') - y = torch.randn(1000, 1000, device='cuda') - z = x @ y - - print(f"βœ… CUDA operations working!") - print(f"βœ… Current memory allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB") - print(f"βœ… Max memory allocated: {torch.cuda.max_memory_allocated(0) / 1e9:.2f} GB") -else: - print("❌ CUDA not available!") - print("Using CPU - training will be slow!") - -print("=" * 50) \ No newline at end of file diff --git a/scripts/train.py b/scripts/train.py index 8c2cc35e4b199f96f8daeefc997a3d8d0fab97ac..e02f8e5cad96d03e625f23d3557490c9b50ce622 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,8 +1,219 @@ -# scripts/train.py -from src.training.trainer import Trainer -from src.utils.config import load_config +"""End-to-end training entrypoint for the LexiMind multitask model.""" +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path +from typing import Dict, Sequence + +import torch + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + +from src.data.dataloader import ( + build_emotion_dataloader, + build_summarization_dataloader, + build_topic_dataloader, +) +from src.data.dataset import ( + EmotionDataset, + SummarizationDataset, + TopicDataset, + load_emotion_jsonl, + load_summarization_jsonl, + load_topic_jsonl, +) +from src.data.tokenization import Tokenizer, TokenizerConfig +from src.models.factory import build_multitask_model, load_model_config +from src.training.trainer import Trainer, TrainerConfig +from src.training.utils import set_seed +from src.utils.config import load_yaml +from src.utils.io import save_state +from src.utils.labels import LabelMetadata, save_label_metadata + + +SplitExamples = Dict[str, list] + + +SPLIT_ALIASES: Dict[str, Sequence[str]] = { + "train": ("train",), + "val": ("val", "validation"), + "test": ("test",), +} + + +def _read_examples(data_dir: Path, loader) -> SplitExamples: + splits: SplitExamples = {} + for canonical, aliases in SPLIT_ALIASES.items(): + found = False + for alias in aliases: + for extension in ("jsonl", "json"): + candidate = data_dir / f"{alias}.{extension}" + if candidate.exists(): + splits[canonical] = loader(str(candidate)) + found = True + break + if found: + break + if not found: + raise FileNotFoundError(f"Missing {canonical} split under {data_dir}") + return splits + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Train the LexiMind multitask transformer") + parser.add_argument("--data-config", default="configs/data/datasets.yaml", help="Path to data configuration YAML.") + parser.add_argument("--training-config", default="configs/training/default.yaml", help="Path to training hyperparameter YAML.") + parser.add_argument("--model-config", default="configs/model/base.yaml", help="Path to model architecture YAML.") + parser.add_argument("--checkpoint-out", default="checkpoints/best.pt", help="Where to store the trained checkpoint.") + parser.add_argument("--labels-out", default="artifacts/labels.json", help="Where to persist label vocabularies.") + parser.add_argument("--history-out", default="outputs/training_history.json", help="Where to write training history.") + parser.add_argument("--device", default="cpu", help="Training device identifier (cpu or cuda).") + parser.add_argument("--seed", type=int, default=17, help="Random seed for reproducibility.") + return parser.parse_args() + + +def main() -> None: + args = parse_args() + set_seed(args.seed) + + data_cfg = load_yaml(args.data_config).data + training_cfg = load_yaml(args.training_config).data + model_cfg = load_model_config(args.model_config) + + summarization_dir = Path(data_cfg["processed"]["summarization"]) + emotion_dir = Path(data_cfg["processed"]["emotion"]) + topic_dir = Path(data_cfg["processed"]["topic"]) + + summarization_splits = _read_examples(summarization_dir, load_summarization_jsonl) + emotion_splits = _read_examples(emotion_dir, load_emotion_jsonl) + topic_splits = _read_examples(topic_dir, load_topic_jsonl) + + tokenizer_section = data_cfg.get("tokenizer", {}) + tokenizer_config = TokenizerConfig( + pretrained_model_name=tokenizer_section.get("pretrained_model_name", "facebook/bart-base"), + max_length=int(tokenizer_section.get("max_length", 512)), + lower=bool(tokenizer_section.get("lower", False)), + ) + tokenizer = Tokenizer(tokenizer_config) + + summarization_train = SummarizationDataset(summarization_splits["train"]) + summarization_val = SummarizationDataset(summarization_splits["val"]) + + emotion_train = EmotionDataset(emotion_splits["train"]) + emotion_val = EmotionDataset(emotion_splits["val"], binarizer=emotion_train.binarizer) + + topic_train = TopicDataset(topic_splits["train"]) + topic_val = TopicDataset(topic_splits["val"], encoder=topic_train.encoder) + + dataloader_args = training_cfg.get("dataloader", {}) + batch_size = int(dataloader_args.get("batch_size", 8)) + shuffle = bool(dataloader_args.get("shuffle", True)) + max_length = tokenizer.config.max_length + + train_loaders = { + "summarization": build_summarization_dataloader( + summarization_train, + tokenizer, + batch_size=batch_size, + shuffle=shuffle, + max_source_length=max_length, + max_target_length=max_length, + ), + "emotion": build_emotion_dataloader( + emotion_train, + tokenizer, + batch_size=batch_size, + shuffle=shuffle, + max_length=max_length, + ), + "topic": build_topic_dataloader( + topic_train, + tokenizer, + batch_size=batch_size, + shuffle=shuffle, + max_length=max_length, + ), + } + + val_loaders = { + "summarization": build_summarization_dataloader( + summarization_val, + tokenizer, + batch_size=batch_size, + shuffle=False, + max_source_length=max_length, + max_target_length=max_length, + ), + "emotion": build_emotion_dataloader( + emotion_val, + tokenizer, + batch_size=batch_size, + shuffle=False, + max_length=max_length, + ), + "topic": build_topic_dataloader( + topic_val, + tokenizer, + batch_size=batch_size, + shuffle=False, + max_length=max_length, + ), + } + + device = torch.device(args.device) + model = build_multitask_model( + tokenizer, + num_emotions=len(emotion_train.emotion_classes), + num_topics=len(topic_train.topic_classes), + config=model_cfg, + ).to(device) + + optimizer_cfg = training_cfg.get("optimizer", {}) + lr = float(optimizer_cfg.get("lr", 3.0e-5)) + optimizer = torch.optim.AdamW(model.parameters(), lr=lr) + + trainer_cfg = training_cfg.get("trainer", {}) + trainer = Trainer( + model=model, + optimizer=optimizer, + config=TrainerConfig( + max_epochs=int(trainer_cfg.get("max_epochs", 1)), + gradient_clip_norm=float(trainer_cfg.get("gradient_clip_norm", 1.0)), + logging_interval=int(trainer_cfg.get("logging_interval", 50)), + task_weights=trainer_cfg.get("task_weights"), + ), + device=device, + tokenizer=tokenizer, + ) + + history = trainer.fit(train_loaders, val_loaders) + + checkpoint_path = Path(args.checkpoint_out) + checkpoint_path.parent.mkdir(parents=True, exist_ok=True) + save_state(model, str(checkpoint_path)) + + labels_path = Path(args.labels_out) + save_label_metadata( + LabelMetadata( + emotion=emotion_train.emotion_classes, + topic=topic_train.topic_classes, + ), + labels_path, + ) + + history_path = Path(args.history_out) + history_path.parent.mkdir(parents=True, exist_ok=True) + with history_path.open("w", encoding="utf-8") as handle: + json.dump(history, handle, indent=2) + + print(f"Training complete. Checkpoint saved to {checkpoint_path}") + print(f"Label metadata saved to {labels_path}") + print(f"History saved to {history_path}") + if __name__ == "__main__": - config = load_config("configs/training/default.yaml") - trainer = Trainer(config) - trainer.train() \ No newline at end of file + main() diff --git a/setup.py b/setup.py index 0c774b5bac7ca69834c0e203546fa5fd4567786f..04f1cdefad82342c8c751fdc1d17ea3a5af0473f 100644 --- a/setup.py +++ b/setup.py @@ -7,13 +7,23 @@ setup( package_dir={"": "src"}, install_requires=[ "torch>=2.0.0", - "transformers>=4.30.0", - # ... (or read from requirements.txt) + "transformers>=4.40.0", + "scikit-learn>=1.4.0", + "numpy>=1.24.0", + "pandas>=2.0.0", ], - entry_points={ - "console_scripts": [ - "leximind-train=scripts.train:main", - "leximind-infer=scripts.inference:main", + extras_require={ + "web": [ + "streamlit>=1.25.0", + "plotly>=5.18.0", + ], + "api": [ + "fastapi>=0.110.0", + ], + "all": [ + "streamlit>=1.25.0", + "plotly>=5.18.0", + "fastapi>=0.110.0", ], }, ) \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..22c157cb851bb4c3986e1e7901d9ec29e8e66c4c 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -0,0 +1 @@ +"""LexiMind core package.""" diff --git a/src/api/__init__.py b/src/api/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..9b47afc64051a5e91b7ac9198b18d087f708b901 100644 --- a/src/api/__init__.py +++ b/src/api/__init__.py @@ -0,0 +1 @@ +"""API surface for LexiMind.""" diff --git a/src/api/app.py b/src/api/app.py new file mode 100644 index 0000000000000000000000000000000000000000..8b9e6fce1920522a457a7c1d36654243bddd4d54 --- /dev/null +++ b/src/api/app.py @@ -0,0 +1,10 @@ +"""FastAPI application entrypoint.""" +from fastapi import FastAPI + +from .routes import router + + +def create_app() -> FastAPI: + app = FastAPI(title="LexiMind") + app.include_router(router) + return app diff --git a/src/api/dependencies.py b/src/api/dependencies.py new file mode 100644 index 0000000000000000000000000000000000000000..0dc43217efa6f056290f602c3d8e6fafd737a67d --- /dev/null +++ b/src/api/dependencies.py @@ -0,0 +1,42 @@ +"""Dependency providers for the FastAPI application.""" +from __future__ import annotations + +from functools import lru_cache +from pathlib import Path + +from fastapi import HTTPException, status + +from ..utils.logging import get_logger +logger = get_logger(__name__) + +from ..inference.factory import create_inference_pipeline +from ..inference.pipeline import InferencePipeline + + +@lru_cache(maxsize=1) +def get_pipeline() -> InferencePipeline: + """Lazily construct and cache the inference pipeline for the API.""" + + checkpoint = Path("checkpoints/best.pt") + labels = Path("artifacts/labels.json") + model_config = Path("configs/model/base.yaml") + + try: + pipeline, _ = create_inference_pipeline( + checkpoint_path=checkpoint, + labels_path=labels, + model_config_path=model_config, + ) + except FileNotFoundError as exc: + logger.exception("Pipeline initialization failed: missing artifact") + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Service temporarily unavailable", + ) from exc + except Exception as exc: # noqa: BLE001 - surface initialization issues to the caller + logger.exception("Pipeline initialization failed") + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Service temporarily unavailable", + ) from exc + return pipeline diff --git a/src/api/inference/__init__.py b/src/api/inference/__init__.py deleted file mode 100644 index 57ddc42c07815c4a7547da8bd6852c8585aff2dd..0000000000000000000000000000000000000000 --- a/src/api/inference/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -""" -API inference module for LexiMind. -""" - -from .inference import load_models, summarize_text, classify_emotion, topic_for_text - -__all__ = ["load_models", "summarize_text", "classify_emotion", "topic_for_text"] diff --git a/src/api/inference/inference.py b/src/api/inference/inference.py deleted file mode 100644 index 13fe9468a0cf5df4f57f7cc16b61c96a199778d3..0000000000000000000000000000000000000000 --- a/src/api/inference/inference.py +++ /dev/null @@ -1,133 +0,0 @@ -"""Minimal inference helpers that rely on the custom transformer stack.""" - -from __future__ import annotations - -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple - -import torch - -from ...data.preprocessing import TextPreprocessor, TransformerTokenizer -from ...models.multitask import MultiTaskModel - - -def _load_tokenizer(tokenizer_path: Path) -> TransformerTokenizer: - if not tokenizer_path.exists(): - raise FileNotFoundError(f"tokenizer file '{tokenizer_path}' not found") - return TransformerTokenizer.load(tokenizer_path) - - -def load_models(config: Dict[str, Any]) -> Dict[str, Any]: - """Load MultiTaskModel together with the tokenizer-driven preprocessor.""" - - device = torch.device(config.get("device", "cpu")) - tokenizer_path = config.get("tokenizer_path") - if tokenizer_path is None: - raise ValueError("'tokenizer_path' missing in config") - - tokenizer = _load_tokenizer(Path(tokenizer_path)) - preprocessor = TextPreprocessor( - max_length=int(config.get("max_length", 512)), - tokenizer=tokenizer, - min_freq=int(config.get("min_freq", 1)), - lowercase=bool(config.get("lowercase", True)), - ) - - encoder_kwargs = dict(config.get("encoder", {})) - decoder_kwargs = dict(config.get("decoder", {})) - - encoder = preprocessor.build_encoder(**encoder_kwargs) - decoder = preprocessor.build_decoder(**decoder_kwargs) - model = MultiTaskModel(encoder=encoder, decoder=decoder) - - checkpoint_path = config.get("checkpoint_path") - if checkpoint_path: - state = torch.load(checkpoint_path, map_location=device) - if isinstance(state, dict) and "state_dict" in state: - state = state["state_dict"] - model.load_state_dict(state, strict=False) - - model.to(device) - - return { - "loaded": True, - "device": device, - "mt": model, - "preprocessor": preprocessor, - } - - -def summarize_text( - text: str, - compression: float = 0.25, - collect_attn: bool = False, - models: Optional[Dict[str, Any]] = None, -) -> Tuple[str, Optional[Dict[str, torch.Tensor]]]: - if models is None or not models.get("loaded"): - raise RuntimeError("Models must be loaded via load_models before summarize_text is called") - - model: MultiTaskModel = models["mt"] - preprocessor: TextPreprocessor = models["preprocessor"] - device: torch.device = models["device"] - - batch = preprocessor.batch_encode([text]) - tokenizer = preprocessor.tokenizer - encoder = model.encoder - decoder = model.decoder - if tokenizer is None or encoder is None or decoder is None: - raise RuntimeError("Encoder, decoder, and tokenizer must be configured before summarization") - input_ids = batch.input_ids.to(device) - memory = encoder(input_ids) - src_len = batch.lengths[0] - max_tgt = max(4, int(src_len * compression)) - generated = decoder.greedy_decode( - memory, - max_len=min(preprocessor.max_length, max_tgt), - start_token_id=tokenizer.bos_id, - end_token_id=tokenizer.eos_id, - ) - summary = tokenizer.decode(generated[0].tolist(), skip_special_tokens=True) - return summary.strip(), None if not collect_attn else {} - - -def classify_emotion(text: str, models: Optional[Dict[str, Any]] = None) -> Tuple[List[float], List[str]]: - if models is None or not models.get("loaded"): - raise RuntimeError("Models must be loaded via load_models before classify_emotion is called") - - model: MultiTaskModel = models["mt"] - preprocessor: TextPreprocessor = models["preprocessor"] - device: torch.device = models["device"] - - batch = preprocessor.batch_encode([text]) - input_ids = batch.input_ids.to(device) - result = model.forward("emotion", {"input_ids": input_ids}) - logits = result[1] if isinstance(result, tuple) else result - scores = torch.sigmoid(logits).squeeze(0).detach().cpu().tolist() - labels = models.get("emotion_labels") or [ - "joy", - "sadness", - "anger", - "fear", - "surprise", - "disgust", - ] - return scores, labels[: len(scores)] - - -def topic_for_text(text: str, models: Optional[Dict[str, Any]] = None) -> Tuple[int, List[str]]: - if models is None or not models.get("loaded"): - raise RuntimeError("Models must be loaded via load_models before topic_for_text is called") - - model: MultiTaskModel = models["mt"] - preprocessor: TextPreprocessor = models["preprocessor"] - device: torch.device = models["device"] - - batch = preprocessor.batch_encode([text]) - input_ids = batch.input_ids.to(device) - encoder = model.encoder - if encoder is None: - raise RuntimeError("Encoder must be configured before topic_for_text is called") - memory = encoder(input_ids) - embedding = memory.mean(dim=1).detach().cpu() - _ = embedding # placeholder for downstream clustering hook - return 0, ["topic_stub"] \ No newline at end of file diff --git a/src/api/routes.py b/src/api/routes.py new file mode 100644 index 0000000000000000000000000000000000000000..a5c0d814cfbd50b7287d7a052c840e601acc03ca --- /dev/null +++ b/src/api/routes.py @@ -0,0 +1,34 @@ +"""API routes.""" +from typing import cast + +from fastapi import APIRouter, Depends, HTTPException, status + +from ..inference import EmotionPrediction, InferencePipeline, TopicPrediction +from .dependencies import get_pipeline +from .schemas import SummaryRequest, SummaryResponse + +router = APIRouter() + + +@router.post("/summarize", response_model=SummaryResponse) +def summarize(payload: SummaryRequest, pipeline: InferencePipeline = Depends(get_pipeline)) -> SummaryResponse: + try: + outputs = pipeline.batch_predict([payload.text]) + except Exception as exc: # noqa: BLE001 - surface inference error to client + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(exc), + ) from exc + summaries = cast(list[str], outputs["summaries"]) + emotion_preds = cast(list[EmotionPrediction], outputs["emotion"]) + topic_preds = cast(list[TopicPrediction], outputs["topic"]) + + emotion = emotion_preds[0] + topic = topic_preds[0] + return SummaryResponse( + summary=summaries[0], + emotion_labels=emotion.labels, + emotion_scores=emotion.scores, + topic=topic.label, + topic_confidence=topic.confidence, + ) diff --git a/src/api/schemas.py b/src/api/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..d59cf1f285a79f9374958ec4928dbd1c977cb773 --- /dev/null +++ b/src/api/schemas.py @@ -0,0 +1,14 @@ +"""API schemas.""" +from pydantic import BaseModel + + +class SummaryRequest(BaseModel): + text: str + + +class SummaryResponse(BaseModel): + summary: str + emotion_labels: list[str] + emotion_scores: list[float] + topic: str + topic_confidence: float diff --git a/src/data/__init__.py b/src/data/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..68f7f9856bbb6d04b50ae32f56d8f4525e1170f6 100644 --- a/src/data/__init__.py +++ b/src/data/__init__.py @@ -0,0 +1 @@ +"""Data utilities for LexiMind.""" diff --git a/src/data/dataloader.py b/src/data/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..339aa38e2aaeeb7b0830e589eb79144571d2b37b --- /dev/null +++ b/src/data/dataloader.py @@ -0,0 +1,117 @@ +"""Task-aware DataLoader builders for the LexiMind multitask suite.""" +from __future__ import annotations + +from typing import Iterable, List + +import torch +from torch.utils.data import DataLoader + +from .dataset import EmotionDataset, EmotionExample, SummarizationDataset, SummarizationExample, TopicDataset, TopicExample +from .tokenization import Tokenizer + + +class SummarizationCollator: + """Prepare encoder-decoder batches for abstractive summarization.""" + + def __init__(self, tokenizer: Tokenizer, *, max_source_length: int | None = None, max_target_length: int | None = None) -> None: + self.tokenizer = tokenizer + self.max_source_length = max_source_length + self.max_target_length = max_target_length + + def __call__(self, batch: List[SummarizationExample]) -> dict[str, torch.Tensor]: + sources = [example.source for example in batch] + targets = [example.summary for example in batch] + + source_enc = self.tokenizer.batch_encode(sources, max_length=self.max_source_length) + target_enc = self.tokenizer.batch_encode(targets, max_length=self.max_target_length) + + labels = target_enc["input_ids"].clone() + decoder_input_ids = self.tokenizer.prepare_decoder_inputs(target_enc["input_ids"]) + labels[target_enc["attention_mask"] == 0] = -100 + + return { + "src_ids": source_enc["input_ids"], + "src_mask": source_enc["attention_mask"], + "tgt_ids": decoder_input_ids, + "labels": labels, + } + + +class EmotionCollator: + """Prepare batches for multi-label emotion classification.""" + + def __init__(self, tokenizer: Tokenizer, dataset: EmotionDataset, *, max_length: int | None = None) -> None: + self.tokenizer = tokenizer + self.binarizer = dataset.binarizer + self.max_length = max_length + + def __call__(self, batch: List[EmotionExample]) -> dict[str, torch.Tensor]: + texts = [example.text for example in batch] + encoded = self.tokenizer.batch_encode(texts, max_length=self.max_length) + label_array = self.binarizer.transform([example.emotions for example in batch]) + labels = torch.as_tensor(label_array, dtype=torch.float32) + return { + "input_ids": encoded["input_ids"], + "attention_mask": encoded["attention_mask"], + "labels": labels, + } + + +class TopicCollator: + """Prepare batches for topic classification using the projection head.""" + + def __init__(self, tokenizer: Tokenizer, dataset: TopicDataset, *, max_length: int | None = None) -> None: + self.tokenizer = tokenizer + self.encoder = dataset.encoder + self.max_length = max_length + + def __call__(self, batch: List[TopicExample]) -> dict[str, torch.Tensor]: + texts = [example.text for example in batch] + encoded = self.tokenizer.batch_encode(texts, max_length=self.max_length) + labels = torch.as_tensor(self.encoder.transform([example.topic for example in batch]), dtype=torch.long) + return { + "input_ids": encoded["input_ids"], + "attention_mask": encoded["attention_mask"], + "labels": labels, + } + + +def build_summarization_dataloader( + dataset: SummarizationDataset, + tokenizer: Tokenizer, + *, + batch_size: int, + shuffle: bool = True, + max_source_length: int | None = None, + max_target_length: int | None = None, +) -> DataLoader: + collator = SummarizationCollator( + tokenizer, + max_source_length=max_source_length, + max_target_length=max_target_length, + ) + return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collator) + + +def build_emotion_dataloader( + dataset: EmotionDataset, + tokenizer: Tokenizer, + *, + batch_size: int, + shuffle: bool = True, + max_length: int | None = None, +) -> DataLoader: + collator = EmotionCollator(tokenizer, dataset, max_length=max_length) + return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collator) + + +def build_topic_dataloader( + dataset: TopicDataset, + tokenizer: Tokenizer, + *, + batch_size: int, + shuffle: bool = True, + max_length: int | None = None, +) -> DataLoader: + collator = TopicCollator(tokenizer, dataset, max_length=max_length) + return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collator) diff --git a/src/data/dataset.py b/src/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7cf0c794a9b18aeea7f69877f3c96d61dff6fcde --- /dev/null +++ b/src/data/dataset.py @@ -0,0 +1,229 @@ +"""Dataset definitions for the LexiMind multitask training pipeline.""" +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Iterable, List, Sequence, TypeVar + +from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer +from torch.utils.data import Dataset + + +@dataclass(slots=True) +class SummarizationExample: + """Container for abstractive summarization samples.""" + + source: str + summary: str + + +@dataclass(slots=True) +class EmotionExample: + """Container for multi-label emotion classification samples.""" + + text: str + emotions: Sequence[str] + + +@dataclass(slots=True) +class TopicExample: + """Container for topic clustering / classification samples.""" + + text: str + topic: str + + +class SummarizationDataset(Dataset[SummarizationExample]): + """Dataset yielding encoder-decoder training pairs.""" + + def __init__(self, examples: Iterable[SummarizationExample]) -> None: + self._examples = list(examples) + + def __len__(self) -> int: + return len(self._examples) + + def __getitem__(self, index: int) -> SummarizationExample: + return self._examples[index] + + +class EmotionDataset(Dataset[EmotionExample]): + """Dataset that owns a scikit-learn MultiLabelBinarizer for emissions.""" + + def __init__( + self, + examples: Iterable[EmotionExample], + *, + binarizer: MultiLabelBinarizer | None = None, + ) -> None: + self._examples = list(examples) + all_labels = [example.emotions for example in self._examples] + if binarizer is None: + self._binarizer = MultiLabelBinarizer() + self._binarizer.fit(all_labels) + else: + self._binarizer = binarizer + if not hasattr(self._binarizer, "classes_"): + raise ValueError( + "Provided MultiLabelBinarizer must be pre-fitted with 'classes_' attribute." + ) + + def __len__(self) -> int: + return len(self._examples) + + def __getitem__(self, index: int) -> EmotionExample: + return self._examples[index] + + @property + def binarizer(self) -> MultiLabelBinarizer: + return self._binarizer + + @property + def emotion_classes(self) -> List[str]: + return list(self._binarizer.classes_) + + +class TopicDataset(Dataset[TopicExample]): + """Dataset that owns a LabelEncoder for topic ids.""" + + def __init__( + self, + examples: Iterable[TopicExample], + *, + encoder: LabelEncoder | None = None, + ) -> None: + self._examples = list(examples) + topics = [example.topic for example in self._examples] + if encoder is None: + self._encoder = LabelEncoder().fit(topics) + else: + self._encoder = encoder + if not hasattr(self._encoder, "classes_"): + raise ValueError( + "Provided LabelEncoder must be pre-fitted with 'classes_' attribute." + ) + + def __len__(self) -> int: + return len(self._examples) + + def __getitem__(self, index: int) -> TopicExample: + return self._examples[index] + + @property + def encoder(self) -> LabelEncoder: + return self._encoder + + @property + def topic_classes(self) -> List[str]: + return list(self._encoder.classes_) + + +T = TypeVar("T") + + +def _safe_json_load(handle, path: Path) -> object: + try: + return json.load(handle) + except json.JSONDecodeError as exc: + raise ValueError(f"Failed to parse JSON in '{path}': {exc}") from exc + + +def _safe_json_loads(data: str, path: Path, line_number: int) -> object: + try: + return json.loads(data) + except json.JSONDecodeError as exc: + raise ValueError(f"Failed to parse JSON in '{path}' at line {line_number}: {exc}") from exc + + +def _validate_keys( + payload: dict, + required_keys: Sequence[str], + position: int, + *, + path: Path, + is_array: bool = False, +) -> None: + missing = [key for key in required_keys if key not in payload] + if missing: + keys = ", ".join(sorted(missing)) + location = "index" if is_array else "line" + raise KeyError(f"Missing required keys ({keys}) at {location} {position} of '{path}'") + + +def _load_jsonl_generic( + path: str, + constructor: Callable[[dict], T], + required_keys: Sequence[str], +) -> List[T]: + data_path = Path(path) + if not data_path.exists(): + raise FileNotFoundError(f"Dataset file '{data_path}' does not exist") + if not data_path.is_file(): + raise ValueError(f"Dataset path '{data_path}' is not a file") + + items: List[T] = [] + with data_path.open("r", encoding="utf-8") as handle: + first_non_ws = "" + while True: + pos = handle.tell() + char = handle.read(1) + if not char: + break + if not char.isspace(): + first_non_ws = char + handle.seek(pos) + break + if not first_non_ws: + raise ValueError(f"Dataset file '{data_path}' is empty or contains only whitespace") + + if first_non_ws == "[": + payloads = _safe_json_load(handle, data_path) + if not isinstance(payloads, list): + raise ValueError(f"Expected a JSON array in '{data_path}' but found {type(payloads).__name__}") + for idx, payload in enumerate(payloads): + if not isinstance(payload, dict): + raise ValueError( + f"Expected objects in array for '{data_path}', found {type(payload).__name__} at index {idx}" + ) + _validate_keys(payload, required_keys, idx, path=data_path, is_array=True) + items.append(constructor(payload)) + else: + handle.seek(0) + line_number = 0 + for line in handle: + line_number += 1 + if not line.strip(): + continue + payload = _safe_json_loads(line, data_path, line_number) + if not isinstance(payload, dict): + raise ValueError( + f"Expected JSON object per line in '{data_path}', found {type(payload).__name__} at line {line_number}" + ) + _validate_keys(payload, required_keys, line_number, path=data_path) + items.append(constructor(payload)) + + return items + + +def load_summarization_jsonl(path: str) -> List[SummarizationExample]: + return _load_jsonl_generic( + path, + lambda payload: SummarizationExample(source=payload["source"], summary=payload["summary"]), + required_keys=("source", "summary"), + ) + + +def load_emotion_jsonl(path: str) -> List[EmotionExample]: + return _load_jsonl_generic( + path, + lambda payload: EmotionExample(text=payload["text"], emotions=payload.get("emotions", [])), + required_keys=("text",), + ) + + +def load_topic_jsonl(path: str) -> List[TopicExample]: + return _load_jsonl_generic( + path, + lambda payload: TopicExample(text=payload["text"], topic=payload["topic"]), + required_keys=("text", "topic"), + ) diff --git a/src/data/download.py b/src/data/download.py index 7cf4f64e3214d8e40644ced28491da75ec740915..9257e036e02b1a07171a19d89db2f456fc8bffb9 100644 --- a/src/data/download.py +++ b/src/data/download.py @@ -1,66 +1,45 @@ -""" -Download helpers for datasets. +"""Dataset download helpers.""" -This version: -- Adds robust error handling when Kaggle API is not configured. -- Stores files under data/raw/ subfolders. -- Keeps the Gutenberg direct download example. +import socket +from pathlib import Path +from subprocess import CalledProcessError, run +from urllib.error import URLError +from urllib.request import urlopen -Make sure you have Kaggle credentials configured if you call Kaggle downloads. -""" -import os -import requests -def download_gutenberg(out_dir="data/raw/books", gutenberg_id: int = 1342, filename: str = "pride_and_prejudice.txt"): - """Download a Gutenberg text file by direct URL template (best-effort).""" - url = f"https://www.gutenberg.org/files/{gutenberg_id}/{gutenberg_id}-0.txt" - os.makedirs(out_dir, exist_ok=True) - out_path = os.path.join(out_dir, filename) - if os.path.exists(out_path): - print("Already downloaded:", out_path) - return out_path - try: - r = requests.get(url, timeout=30) - r.raise_for_status() - with open(out_path, "wb") as f: - f.write(r.content) - print("Downloaded:", out_path) - return out_path - except Exception as e: - print("Failed to download Gutenberg file:", e) - return None +DOWNLOAD_TIMEOUT = 60 + -# Kaggle helpers: optional, wrapped to avoid hard failure when Kaggle isn't configured. -def _safe_kaggle_download(dataset: str, path: str): +def kaggle_download(dataset: str, output_dir: str) -> None: + target = Path(output_dir) + target.mkdir(parents=True, exist_ok=True) try: - import kaggle - except Exception as e: - print("Kaggle package not available or not configured. Please install 'kaggle' and configure API token. Error:", e) - return False + run([ + "kaggle", + "datasets", + "download", + "-d", + dataset, + "-p", + str(target), + "--unzip", + ], check=True) + except CalledProcessError as error: + raise RuntimeError( + "Kaggle download failed. Verify that the Kaggle CLI is authenticated," + " you have accepted the dataset terms on kaggle.com, and your kaggle.json" + " credentials are located in %USERPROFILE%/.kaggle." + ) from error + + +def gutenberg_download(url: str, output_path: str) -> None: + target = Path(output_path) + target.parent.mkdir(parents=True, exist_ok=True) try: - os.makedirs(path, exist_ok=True) - kaggle.api.authenticate() - kaggle.api.dataset_download_files(dataset, path=path, unzip=True) - print(f"Downloaded Kaggle dataset {dataset} to {path}") - return True - except Exception as e: - print("Failed to download Kaggle dataset:", e) - return False - -def download_emotion_dataset(): - target_dir = "data/raw/emotion" - return _safe_kaggle_download('praveengovi/emotions-dataset-for-nlp', target_dir) - -def download_cnn_dailymail(): - target_dir = "data/raw/summarization" - return _safe_kaggle_download('gowrishankarp/newspaper-text-summarization-cnn-dailymail', target_dir) - -def download_ag_news(): - target_dir = "data/raw/topic" - return _safe_kaggle_download('amananandrai/ag-news-classification-dataset', target_dir) - -if __name__ == "__main__": - download_gutenberg() - download_emotion_dataset() - download_cnn_dailymail() - download_ag_news() \ No newline at end of file + with urlopen(url, timeout=DOWNLOAD_TIMEOUT) as response, target.open("wb") as handle: + chunk = response.read(8192) + while chunk: + handle.write(chunk) + chunk = response.read(8192) + except (URLError, socket.timeout, OSError) as error: + raise RuntimeError(f"Failed to download '{url}' to '{target}': {error}") from error diff --git a/src/data/preprocessing.py b/src/data/preprocessing.py index ac1e45f2aaf44cd8f3a2bf7f8ced088e1a925f6e..753bcee038985b4854f21413f23e5a319ebfe5cd 100644 --- a/src/data/preprocessing.py +++ b/src/data/preprocessing.py @@ -1,260 +1,130 @@ -"""Lightweight preprocessing utilities built around the in-repo transformer.""" - +"""Text preprocessing utilities built around Hugging Face tokenizers.""" from __future__ import annotations -from collections import Counter -from dataclasses import dataclass -import json -from pathlib import Path import re -from typing import Dict, Iterable, List, Optional, Sequence, Tuple +from dataclasses import dataclass, replace +from typing import Iterable, List, Sequence import torch +from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS -from ..models.decoder import TransformerDecoder -from ..models.encoder import TransformerEncoder - -SPECIAL_TOKENS: Tuple[str, str, str, str] = ("", "", "", "") - - -def _normalize(text: str, lowercase: bool) -> str: - text = text.strip() - text = re.sub(r"\s+", " ", text) - if lowercase: - text = text.lower() - return text - +from .tokenization import Tokenizer, TokenizerConfig -def _basic_tokenize(text: str) -> List[str]: - return re.findall(r"\b\w+\b|[.,;:?!]", text) +class BasicTextCleaner(BaseEstimator, TransformerMixin): + """Minimal text cleaner following scikit-learn conventions.""" -class TransformerTokenizer: - """Minimal tokenizer that keeps vocabulary aligned with the custom transformer.""" - - def __init__( - self, - stoi: Dict[str, int], - itos: List[str], - specials: Sequence[str] = SPECIAL_TOKENS, - lowercase: bool = True, - ) -> None: - self.stoi = stoi - self.itos = itos - self.specials = tuple(specials) + def __init__(self, lowercase: bool = True, strip: bool = True) -> None: self.lowercase = lowercase - self.pad_id = self._lookup(self.specials[0]) - self.bos_id = self._lookup(self.specials[1]) - self.eos_id = self._lookup(self.specials[2]) - self.unk_id = self._lookup(self.specials[3]) - - @classmethod - def build( - cls, - texts: Iterable[str], - min_freq: int = 1, - lowercase: bool = True, - specials: Sequence[str] = SPECIAL_TOKENS, - ) -> "TransformerTokenizer": - counter: Counter[str] = Counter() - for text in texts: - normalized = _normalize(text, lowercase) - counter.update(_basic_tokenize(normalized)) - - ordered_specials = list(dict.fromkeys(specials)) - itos: List[str] = ordered_specials.copy() - for token, freq in counter.most_common(): - if freq < min_freq: - continue - if token in itos: - continue - itos.append(token) + self.strip = strip - stoi = {token: idx for idx, token in enumerate(itos)} - return cls(stoi=stoi, itos=itos, specials=ordered_specials, lowercase=lowercase) + def fit(self, texts: Iterable[str], y: Iterable[str] | None = None): + return self - @property - def vocab_size(self) -> int: - return len(self.itos) + def transform(self, texts: Iterable[str]) -> List[str]: + return [self._clean_text(text) for text in texts] - def tokenize(self, text: str) -> List[str]: - normalized = _normalize(text, self.lowercase) - return _basic_tokenize(normalized) + def _clean_text(self, text: str) -> str: + item = text.strip() if self.strip else text + if self.lowercase: + item = item.lower() + return " ".join(item.split()) - def encode( - self, - text: str, - add_special_tokens: bool = True, - max_length: Optional[int] = None, - ) -> List[int]: - tokens = self.tokenize(text) - pieces = [self.stoi.get(tok, self.unk_id) for tok in tokens] - if add_special_tokens: - pieces = [self.bos_id] + pieces + [self.eos_id] - - if max_length is not None and len(pieces) > max_length: - if add_special_tokens and max_length >= 2: - inner_max = max_length - 2 - trimmed = pieces[1:-1][:inner_max] - pieces = [self.bos_id] + trimmed + [self.eos_id] - else: - pieces = pieces[:max_length] - return pieces - def decode(self, ids: Sequence[int], skip_special_tokens: bool = True) -> str: - tokens: List[str] = [] - for idx in ids: - if idx < 0 or idx >= len(self.itos): - continue - token = self.itos[idx] - if skip_special_tokens and token in self.specials: - continue - tokens.append(token) - return " ".join(tokens).strip() - - def pad_batch( - self, - sequences: Sequence[Sequence[int]], - pad_to_length: Optional[int] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - if not sequences: - raise ValueError("pad_batch requires at least one sequence") - if pad_to_length is None: - pad_to_length = max(len(seq) for seq in sequences) - padded: List[List[int]] = [] - mask: List[List[int]] = [] - for seq in sequences: - trimmed = list(seq[:pad_to_length]) - pad_len = pad_to_length - len(trimmed) - padded.append(trimmed + [self.pad_id] * pad_len) - mask.append([1] * len(trimmed) + [0] * pad_len) - return torch.tensor(padded, dtype=torch.long), torch.tensor(mask, dtype=torch.bool) - - def save(self, path: Path) -> None: - payload = { - "itos": self.itos, - "specials": list(self.specials), - "lowercase": self.lowercase, - } - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") - - @classmethod - def load(cls, path: Path) -> "TransformerTokenizer": - data = json.loads(path.read_text(encoding="utf-8")) - itos = list(data["itos"]) - stoi = {token: idx for idx, token in enumerate(itos)} - specials = data.get("specials", list(SPECIAL_TOKENS)) - lowercase = bool(data.get("lowercase", True)) - return cls(stoi=stoi, itos=itos, specials=specials, lowercase=lowercase) - - def _lookup(self, token: str) -> int: - if token not in self.stoi: - raise ValueError(f"token '{token}' missing from vocabulary") - return self.stoi[token] - - -@dataclass +@dataclass(slots=True) class Batch: + """Bundle of tensors returned by the text preprocessor.""" + input_ids: torch.Tensor attention_mask: torch.Tensor lengths: List[int] class TextPreprocessor: - """Prepares text so it can flow directly into the custom transformer stack.""" + """Coordinate lightweight text cleaning and tokenization. + + When supplying an already-initialized tokenizer instance, its configuration is left + untouched. If a differing ``max_length`` is requested, a ``ValueError`` is raised to + avoid mutating shared tokenizer state. + """ def __init__( self, - max_length: int = 512, - tokenizer: Optional[TransformerTokenizer] = None, + tokenizer: Tokenizer | None = None, *, - min_freq: int = 1, + tokenizer_config: TokenizerConfig | None = None, + tokenizer_name: str = "facebook/bart-base", + max_length: int | None = None, lowercase: bool = True, + remove_stopwords: bool = False, + sklearn_transformer: TransformerMixin | None = None, ) -> None: - self.max_length = max_length - self.min_freq = min_freq + self.cleaner = BasicTextCleaner(lowercase=lowercase, strip=True) self.lowercase = lowercase - self.tokenizer = tokenizer + if remove_stopwords: + raise ValueError( + "Stop-word removal is not supported because it conflicts with subword tokenizers; " + "clean the text externally before initializing TextPreprocessor." + ) + self._stop_words = None + self._sklearn_transformer = sklearn_transformer + + if tokenizer is None: + cfg = tokenizer_config or TokenizerConfig(pretrained_model_name=tokenizer_name) + if max_length is not None: + cfg = replace(cfg, max_length=max_length) + self.tokenizer = Tokenizer(cfg) + else: + self.tokenizer = tokenizer + if max_length is not None and max_length != tokenizer.config.max_length: + raise ValueError( + "Provided tokenizer config.max_length does not match requested max_length; " + "initialise the tokenizer with desired settings before passing it in." + ) + + self.max_length = max_length or self.tokenizer.config.max_length def clean_text(self, text: str) -> str: - return _normalize(text, self.lowercase) - - def fit_tokenizer(self, texts: Iterable[str]) -> TransformerTokenizer: - cleaned = [self.clean_text(text) for text in texts] - self.tokenizer = TransformerTokenizer.build( - cleaned, - min_freq=self.min_freq, - lowercase=False, - ) - return self.tokenizer - - def encode(self, text: str, *, add_special_tokens: bool = True) -> List[int]: - if self.tokenizer is None: - raise RuntimeError("Tokenizer not fitted") - cleaned = self.clean_text(text) - return self.tokenizer.encode(cleaned, add_special_tokens=add_special_tokens, max_length=self.max_length) + item = self.cleaner.transform([text])[0] + return self._normalize_tokens(item) + + def _normalize_tokens(self, text: str) -> str: + """Apply token-level normalization and optional stop-word filtering.""" + # Note: Pre-tokenization word-splitting is incompatible with subword tokenizers. + # Stop-word filtering should be done post-tokenization or not at all for transformers. + return text + + def _apply_sklearn_transform(self, texts: List[str]) -> List[str]: + if self._sklearn_transformer is None: + return texts + + transform = getattr(self._sklearn_transformer, "transform", None) + if transform is None: + raise AttributeError("Provided sklearn transformer must implement a 'transform' method") + transformed = transform(texts) + if isinstance(transformed, list): + return transformed # assume downstream type is already list[str] + if hasattr(transformed, "tolist"): + transformed = transformed.tolist() + + result = list(transformed) + if not all(isinstance(item, str) for item in result): + result = [str(item) for item in result] + return result + + def _prepare_texts(self, texts: Sequence[str]) -> List[str]: + cleaned = self.cleaner.transform(texts) + normalized = [self._normalize_tokens(text) for text in cleaned] + return self._apply_sklearn_transform(normalized) def batch_encode(self, texts: Sequence[str]) -> Batch: - if self.tokenizer is None: - raise RuntimeError("Tokenizer not fitted") - sequences = [self.encode(text) for text in texts] - lengths = [len(seq) for seq in sequences] - input_ids, attention_mask = self.tokenizer.pad_batch(sequences, pad_to_length=self.max_length) + cleaned = self._prepare_texts(texts) + encoded = self.tokenizer.batch_encode(cleaned, max_length=self.max_length) + input_ids: torch.Tensor = encoded["input_ids"] + attention_mask: torch.Tensor = encoded["attention_mask"].to(dtype=torch.bool) + lengths = attention_mask.sum(dim=1).tolist() return Batch(input_ids=input_ids, attention_mask=attention_mask, lengths=lengths) - def build_encoder(self, **encoder_kwargs) -> TransformerEncoder: - if self.tokenizer is None: - raise RuntimeError("Tokenizer not fitted") - return TransformerEncoder( - vocab_size=self.tokenizer.vocab_size, - max_len=self.max_length, - pad_token_id=self.tokenizer.pad_id, - **encoder_kwargs, - ) - - def build_decoder(self, **decoder_kwargs) -> TransformerDecoder: - if self.tokenizer is None: - raise RuntimeError("Tokenizer not fitted") - return TransformerDecoder( - vocab_size=self.tokenizer.vocab_size, - max_len=self.max_length, - pad_token_id=self.tokenizer.pad_id, - **decoder_kwargs, - ) - - def save_tokenizer(self, path: Path) -> None: - if self.tokenizer is None: - raise RuntimeError("Tokenizer not fitted") - self.tokenizer.save(path) - - def load_tokenizer(self, path: Path) -> TransformerTokenizer: - self.tokenizer = TransformerTokenizer.load(path) - return self.tokenizer - - def chunk_text(self, text: str, *, chunk_size: int = 1000, overlap: int = 100) -> List[str]: - if chunk_size <= overlap: - raise ValueError("chunk_size must be larger than overlap") - words = self.clean_text(text).split() - chunks: List[str] = [] - start = 0 - while start < len(words): - end = min(start + chunk_size, len(words)) - chunks.append(" ".join(words[start:end])) - start += chunk_size - overlap - return chunks - - def save_book_chunks( - self, - input_path: Path, - out_dir: Path, - *, - chunk_size: int = 1000, - overlap: int = 100, - ) -> Path: - out_dir.mkdir(parents=True, exist_ok=True) - raw_text = input_path.read_text(encoding="utf-8", errors="ignore") - chunks = self.chunk_text(raw_text, chunk_size=chunk_size, overlap=overlap) - out_file = out_dir / f"{input_path.stem}.json" - out_file.write_text(json.dumps(chunks, ensure_ascii=False, indent=2), encoding="utf-8") - return out_file \ No newline at end of file + def __call__(self, texts: Sequence[str]) -> Batch: + return self.batch_encode(texts) diff --git a/src/data/tokenization.py b/src/data/tokenization.py new file mode 100644 index 0000000000000000000000000000000000000000..cb8a529bea239d5e73dca2702678bc4f955106b6 --- /dev/null +++ b/src/data/tokenization.py @@ -0,0 +1,122 @@ +"""Tokenizer wrapper around HuggingFace models used across LexiMind.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Iterable, List, Sequence, cast + +import torch +from transformers import AutoTokenizer, PreTrainedTokenizerBase + + +@dataclass(slots=True) +class TokenizerConfig: + pretrained_model_name: str = "facebook/bart-base" + max_length: int = 512 + padding: str = "longest" + truncation: bool = True + lower: bool = False + + +class Tokenizer: + """Lightweight faΓ§ade over a HuggingFace tokenizer.""" + + def __init__(self, config: TokenizerConfig | None = None) -> None: + cfg = config or TokenizerConfig() + self.config = cfg + self._tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(cfg.pretrained_model_name) + self._pad_token_id = self._resolve_id(self._tokenizer.pad_token_id) + self._bos_token_id = self._resolve_id( + self._tokenizer.bos_token_id if self._tokenizer.bos_token_id is not None else self._tokenizer.cls_token_id + ) + self._eos_token_id = self._resolve_id( + self._tokenizer.eos_token_id if self._tokenizer.eos_token_id is not None else self._tokenizer.sep_token_id + ) + + @property + def tokenizer(self) -> PreTrainedTokenizerBase: + return self._tokenizer + + @property + def pad_token_id(self) -> int: + return self._pad_token_id + + @property + def bos_token_id(self) -> int: + return self._bos_token_id + + @property + def eos_token_id(self) -> int: + return self._eos_token_id + + @property + def vocab_size(self) -> int: + vocab = getattr(self._tokenizer, "vocab_size", None) + if vocab is None: + raise RuntimeError("Tokenizer must expose vocab_size") + return int(vocab) + + @staticmethod + def _resolve_id(value) -> int: + if value is None: + raise ValueError("Tokenizer is missing required special token ids") + if isinstance(value, (list, tuple)): + value = value[0] + return int(value) + + def encode(self, text: str) -> List[int]: + content = text.lower() if self.config.lower else text + return self._tokenizer.encode( + content, + max_length=self.config.max_length, + truncation=self.config.truncation, + padding=self.config.padding, + ) + + def encode_batch(self, texts: Sequence[str]) -> List[List[int]]: + normalized = (text.lower() if self.config.lower else text for text in texts) + encoded = self._tokenizer.batch_encode_plus( + list(normalized), + max_length=self.config.max_length, + padding=self.config.padding, + truncation=self.config.truncation, + return_attention_mask=False, + return_tensors=None, + ) + return cast(List[List[int]], encoded["input_ids"]) + + def batch_encode(self, texts: Sequence[str], *, max_length: int | None = None) -> dict[str, torch.Tensor]: + normalized = [text.lower() if self.config.lower else text for text in texts] + encoded = self._tokenizer( + normalized, + padding=self.config.padding, + truncation=self.config.truncation, + max_length=max_length or self.config.max_length, + return_tensors="pt", + ) + input_ids = cast(torch.Tensor, encoded["input_ids"]) + attention_mask = cast(torch.Tensor, encoded["attention_mask"]) + if input_ids.dtype != torch.long: + input_ids = input_ids.to(dtype=torch.long) + if attention_mask.dtype != torch.bool: + attention_mask = attention_mask.to(dtype=torch.bool) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + + def decode(self, token_ids: Iterable[int]) -> str: + return self._tokenizer.decode(list(token_ids), skip_special_tokens=True) + + def decode_batch(self, sequences: Sequence[Sequence[int]]) -> List[str]: + prepared = [list(seq) for seq in sequences] + return self._tokenizer.batch_decode(prepared, skip_special_tokens=True) + + def prepare_decoder_inputs(self, labels: torch.Tensor) -> torch.Tensor: + """Shift decoder labels to create input ids prefixed by BOS.""" + + bos = self.bos_token_id + pad = self.pad_token_id + decoder_inputs = torch.full_like(labels, pad) + decoder_inputs[:, 0] = bos + decoder_inputs[:, 1:] = labels[:, :-1] + return decoder_inputs diff --git a/src/inference/__init__.py b/src/inference/__init__.py index a5359aeff70f14038cb26ca610568348de2c3117..3b97d7f5019861d5f86ef467a535ccf23519733c 100644 --- a/src/inference/__init__.py +++ b/src/inference/__init__.py @@ -1,7 +1,12 @@ -""" -Inference utilities for LexiMind. -""" +"""Inference tools for LexiMind.""" -from .baseline_summarizer import Summarizer, TransformerSummarizer +from .factory import create_inference_pipeline +from .pipeline import EmotionPrediction, InferenceConfig, InferencePipeline, TopicPrediction -__all__ = ["Summarizer", "TransformerSummarizer"] +__all__ = [ + "InferencePipeline", + "InferenceConfig", + "EmotionPrediction", + "TopicPrediction", + "create_inference_pipeline", +] diff --git a/src/inference/baseline_summarizer.py b/src/inference/baseline_summarizer.py deleted file mode 100644 index b32d8ed145850a8176b17aee5c34cb1e5b07111c..0000000000000000000000000000000000000000 --- a/src/inference/baseline_summarizer.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Thin wrapper around the custom transformer summarizer.""" - -from __future__ import annotations -from typing import Any, Dict, Optional, Tuple -import torch -from ..api.inference import load_models - - -class TransformerSummarizer: - def __init__(self, config: Optional[Dict[str, Any]] = None) -> None: - models = load_models(config or {}) - if not models.get("loaded"): - raise RuntimeError("load_models returned an unloaded model; check configuration") - self.model = models["mt"] - self.preprocessor = models["preprocessor"] - self.device = models["device"] - - def summarize( - self, - text: str, - compression: float = 0.25, - collect_attn: bool = False, - ) -> Tuple[str, Optional[Dict[str, torch.Tensor]]]: - batch = self.preprocessor.batch_encode([text]) - tokenizer = self.preprocessor.tokenizer - encoder = self.model.encoder - decoder = self.model.decoder - if tokenizer is None or encoder is None or decoder is None: - raise RuntimeError("Model components are missing; ensure encoder, decoder, and tokenizer are set") - input_ids = batch.input_ids.to(self.device) - memory = encoder(input_ids) - src_len = batch.lengths[0] - target_len = max(4, int(src_len * compression)) - generated = decoder.greedy_decode( - memory, - max_len=min(self.preprocessor.max_length, target_len), - start_token_id=tokenizer.bos_id, - end_token_id=tokenizer.eos_id, - ) - summary = tokenizer.decode(generated[0].tolist(), skip_special_tokens=True) - return summary.strip(), None if not collect_attn else {} diff --git a/src/inference/factory.py b/src/inference/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..09cc45baaa256d25045676c8767524ab815a01ee --- /dev/null +++ b/src/inference/factory.py @@ -0,0 +1,75 @@ +"""Helpers to assemble an inference pipeline from saved artifacts.""" +from __future__ import annotations + +from pathlib import Path +from typing import Tuple + +import torch + +from ..data.tokenization import Tokenizer, TokenizerConfig +from ..models.factory import ModelConfig, build_multitask_model, load_model_config +from ..utils.io import load_state +from ..utils.labels import LabelMetadata, load_label_metadata +from .pipeline import InferenceConfig, InferencePipeline + + +def create_inference_pipeline( + checkpoint_path: str | Path, + labels_path: str | Path, + *, + tokenizer_config: TokenizerConfig | None = None, + tokenizer_dir: str | Path | None = None, + model_config_path: str | Path | None = None, + device: str | torch.device = "cpu", + summary_max_length: int | None = None, +) -> Tuple[InferencePipeline, LabelMetadata]: + """Build an :class:`InferencePipeline` from saved model and label metadata.""" + + checkpoint = Path(checkpoint_path) + if not checkpoint.exists(): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint}") + + labels = load_label_metadata(labels_path) + + resolved_tokenizer_config = tokenizer_config + if resolved_tokenizer_config is None: + default_dir = Path(__file__).resolve().parent.parent.parent / "artifacts" / "hf_tokenizer" + chosen_dir = Path(tokenizer_dir) if tokenizer_dir is not None else default_dir + local_tokenizer_dir = chosen_dir + if local_tokenizer_dir.exists(): + resolved_tokenizer_config = TokenizerConfig(pretrained_model_name=str(local_tokenizer_dir)) + else: + raise ValueError( + "No tokenizer configuration provided and default tokenizer directory " + f"'{local_tokenizer_dir}' not found. Please provide tokenizer_config parameter or set tokenizer_dir." + ) + + tokenizer = Tokenizer(resolved_tokenizer_config) + model_config = load_model_config(model_config_path) + model = build_multitask_model( + tokenizer, + num_emotions=labels.emotion_size, + num_topics=labels.topic_size, + config=model_config, + ) + load_state(model, str(checkpoint)) + + if isinstance(device, torch.device): + device_str = str(device) + else: + device_str = device + + if summary_max_length is not None: + pipeline_config = InferenceConfig(summary_max_length=summary_max_length, device=device_str) + else: + pipeline_config = InferenceConfig(device=device_str) + + pipeline = InferencePipeline( + model=model, + tokenizer=tokenizer, + config=pipeline_config, + emotion_labels=labels.emotion, + topic_labels=labels.topic, + device=device, + ) + return pipeline, labels diff --git a/src/inference/generation.py b/src/inference/generation.py new file mode 100644 index 0000000000000000000000000000000000000000..974003d2afc367cc791386066459246a7dfc8774 --- /dev/null +++ b/src/inference/generation.py @@ -0,0 +1,14 @@ +"""Generation helpers.""" + +import torch + + +def greedy_decode(model: torch.nn.Module, input_ids: torch.Tensor, max_length: int) -> torch.Tensor: + """Run greedy decoding with ``model.generate`` and return generated token ids.""" + + return model.generate( + input_ids, + max_length=max_length, + do_sample=False, + num_beams=1, + ) diff --git a/src/inference/pipeline.py b/src/inference/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..2e5031ed5e2309636565860abe369dc3f3b8693f --- /dev/null +++ b/src/inference/pipeline.py @@ -0,0 +1,166 @@ +"""Inference helpers for multitask LexiMind models.""" +from __future__ import annotations + +from dataclasses import dataclass, fields, replace +from typing import Iterable, List, Sequence + +import torch +import torch.nn.functional as F + +from ..data.preprocessing import Batch, TextPreprocessor +from ..data.tokenization import Tokenizer + + +@dataclass(slots=True) +class InferenceConfig: + """Configuration knobs for the inference pipeline.""" + + summary_max_length: int = 128 + emotion_threshold: float = 0.5 + device: str | None = None + + +@dataclass(slots=True) +class EmotionPrediction: + labels: List[str] + scores: List[float] + + +@dataclass(slots=True) +class TopicPrediction: + label: str + confidence: float + + +class InferencePipeline: + """Run summarization, emotion, and topic heads through a unified interface.""" + + def __init__( + self, + model: torch.nn.Module, + tokenizer: Tokenizer, + *, + preprocessor: TextPreprocessor | None = None, + emotion_labels: Sequence[str] | None = None, + topic_labels: Sequence[str] | None = None, + config: InferenceConfig | None = None, + device: torch.device | str | None = None, + ) -> None: + self.model = model + self.tokenizer = tokenizer + self.config = config or InferenceConfig() + chosen_device = device or self.config.device + if chosen_device is None: + first_param = next(model.parameters(), None) + chosen_device = first_param.device if first_param is not None else "cpu" + self.device = torch.device(chosen_device) + self.model.to(self.device) + self.model.eval() + + self.preprocessor = preprocessor or TextPreprocessor(tokenizer=tokenizer) + self.emotion_labels = list(emotion_labels) if emotion_labels is not None else None + self.topic_labels = list(topic_labels) if topic_labels is not None else None + + def summarize(self, texts: Sequence[str], *, max_length: int | None = None) -> List[str]: + if not texts: + return [] + batch = self._batch_to_device(self.preprocessor.batch_encode(texts)) + src_ids = batch.input_ids + src_mask = batch.attention_mask + max_len = max_length or self.config.summary_max_length + + if not hasattr(self.model, "encoder") or not hasattr(self.model, "decoder"): + raise RuntimeError("Model must expose encoder and decoder attributes for summarization.") + + with torch.inference_mode(): + encoder_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None + memory = self.model.encoder(src_ids, mask=encoder_mask) + generated = self.model.decoder.greedy_decode( + memory=memory, + max_len=max_len, + start_token_id=self.tokenizer.bos_token_id, + end_token_id=self.tokenizer.eos_token_id, + device=self.device, + ) + + return self.tokenizer.decode_batch(generated.tolist()) + + def predict_emotions( + self, + texts: Sequence[str], + *, + threshold: float | None = None, + ) -> List[EmotionPrediction]: + if not texts: + return [] + if self.emotion_labels is None or not self.emotion_labels: + raise RuntimeError("emotion_labels must be provided to decode emotion predictions") + + batch = self._batch_to_device(self.preprocessor.batch_encode(texts)) + model_inputs = self._batch_to_model_inputs(batch) + decision_threshold = threshold or self.config.emotion_threshold + + with torch.inference_mode(): + logits = self.model.forward("emotion", model_inputs) + probs = torch.sigmoid(logits) + + predictions: List[EmotionPrediction] = [] + for row in probs.cpu(): + pairs = [ + (label, score) + for label, score in zip(self.emotion_labels, row.tolist()) + if score >= decision_threshold + ] + labels = [label for label, _ in pairs] + scores = [score for _, score in pairs] + predictions.append(EmotionPrediction(labels=labels, scores=scores)) + return predictions + + def predict_topics(self, texts: Sequence[str]) -> List[TopicPrediction]: + if not texts: + return [] + if self.topic_labels is None or not self.topic_labels: + raise RuntimeError("topic_labels must be provided to decode topic predictions") + + batch = self._batch_to_device(self.preprocessor.batch_encode(texts)) + model_inputs = self._batch_to_model_inputs(batch) + + with torch.inference_mode(): + logits = self.model.forward("topic", model_inputs) + probs = F.softmax(logits, dim=-1) + + results: List[TopicPrediction] = [] + for row in probs.cpu(): + scores = row.tolist() + best_index = int(row.argmax().item()) + results.append(TopicPrediction(label=self.topic_labels[best_index], confidence=scores[best_index])) + return results + + def batch_predict(self, texts: Iterable[str]) -> dict[str, object]: + text_list = list(texts) + if self.emotion_labels is None or not self.emotion_labels: + raise RuntimeError("emotion_labels must be provided for batch predictions") + if self.topic_labels is None or not self.topic_labels: + raise RuntimeError("topic_labels must be provided for batch predictions") + return { + "summaries": self.summarize(text_list), + "emotion": self.predict_emotions(text_list), + "topic": self.predict_topics(text_list), + } + + def _batch_to_device(self, batch: Batch) -> Batch: + tensor_updates: dict[str, torch.Tensor] = {} + for item in fields(batch): + value = getattr(batch, item.name) + if torch.is_tensor(value): + tensor_updates[item.name] = value.to(self.device) + if not tensor_updates: + return batch + return replace(batch, **tensor_updates) + + @staticmethod + def _batch_to_model_inputs(batch: Batch) -> dict[str, torch.Tensor]: + inputs: dict[str, torch.Tensor] = {"input_ids": batch.input_ids} + if batch.attention_mask is not None: + inputs["attention_mask"] = batch.attention_mask + return inputs diff --git a/src/inference/postprocessing.py b/src/inference/postprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..684941e5c21aa6cc09fda372140eaf2132060d9e --- /dev/null +++ b/src/inference/postprocessing.py @@ -0,0 +1,6 @@ +"""Output cleaning helpers.""" +from typing import List + + +def strip_whitespace(texts: List[str]) -> List[str]: + return [text.strip() for text in texts] diff --git a/src/models/factory.py b/src/models/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..25f18d9fecdc34cebe615cb48ad9a77c6736961c --- /dev/null +++ b/src/models/factory.py @@ -0,0 +1,105 @@ +"""Factory helpers to assemble multitask models for inference/training.""" +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +from ..data.tokenization import Tokenizer +from ..utils.config import load_yaml +from .decoder import TransformerDecoder +from .encoder import TransformerEncoder +from .heads import ClassificationHead, LMHead +from .multitask import MultiTaskModel + + +@dataclass(slots=True) +class ModelConfig: + """Configuration describing the transformer architecture.""" + + d_model: int = 512 + num_encoder_layers: int = 6 + num_decoder_layers: int = 6 + num_attention_heads: int = 8 + ffn_dim: int = 2048 + dropout: float = 0.1 + + def __post_init__(self): + if self.d_model % self.num_attention_heads != 0: + raise ValueError( + f"d_model ({self.d_model}) must be divisible by num_attention_heads ({self.num_attention_heads})" + ) + if not 0 <= self.dropout <= 1: + raise ValueError(f"dropout must be in [0, 1], got {self.dropout}") + if self.d_model <= 0 or self.num_encoder_layers <= 0 or self.num_decoder_layers <= 0: + raise ValueError("Model dimensions must be positive") + if self.num_attention_heads <= 0 or self.ffn_dim <= 0: + raise ValueError("Model dimensions must be positive") + + +def load_model_config(path: Optional[str | Path]) -> ModelConfig: + """Load a model configuration from YAML with sane defaults.""" + + if path is None: + return ModelConfig() + + data = load_yaml(str(path)).data + return ModelConfig( + d_model=int(data.get("d_model", 512)), + num_encoder_layers=int(data.get("num_encoder_layers", 6)), + num_decoder_layers=int(data.get("num_decoder_layers", 6)), + num_attention_heads=int(data.get("num_attention_heads", 8)), + ffn_dim=int(data.get("ffn_dim", 2048)), + dropout=float(data.get("dropout", 0.1)), + ) + + +def build_multitask_model( + tokenizer: Tokenizer, + *, + num_emotions: int, + num_topics: int, + config: ModelConfig | None = None, +) -> MultiTaskModel: + """Construct the multitask transformer with heads for the three tasks.""" + + cfg = config or ModelConfig() + if not isinstance(num_emotions, int) or num_emotions <= 0: + raise ValueError("num_emotions must be a positive integer") + if not isinstance(num_topics, int) or num_topics <= 0: + raise ValueError("num_topics must be a positive integer") + encoder = TransformerEncoder( + vocab_size=tokenizer.vocab_size, + d_model=cfg.d_model, + num_layers=cfg.num_encoder_layers, + num_heads=cfg.num_attention_heads, + d_ff=cfg.ffn_dim, + dropout=cfg.dropout, + max_len=tokenizer.config.max_length, + pad_token_id=tokenizer.pad_token_id, + ) + decoder = TransformerDecoder( + vocab_size=tokenizer.vocab_size, + d_model=cfg.d_model, + num_layers=cfg.num_decoder_layers, + num_heads=cfg.num_attention_heads, + d_ff=cfg.ffn_dim, + dropout=cfg.dropout, + max_len=tokenizer.config.max_length, + pad_token_id=tokenizer.pad_token_id, + ) + + model = MultiTaskModel(encoder=encoder, decoder=decoder, decoder_outputs_logits=True) + model.add_head( + "summarization", + LMHead(d_model=cfg.d_model, vocab_size=tokenizer.vocab_size, tie_embedding=decoder.embedding), + ) + model.add_head( + "emotion", + ClassificationHead(d_model=cfg.d_model, num_labels=num_emotions, pooler="mean", dropout=cfg.dropout), + ) + model.add_head( + "topic", + ClassificationHead(d_model=cfg.d_model, num_labels=num_topics, pooler="mean", dropout=cfg.dropout), + ) + return model diff --git a/src/models/multitask.py b/src/models/multitask.py index 01e49c1c5d79a9f5092a94f60cf974c22258a6a5..d29b0c247a3c4e3694364892d14784eefd39bdc6 100644 --- a/src/models/multitask.py +++ b/src/models/multitask.py @@ -39,17 +39,28 @@ class MultiTaskModel(nn.Module): mt = MultiTaskModel(encoder=enc, decoder=dec) mt.add_head("summarize", LMHead(...)) logits = mt.forward("summarize", {"src_ids": src_ids, "tgt_ids": tgt_ids}) + + Args: + encoder: optional encoder backbone. + decoder: optional decoder backbone. + decoder_outputs_logits: set True when ``decoder.forward`` already returns vocabulary logits; + set False if the decoder produces hidden states that must be projected by the LM head. """ def __init__( self, encoder: Optional[TransformerEncoder] = None, decoder: Optional[TransformerDecoder] = None, + *, + decoder_outputs_logits: bool = True, ): super().__init__() self.encoder = encoder self.decoder = decoder self.heads: Dict[str, nn.Module] = {} + # When True, decoder.forward(...) is expected to return logits already projected to the vocabulary space. + # When False, decoder outputs hidden states that must be passed through the registered LM head. + self.decoder_outputs_logits = decoder_outputs_logits def add_head(self, name: str, module: nn.Module) -> None: """Register a head under a task name.""" @@ -99,9 +110,15 @@ class MultiTaskModel(nn.Module): raise RuntimeError("Encoder is required for encoder-side heads") # accept either input_ids or embeddings if "input_ids" in inputs: - enc_out = self.encoder(inputs["input_ids"]) + encoder_mask = None + if "attention_mask" in inputs: + encoder_mask = self._expand_attention_mask(inputs["attention_mask"], inputs["input_ids"].device) + enc_out = self.encoder(inputs["input_ids"], mask=encoder_mask) elif "embeddings" in inputs: - enc_out = self.encoder(inputs["embeddings"]) + encoder_mask = inputs.get("attention_mask") + if encoder_mask is not None: + encoder_mask = self._expand_attention_mask(encoder_mask, inputs["embeddings"].device) + enc_out = self.encoder(inputs["embeddings"], mask=encoder_mask) else: raise ValueError("inputs must contain 'input_ids' or 'embeddings' for encoder tasks") logits = head(enc_out) @@ -120,10 +137,20 @@ class MultiTaskModel(nn.Module): raise RuntimeError("Both encoder and decoder are required for LM-style heads") # Build encoder memory + src_mask = inputs.get("src_mask") + if src_mask is None: + src_mask = inputs.get("attention_mask") + encoder_mask = None + reference_tensor = inputs.get("src_ids") + if reference_tensor is None: + reference_tensor = inputs.get("src_embeddings") + if src_mask is not None and reference_tensor is not None: + encoder_mask = self._expand_attention_mask(src_mask, reference_tensor.device) + if "src_ids" in inputs: - memory = self.encoder(inputs["src_ids"]) + memory = self.encoder(inputs["src_ids"], mask=encoder_mask) elif "src_embeddings" in inputs: - memory = self.encoder(inputs["src_embeddings"]) + memory = self.encoder(inputs["src_embeddings"], mask=encoder_mask) else: raise ValueError("inputs must contain 'src_ids' or 'src_embeddings' for seq2seq tasks") @@ -137,12 +164,13 @@ class MultiTaskModel(nn.Module): # Here we don't attempt to generate when labels not provided. raise ValueError("Seq2seq tasks require 'tgt_ids' or 'tgt_embeddings' for training forward") - # Run decoder. Decoder returns logits shaped (B, T, vocab) in this codebase. decoder_out = self.decoder(decoder_inputs, memory) - # If decoder already returned logits matching the head vocab size, use them directly. - # Otherwise, assume decoder returned hidden states and let the head project them. - if isinstance(decoder_out, torch.Tensor) and decoder_out.shape[-1] == head.vocab_size: + if self.decoder_outputs_logits: + if not isinstance(decoder_out, torch.Tensor): + raise TypeError( + "Decoder is configured to return logits, but forward returned a non-tensor value." + ) logits = decoder_out else: logits = head(decoder_out) @@ -195,4 +223,15 @@ class MultiTaskModel(nn.Module): return F.cross_entropy(logits, labels.long()) # If we can't determine, raise - raise RuntimeError("Cannot compute loss for unknown head type") \ No newline at end of file + raise RuntimeError("Cannot compute loss for unknown head type") + + @staticmethod + def _expand_attention_mask(mask: torch.Tensor, device: torch.device) -> torch.Tensor: + if mask is None: + return None # type: ignore[return-value] + bool_mask = mask.to(device=device, dtype=torch.bool) + if bool_mask.dim() == 2: + return bool_mask.unsqueeze(1) & bool_mask.unsqueeze(2) + if bool_mask.dim() in (3, 4): + return bool_mask + raise ValueError("Attention mask must be 2D, 3D, or 4D tensor") \ No newline at end of file diff --git a/src/training/__init__.py b/src/training/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..8dcb2957bf90acdccd778715615989a4489894d8 100644 --- a/src/training/__init__.py +++ b/src/training/__init__.py @@ -0,0 +1 @@ +"""Training utilities for LexiMind.""" diff --git a/src/training/callbacks.py b/src/training/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..ee260be490adb2c647126c9a4785fb66aa6802e3 --- /dev/null +++ b/src/training/callbacks.py @@ -0,0 +1,37 @@ +"""Callback hooks for training.""" + +from pathlib import Path +from typing import Any, Dict, Optional + +import torch + + +def save_checkpoint( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + epoch: int, + output_path: str, + *, + metrics: Optional[Dict[str, Any]] = None, +) -> None: + """Persist model and optimizer state for resuming training.""" + + checkpoint = { + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "epoch": int(epoch), + } + if metrics: + checkpoint["metrics"] = metrics + + target = Path(output_path) + target.parent.mkdir(parents=True, exist_ok=True) + temp_path = target.parent / f"{target.name}.tmp" + try: + torch.save(checkpoint, temp_path) + temp_path.replace(target) + except Exception: + raise + finally: + if temp_path.exists(): + temp_path.unlink(missing_ok=True) diff --git a/src/training/losses.py b/src/training/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..e6fa49333d4951a34782cc0eef14cad978e67573 --- /dev/null +++ b/src/training/losses.py @@ -0,0 +1,13 @@ +"""Loss helpers.""" +import torch + + +def multitask_loss(losses: dict[str, torch.Tensor]) -> torch.Tensor: + iterator = iter(losses.values()) + try: + total = next(iterator).clone() + except StopIteration: + raise ValueError("losses is empty") + for value in iterator: + total = total + value + return total / len(losses) diff --git a/src/training/metrics.py b/src/training/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..3c38895b9e896e7afb5e4054bf7725b2eb76f27a --- /dev/null +++ b/src/training/metrics.py @@ -0,0 +1,36 @@ +"""Metric helpers used during training and evaluation.""" +from __future__ import annotations + +from typing import Sequence + +import torch + + +def accuracy(predictions: Sequence[int], targets: Sequence[int]) -> float: + matches = sum(int(pred == target) for pred, target in zip(predictions, targets)) + return matches / max(1, len(predictions)) + + +def multilabel_f1(predictions: torch.Tensor, targets: torch.Tensor) -> float: + preds = predictions.float() + gold = targets.float() + true_positive = (preds * gold).sum(dim=1) + precision = true_positive / (preds.sum(dim=1).clamp(min=1.0)) + recall = true_positive / (gold.sum(dim=1).clamp(min=1.0)) + f1 = (2 * precision * recall) / (precision + recall).clamp(min=1e-8) + return float(f1.mean().item()) + + +def rouge_like(predictions: Sequence[str], references: Sequence[str]) -> float: + if not predictions or not references: + return 0.0 + scores = [] + for pred, ref in zip(predictions, references): + pred_tokens = pred.split() + ref_tokens = ref.split() + if not ref_tokens: + scores.append(0.0) + continue + overlap = len(set(pred_tokens) & set(ref_tokens)) + scores.append(overlap / len(ref_tokens)) + return sum(scores) / len(scores) diff --git a/src/training/trainer.py b/src/training/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..52eead03a97469d9ec78918901b704aeb133ed2f --- /dev/null +++ b/src/training/trainer.py @@ -0,0 +1,327 @@ +"""Multi-task trainer coordinating summarization, emotion, and topic heads.""" +from __future__ import annotations + +from collections import defaultdict +from dataclasses import dataclass +from typing import Dict, Iterator, List +import time +import shutil +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader + +from ..data.tokenization import Tokenizer +from .metrics import accuracy, multilabel_f1, rouge_like + + +@dataclass(slots=True) +class TrainerConfig: + max_epochs: int = 1 + gradient_clip_norm: float = 1.0 + logging_interval: int = 50 + task_weights: Dict[str, float] | None = None + + +class Trainer: + """Coordinates multi-task optimisation across task-specific dataloaders.""" + def __init__( + self, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + config: TrainerConfig, + device: torch.device, + tokenizer: Tokenizer, + ) -> None: + self.model = model.to(device) + self.optimizer = optimizer + self.config = config + self.device = device + self.tokenizer = tokenizer + self.emotion_loss = torch.nn.BCEWithLogitsLoss() + self.topic_loss = torch.nn.CrossEntropyLoss() + self._progress_last_len = 0 + + def fit( + self, + train_loaders: Dict[str, DataLoader], + val_loaders: Dict[str, DataLoader] | None = None, + ) -> Dict[str, Dict[str, float]]: + history: Dict[str, Dict[str, float]] = {} + total_epochs = max(1, self.config.max_epochs) + start_time = time.perf_counter() + for epoch in range(1, total_epochs + 1): + epoch_start = time.perf_counter() + train_metrics = self._run_epoch( + train_loaders, + train=True, + epoch=epoch, + total_epochs=total_epochs, + epoch_start=epoch_start, + global_start=start_time, + ) + history[f"train_epoch_{epoch}"] = train_metrics + if val_loaders: + val_metrics = self._run_epoch(val_loaders, train=False, epoch=epoch) + history[f"val_epoch_{epoch}"] = val_metrics + epoch_duration = time.perf_counter() - epoch_start + total_elapsed = time.perf_counter() - start_time + self._print_epoch_progress(epoch, total_epochs, epoch_duration, total_elapsed) + return history + + def _run_epoch( + self, + loaders: Dict[str, DataLoader], + *, + train: bool, + epoch: int, + total_epochs: int | None = None, + epoch_start: float | None = None, + global_start: float | None = None, + ) -> Dict[str, float]: + phase = "train" if train else "eval" + self.model.train(train) + metrics_accumulator: Dict[str, list[float]] = defaultdict(list) + iterator_map: Dict[str, Iterator[Dict[str, torch.Tensor]]] = { + task: iter(loader) for task, loader in loaders.items() + } + max_batches = max(len(loader) for loader in loaders.values()) + progress_enabled = ( + train + and max_batches > 0 + and total_epochs is not None + and epoch_start is not None + and global_start is not None + ) + + def emit_progress(step: int, final: bool = False) -> None: + if not progress_enabled: + return + total_epochs_value = total_epochs + epoch_start_value = epoch_start + global_start_value = global_start + assert total_epochs_value is not None + assert epoch_start_value is not None + assert global_start_value is not None + self._update_epoch_progress( + epoch=epoch, + total_epochs=total_epochs_value, + step=step, + total_steps=max_batches, + epoch_start=epoch_start_value, + global_start=global_start_value, + final=final, + ) + + emit_progress(0) + + context = torch.enable_grad() if train else torch.no_grad() + with context: + for step in range(max_batches): + backward_performed = False + for task, loader in loaders.items(): + batch = self._next_batch(iterator_map, loader, task) + if batch is None: + continue + loss, task_metrics = self._forward_task(task, batch, train) + weight = self._task_weight(task) + metrics_accumulator[f"{task}_loss"].append(loss.item()) + for metric_name, metric_value in task_metrics.items(): + metrics_accumulator[f"{task}_{metric_name}"].append(metric_value) + if train: + scaled_loss = loss * weight + scaled_loss.backward() + backward_performed = True + if train and backward_performed: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.gradient_clip_norm) + self.optimizer.step() + self.optimizer.zero_grad() + if train and self.config.logging_interval and (step + 1) % self.config.logging_interval == 0: + if torch.cuda.is_available() and self.device.type == "cuda": + torch.cuda.empty_cache() + emit_progress(step + 1) + emit_progress(max_batches, final=True) + + averaged = {name: sum(values) / len(values) for name, values in metrics_accumulator.items() if values} + averaged["epoch"] = float(epoch) + metric_str = ", ".join( + f"{k}={v:.4f}" for k, v in averaged.items() if k != "epoch" + ) + print(f"[{phase}] epoch {epoch}: {metric_str}") + return averaged + + def _next_batch( + self, + iterator_map: Dict[str, Iterator[Dict[str, torch.Tensor]]], + loader: DataLoader, + task: str, + ) -> Dict[str, torch.Tensor] | None: + try: + batch = next(iterator_map[task]) + except StopIteration: + iterator_map[task] = iter(loader) + try: + batch = next(iterator_map[task]) + except StopIteration: + return None + return {key: value.to(self.device) if isinstance(value, torch.Tensor) else value for key, value in batch.items()} + + def _forward_task(self, task: str, batch: Dict[str, torch.Tensor], train: bool) -> tuple[torch.Tensor, Dict[str, float]]: + if task == "summarization": + summarization_inputs = { + "src_ids": batch["src_ids"], + "tgt_ids": batch["tgt_ids"], + } + if "src_mask" in batch: + summarization_inputs["src_mask"] = batch["src_mask"] + logits = self.model.forward("summarization", summarization_inputs) + vocab_size = logits.size(-1) + loss = F.cross_entropy( + logits.view(-1, vocab_size), + batch["labels"].view(-1), + ignore_index=-100, + ) + summaries = self._decode_predictions(logits) + references = self._decode_labels(batch["labels"]) + rouge = rouge_like(summaries, references) + return loss, {"rouge_like": rouge} + + if task == "emotion": + emotion_inputs = {"input_ids": batch["input_ids"]} + if "attention_mask" in batch: + emotion_inputs["attention_mask"] = batch["attention_mask"] + logits = self.model.forward("emotion", emotion_inputs) + loss = self.emotion_loss(logits, batch["labels"].float()) + probs = torch.sigmoid(logits) + preds = (probs > 0.5).int() + labels = batch["labels"].int() + f1 = multilabel_f1(preds, labels) + return loss, {"f1": f1} + + if task == "topic": + topic_inputs = {"input_ids": batch["input_ids"]} + if "attention_mask" in batch: + topic_inputs["attention_mask"] = batch["attention_mask"] + logits = self.model.forward("topic", topic_inputs) + loss = self.topic_loss(logits, batch["labels"]) + preds = logits.argmax(dim=-1) + acc = accuracy(preds.tolist(), batch["labels"].tolist()) + return loss, {"accuracy": acc} + + raise ValueError(f"Unknown task '{task}'") + + def _task_weight(self, task: str) -> float: + if not self.config.task_weights: + return 1.0 + return self.config.task_weights.get(task, 1.0) + + def _decode_predictions(self, logits: torch.Tensor) -> List[str]: + generated = logits.argmax(dim=-1) + return self.tokenizer.decode_batch(generated.tolist()) + + def _decode_labels(self, labels: torch.Tensor) -> List[str]: + valid = labels.clone() + valid[valid == -100] = self.tokenizer.pad_token_id + return self.tokenizer.decode_batch(valid.tolist()) + + def _print_epoch_progress( + self, + epoch: int, + total_epochs: int, + epoch_duration: float, + total_elapsed: float, + ) -> None: + progress = epoch / total_epochs + percent = progress * 100 + remaining_epochs = total_epochs - epoch + eta = (total_elapsed / epoch) * remaining_epochs if epoch > 0 else 0.0 + bar = self._format_progress_bar(progress) + message = ( + f"[progress] {bar} {percent:5.1f}% | epoch {epoch}/{total_epochs} " + f"| last {epoch_duration:6.2f}s | total {total_elapsed:6.2f}s | ETA {eta:6.2f}s" + ) + print(message, flush=True) + + @staticmethod + def _format_progress_bar(progress: float, width: int = 20) -> str: + clamped = max(0.0, min(1.0, progress)) + filled = int(round(clamped * width)) + bar = "#" * filled + "-" * (width - filled) + return f"[{bar}]" + + def _update_epoch_progress( + self, + *, + epoch: int, + total_epochs: int, + step: int, + total_steps: int, + epoch_start: float, + global_start: float, + final: bool = False, + ) -> None: + if total_steps <= 0 or total_epochs <= 0: + return + bounded_step = max(0, min(step, total_steps)) + step_fraction = bounded_step / total_steps + epochs_completed = (epoch - 1) + step_fraction + overall_progress = epochs_completed / total_epochs + percent = overall_progress * 100.0 + epoch_elapsed = time.perf_counter() - epoch_start + total_elapsed = time.perf_counter() - global_start + if epochs_completed > 0: + remaining_epochs = max(total_epochs - epochs_completed, 0.0) + eta = (total_elapsed / epochs_completed) * remaining_epochs if total_elapsed > 0 else 0.0 + else: + eta = 0.0 + bar = self._format_progress_bar(overall_progress, width=self._progress_bar_width()) + message = ( + f"[progress] {bar} {percent:5.1f}% " + f"e {epoch}/{total_epochs} " + f"s {bounded_step}/{total_steps} " + f"ep {self._format_duration(epoch_elapsed)} " + f"tot {self._format_duration(total_elapsed)} " + f"eta {self._format_duration(eta)}" + ) + display = self._truncate_to_terminal(message) + padding = " " * max(self._progress_last_len - len(display), 0) + print(f"\r{display}{padding}", end="", flush=True) + if final: + print() + self._progress_last_len = 0 + else: + self._progress_last_len = len(display) + + def _truncate_to_terminal(self, text: str) -> str: + columns = self._terminal_width() + if columns <= 0: + return text + if len(text) >= columns: + return text[: max(columns - 1, 1)] + return text + + def _progress_bar_width(self) -> int: + columns = self._terminal_width() + reserved = 60 + if columns <= reserved: + return 10 + return max(10, min(30, columns - reserved)) + + @staticmethod + def _terminal_width() -> int: + try: + return shutil.get_terminal_size(fallback=(120, 20)).columns + except OSError: + return 120 + + @staticmethod + def _format_duration(seconds: float) -> str: + seconds = max(0.0, seconds) + if seconds >= 3600: + hours = int(seconds // 3600) + minutes = int((seconds % 3600) // 60) + return f"{hours}h{minutes:02}m" + if seconds >= 60: + minutes = int(seconds // 60) + secs = int(seconds % 60) + return f"{minutes}m{secs:02}s" + return f"{seconds:4.1f}s" diff --git a/src/training/utils.py b/src/training/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..71f50d563da89ca982fc08772588519a25630ad8 --- /dev/null +++ b/src/training/utils.py @@ -0,0 +1,55 @@ +"""Small training helpers.""" + +from __future__ import annotations + +import random +import threading +from typing import Optional + +import numpy as np +import torch + + +_seed_sequence: Optional[np.random.SeedSequence] = None +_seed_lock = threading.Lock() +_spawn_counter = 0 +_thread_local = threading.local() + + +def set_seed(seed: int) -> np.random.Generator: + """Seed stdlib/Torch RNGs and initialise this thread's NumPy generator.""" + + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + base_seq = np.random.SeedSequence(seed) + child = base_seq.spawn(1)[0] + rng = np.random.default_rng(child) + + global _seed_sequence, _spawn_counter + with _seed_lock: + _seed_sequence = base_seq + _spawn_counter = 1 + _thread_local.rng = rng + return rng + + +def numpy_generator() -> np.random.Generator: + """Return the calling thread's NumPy generator, creating one if needed.""" + + rng = getattr(_thread_local, "rng", None) + if rng is not None: + return rng + + global _seed_sequence, _spawn_counter + with _seed_lock: + if _seed_sequence is None: + _seed_sequence = np.random.SeedSequence() + _spawn_counter = 0 + child_seq = _seed_sequence.spawn(1)[0] + _spawn_counter += 1 + + rng = np.random.default_rng(child_seq) + _thread_local.rng = rng + return rng diff --git a/src/ui/streamlit_app.py b/src/ui/streamlit_app.py deleted file mode 100644 index ad8b9255f5d304628e05d1b0e303820a7d4e17bd..0000000000000000000000000000000000000000 --- a/src/ui/streamlit_app.py +++ /dev/null @@ -1,108 +0,0 @@ -""" -Streamlit prototype for LexiMind (summarization, emotion, topic). -Run from repo root: streamlit run streamlit_app.py -""" -import streamlit as st -import numpy as np -import pandas as pd -import plotly.express as px -import plotly.figure_factory as ff - -# Stable absolute import; ensure repo root is on PYTHONPATH (running from repo root is standard) -try: - from ..api.inference import load_models, summarize_text, classify_emotion, topic_for_text -except Exception as e: - st.error(f"Failed to import inference helpers: {e}") - raise - -st.set_page_config(page_title="LexiMind demo", layout="wide") - -MODEL_CONFIG = { - "checkpoint_path": "checkpoints/best.pt", # change to your trained checkpoint - "tokenizer_path": "artifacts/tokenizer.json", # JSON produced by TextPreprocessor.save_tokenizer - "device": "cpu", -} -try: - models = load_models(MODEL_CONFIG) -except Exception as exc: - st.error(f"Failed to load models: {exc}") - st.stop() - -st.sidebar.title("LexiMind") -task = st.sidebar.selectbox("Task", ["Summarize", "Emotion", "Topic", "Search demo"]) -compression = st.sidebar.slider("Compression (summary length)", 0.1, 1.0, 0.25) -show_attn = st.sidebar.checkbox("Show attention heatmap (collect_attn)", value=False) - -st.sidebar.markdown("Demo controls") -sample_choice = st.sidebar.selectbox("Use sample text", ["None", "Gutenberg sample", "News sample"]) - -SAMPLES = { - "Gutenberg sample": ( - "It was the best of times, it was the worst of times, it was the age of wisdom, " - "it was the age of foolishness..." - ), - "News sample": ( - "Markets rallied today as tech stocks posted gains amid broad optimism over earnings..." - ), -} - -st.title("LexiMind β€” Summarization, Emotion, Topic (Prototype)") - -if sample_choice != "None": - input_text = st.text_area("Input text", value=SAMPLES[sample_choice], height=280) -else: - input_text = st.text_area("Input text", value="", height=280) - -col1, col2 = st.columns([2, 1]) - -with col1: - st.subheader("Output") - if st.button("Run"): - if not input_text.strip(): - st.warning("Enter some text or select a sample to run the model.") - else: - if task == "Summarize": - summary, attn_data = summarize_text(input_text, compression=compression, collect_attn=show_attn, models=models) - st.markdown("**Summary**") - st.write(summary) - if show_attn and attn_data is not None: - st.markdown("**Attention heatmap (averaged heads)**") - src_tokens = attn_data.get("src_tokens", None) - tgt_tokens = attn_data.get("tgt_tokens", None) - weights = attn_data.get("weights", None) - if weights is not None: - arr = np.array(weights) - if arr.ndim == 4: - arr = arr.mean(axis=(0,1)) - elif arr.ndim == 3: - arr = arr.mean(axis=0) - fig = ff.create_annotated_heatmap( - z=arr.tolist(), - x=src_tokens if src_tokens else [f"tok{i}" for i in range(arr.shape[1])], - y=tgt_tokens if tgt_tokens else [f"tok{i}" for i in range(arr.shape[0])], - colorscale="Viridis", - ) - st.plotly_chart(fig, use_container_width=True) - else: - st.info("Attention data not available from the model.") - elif task == "Emotion": - probs, labels = classify_emotion(input_text, models=models) - st.markdown("**Emotion predictions (multi-label probabilities)**") - df = pd.DataFrame({"emotion": labels, "prob": probs}) - fig = px.bar(df, x="emotion", y="prob", color="prob", range_y=[0,1]) - st.plotly_chart(fig, use_container_width=True) - elif task == "Topic": - topic_id, topic_terms = topic_for_text(input_text, models=models) - st.markdown("**Topic cluster**") - st.write(f"Cluster ID: {topic_id}") - st.write("Top terms:", ", ".join(topic_terms)) - elif task == "Search demo": - st.info("Search demo will be available when ingestion is run (see scripts).") - -with col2: - st.subheader("Model & Info") - st.markdown(f"*Model loaded:* {'yes' if models.get('loaded', False) else 'no'}") - st.markdown(f"*Device:* {models.get('device', MODEL_CONFIG['device'])}") - st.markdown("**Notes**") - st.markdown("- Attention visualization depends on model support to return attention.") - st.markdown("- For long inputs the UI truncates tokens for heatmap clarity.") \ No newline at end of file diff --git a/src/utils/__init__.py b/src/utils/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..a1337311995fcdeefcc72e07223506e7baa65f1f 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -0,0 +1 @@ +"""General utilities for LexiMind.""" diff --git a/src/utils/config.py b/src/utils/config.py index c9996a81637086de49ca8eff3216cd63ac660307..8c553b247353f2b438c12e9d43e93d728e360d6f 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -1,47 +1,19 @@ +"""YAML config loader.""" from dataclasses import dataclass from pathlib import Path -from typing import Optional, Dict, Any -import yaml -from omegaconf import OmegaConf +from typing import Any, Dict -@dataclass -class ModelConfig: - vocab_size: int - d_model: int - num_encoder_layers: int - num_decoder_layers: int - num_heads: int - d_ff: int - dropout: float - max_seq_length: int +import yaml -@dataclass -class TrainingConfig: - batch_size: int - learning_rate: float - num_epochs: int - warmup_steps: int - max_grad_norm: float - mixed_precision: bool -@dataclass +@dataclass(slots=True) class Config: - model: ModelConfig - training: TrainingConfig data: Dict[str, Any] - tasks: Dict[str, Any] -def load_config(config_path: str) -> Config: - """Load config from YAML and convert to structured dataclass.""" - cfg = OmegaConf.load(config_path) - - # Convert to dataclass for type safety - model_cfg = ModelConfig(**cfg.model) - training_cfg = TrainingConfig(**cfg.training) - - return Config( - model=model_cfg, - training=training_cfg, - data=dict(cfg.data), - tasks=dict(cfg.tasks) - ) + +def load_yaml(path: str) -> Config: + with Path(path).open("r", encoding="utf-8") as handle: + content = yaml.safe_load(handle) + if not isinstance(content, dict): + raise ValueError(f"YAML configuration '{path}' must contain a mapping at the root") + return Config(data=content) diff --git a/src/utils/io.py b/src/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..4aca5373cc2152274926dc85290b5dcb80e095ce --- /dev/null +++ b/src/utils/io.py @@ -0,0 +1,15 @@ +"""Checkpoint IO helpers.""" +from pathlib import Path + +import torch + + +def save_state(model: torch.nn.Module, path: str) -> None: + destination = Path(path) + destination.parent.mkdir(parents=True, exist_ok=True) + torch.save(model.state_dict(), destination) + + +def load_state(model: torch.nn.Module, path: str) -> None: + state = torch.load(path, map_location="cpu", weights_only=True) + model.load_state_dict(state) \ No newline at end of file diff --git a/src/utils/labels.py b/src/utils/labels.py new file mode 100644 index 0000000000000000000000000000000000000000..f5e532eee22d78fa60d0a1b92fb42eddb725d865 --- /dev/null +++ b/src/utils/labels.py @@ -0,0 +1,56 @@ +"""Label metadata helpers for multitask inference.""" +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import List + + +@dataclass(slots=True) +class LabelMetadata: + """Container for label vocabularies persisted after training.""" + + emotion: List[str] + topic: List[str] + + @property + def emotion_size(self) -> int: + return len(self.emotion) + + @property + def topic_size(self) -> int: + return len(self.topic) + + +def load_label_metadata(path: str | Path) -> LabelMetadata: + """Load label vocabularies from a JSON file.""" + + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"Label metadata file not found: {path}") + + with path.open("r", encoding="utf-8") as handle: + payload = json.load(handle) + + emotion = payload.get("emotion") if "emotion" in payload else payload.get("emotions") + topic = payload.get("topic") if "topic" in payload else payload.get("topics") + if not isinstance(emotion, list) or not all(isinstance(item, str) for item in emotion): + raise ValueError("Label metadata missing 'emotion' list of strings") + if not isinstance(topic, list) or not all(isinstance(item, str) for item in topic): + raise ValueError("Label metadata missing 'topic' list of strings") + + return LabelMetadata(emotion=emotion, topic=topic) + + +def save_label_metadata(metadata: LabelMetadata, path: str | Path) -> None: + """Persist label vocabularies to JSON.""" + + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + payload = { + "emotion": metadata.emotion, + "topic": metadata.topic, + } + with path.open("w", encoding="utf-8") as handle: + json.dump(payload, handle, ensure_ascii=False, indent=2) diff --git a/src/utils/logging.py b/src/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..b19917125a3d4ad0cb4cf46c73fd718c3d31bb35 --- /dev/null +++ b/src/utils/logging.py @@ -0,0 +1,12 @@ +"""Logging setup.""" +import logging + + +def configure_logging(level: int = logging.INFO) -> None: + """Configure root logging. Call once during application setup.""" + + logging.basicConfig(level=level) + + +def get_logger(name: str) -> logging.Logger: + return logging.getLogger(name) diff --git a/src/utils/random.py b/src/utils/random.py new file mode 100644 index 0000000000000000000000000000000000000000..0817d945d41e645306a6c5f2d50e1760c9258d23 --- /dev/null +++ b/src/utils/random.py @@ -0,0 +1,9 @@ +"""Randomness helpers.""" +import random + +import numpy as np + + +def set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) diff --git a/src/visualization/__init__.py b/src/visualization/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..76f41afd754edefca783989027bb720959e72bf9 100644 --- a/src/visualization/__init__.py +++ b/src/visualization/__init__.py @@ -0,0 +1 @@ +"""Visualization helpers for LexiMind.""" diff --git a/src/visualization/attention.py b/src/visualization/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..4e17bf6eb470e3718361e8abf6cbce2c667c6bcb --- /dev/null +++ b/src/visualization/attention.py @@ -0,0 +1,28 @@ +"""Attention plotting utilities.""" +from typing import Sequence + +import matplotlib.pyplot as plt +import numpy as np + + +def plot_attention(matrix: np.ndarray, tokens: Sequence[str]) -> None: + if matrix.ndim != 2: + raise ValueError("Attention matrix must be 2-dimensional") + token_count = len(tokens) + if token_count == 0: + raise ValueError("tokens must contain at least one item") + if matrix.shape != (token_count, token_count): + raise ValueError( + f"Attention matrix shape {matrix.shape} must match (len(tokens), len(tokens)) = ({token_count}, {token_count})" + ) + + fig, ax = plt.subplots() + heatmap = ax.imshow(matrix, cmap="viridis") + ax.set_xticks(range(token_count)) + ax.set_xticklabels(tokens, rotation=90) + ax.set_yticks(range(token_count)) + ax.set_yticklabels(tokens) + cbar = fig.colorbar(heatmap, ax=ax) + cbar.set_label("Attention Weight") + fig.tight_layout() + plt.show() diff --git a/src/visualization/embeddings.py b/src/visualization/embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..8ac0706d4198ed992ae8464437caae1655efa306 --- /dev/null +++ b/src/visualization/embeddings.py @@ -0,0 +1,32 @@ +"""Embedding visualization helpers.""" + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +import numpy as np +from sklearn.manifold import TSNE + + +def plot_tsne(embeddings: np.ndarray, labels: list[str]) -> None: + if embeddings.size == 0 or embeddings.ndim != 2: + raise ValueError("embeddings must be a non-empty 2D array") + if not labels: + raise ValueError("labels must be a non-empty list") + if embeddings.shape[0] != len(labels): + raise ValueError("number of samples in embeddings must equal length of labels") + if embeddings.shape[1] < 2: + raise ValueError("embeddings must have at least 2 features for t-SNE visualization") + + reducer = TSNE(n_components=2, init="pca", learning_rate="auto") + projection = reducer.fit_transform(embeddings) + + df = pd.DataFrame({ + "x": projection[:, 0], + "y": projection[:, 1], + "label": labels, + }) + plt.figure() + sns.scatterplot(data=df, x="x", y="y", hue="label", palette="tab10", s=50) + plt.legend(title="Labels", loc="best") + plt.tight_layout() + plt.show() diff --git a/src/visualization/metrics.py b/src/visualization/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..1b5ad240648d74977da8e248a7dd5765aecc8a45 --- /dev/null +++ b/src/visualization/metrics.py @@ -0,0 +1,27 @@ +"""Metric plotting helpers.""" +import matplotlib.pyplot as plt + + +def plot_curve( + values: list[float], + title: str, + *, + save_path: str | None = None, + show: bool = True, +) -> None: + fig, ax = plt.subplots() + ax.plot(values) + ax.set_title(title) + ax.set_xlabel("Step") + ax.set_ylabel("Value") + fig.tight_layout() + + if save_path is not None: + fig.savefig(save_path) + plt.close(fig) + return + + if show: + plt.show() + else: + plt.close(fig) diff --git a/tests/test_api/test_routes.py b/tests/test_api/test_routes.py new file mode 100644 index 0000000000000000000000000000000000000000..a1fa59416064324ff578a858691294a06538873c --- /dev/null +++ b/tests/test_api/test_routes.py @@ -0,0 +1,34 @@ +"""API integration tests for the inference endpoint.""" +from __future__ import annotations + +from fastapi.testclient import TestClient + +from src.api.app import create_app +from src.api.dependencies import get_pipeline +from src.inference.pipeline import EmotionPrediction, TopicPrediction + + +class StubPipeline: + def batch_predict(self, texts): # pragma: no cover - simple stub + return { + "summaries": [f"summary:{text}" for text in texts], + "emotion": [EmotionPrediction(labels=["joy"], scores=[0.9]) for _ in texts], + "topic": [TopicPrediction(label="news", confidence=0.8) for _ in texts], + } + + +def test_summarize_route_returns_pipeline_outputs() -> None: + app = create_app() + app.dependency_overrides[get_pipeline] = lambda: StubPipeline() + client = TestClient(app) + + try: + response = client.post("/summarize", json={"text": "hello world"}) + assert response.status_code == 200 + payload = response.json() + assert payload["summary"] == "summary:hello world" + assert payload["emotion_labels"] == ["joy"] + assert payload["topic"] == "news" + assert payload["topic_confidence"] == 0.8 + finally: + app.dependency_overrides.clear() \ No newline at end of file diff --git a/tests/test_data/test_download_records.py b/tests/test_data/test_download_records.py new file mode 100644 index 0000000000000000000000000000000000000000..87aa151515952460ea1a693538e1f2376830fa74 --- /dev/null +++ b/tests/test_data/test_download_records.py @@ -0,0 +1,70 @@ +"""Unit tests for dataset record helpers in scripts.download_data.""" +from __future__ import annotations + +import importlib.util +import unittest +from pathlib import Path +from typing import Any, Dict, Iterator, List, cast + +PROJECT_ROOT = Path(__file__).resolve().parents[2] +DOWNLOAD_SCRIPT = PROJECT_ROOT / "scripts" / "download_data.py" + +spec = importlib.util.spec_from_file_location("download_data", DOWNLOAD_SCRIPT) +if spec is None or spec.loader is None: + raise RuntimeError("Unable to load scripts/download_data.py for testing") +download_data = importlib.util.module_from_spec(spec) +spec.loader.exec_module(download_data) + + +class DummyDataset: + def __init__(self, records: List[Dict[str, object]]) -> None: + self._records = records + + def __iter__(self) -> Iterator[Dict[str, object]]: + return iter(self._records) + + +class DownloadDataRecordTests(unittest.TestCase): + def test_emotion_records_handles_out_of_range_labels(self) -> None: + dataset_split = DummyDataset([ + {"text": "sample", "label": 1}, + {"text": "multi", "label": [0, 5]}, + {"text": "string", "label": "2"}, + ]) + label_names = ["sadness", "joy", "love"] + records = list( + download_data._emotion_records( + cast(Any, dataset_split), + label_names, + ) + ) + self.assertEqual(records[0]["emotions"], ["joy"]) + # Out-of-range index falls back to string representation + self.assertEqual(records[1]["emotions"], ["sadness", "5"]) + # Non-int values fall back to string + self.assertEqual(records[2]["emotions"], ["2"]) + + def test_topic_records_handles_varied_label_inputs(self) -> None: + dataset_split = DummyDataset([ + {"text": "news", "label": 3}, + {"text": "list", "label": [1]}, + {"text": "unknown", "label": "5"}, + {"text": "missing", "label": []}, + ]) + label_names = ["World", "Sports", "Business", "Sci/Tech"] + records = list( + download_data._topic_records( + cast(Any, dataset_split), + label_names, + ) + ) + self.assertEqual(records[0]["topic"], "Sci/Tech") + self.assertEqual(records[1]["topic"], "Sports") + # Out-of-range string label falls back to original string value + self.assertEqual(records[2]["topic"], "5") + # Empty list yields empty string + self.assertEqual(records[3]["topic"], "") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_data/test_preprocessing.py b/tests/test_data/test_preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..df71f13b60c7cf428352fa0237dafe63562ca90b --- /dev/null +++ b/tests/test_data/test_preprocessing.py @@ -0,0 +1,29 @@ +import unittest + +from LexiMind.src.data.preprocessing import TextPreprocessor +from LexiMind.src.data.tokenization import Tokenizer, TokenizerConfig + + +class _StubTokenizer(Tokenizer): + def __init__(self, max_length: int) -> None: + # Avoid expensive huggingface initialisation by skipping super().__init__ + self.config = TokenizerConfig(max_length=max_length) + + def batch_encode(self, texts, *, max_length=None): + raise NotImplementedError + + +class TextPreprocessorTests(unittest.TestCase): + def test_matching_max_length_leaves_tokenizer_unchanged(self) -> None: + tokenizer = _StubTokenizer(max_length=128) + TextPreprocessor(tokenizer=tokenizer, max_length=128) + self.assertEqual(tokenizer.config.max_length, 128) + + def test_conflicting_max_length_raises_value_error(self) -> None: + tokenizer = _StubTokenizer(max_length=256) + with self.assertRaises(ValueError): + TextPreprocessor(tokenizer=tokenizer, max_length=128) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_inference/test_pipeline.py b/tests/test_inference/test_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..8d04dd7a140133740f2866bbe932ef3e5068ba9e --- /dev/null +++ b/tests/test_inference/test_pipeline.py @@ -0,0 +1,106 @@ +"""Integration tests for the inference pipeline.""" +from __future__ import annotations + +from pathlib import Path +from typing import cast + +import torch + +from src.data.tokenization import Tokenizer, TokenizerConfig +from src.inference.pipeline import EmotionPrediction, InferenceConfig, InferencePipeline, TopicPrediction +from src.utils.labels import LabelMetadata + + +def _local_tokenizer_config() -> TokenizerConfig: + root = Path(__file__).resolve().parents[2] + hf_path = root / "artifacts" / "hf_tokenizer" + return TokenizerConfig(pretrained_model_name=str(hf_path)) + + +class DummyEncoder(torch.nn.Module): + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: # pragma: no cover - trivial + batch, seq_len = input_ids.shape + return torch.zeros(batch, seq_len, 8, device=input_ids.device) + + +class DummyDecoder(torch.nn.Module): + def __init__(self, tokenizer: Tokenizer) -> None: + super().__init__() + tokens = tokenizer.tokenizer.encode("dummy summary", add_special_tokens=False) + sequence = [tokenizer.bos_token_id, *tokens, tokenizer.eos_token_id] + self.register_buffer("sequence", torch.tensor(sequence, dtype=torch.long)) + + def greedy_decode( + self, + *, + memory: torch.Tensor, + max_len: int, + start_token_id: int, + end_token_id: int | None, + device: torch.device, + ) -> torch.Tensor: + seq = self.sequence.to(device) + if seq.numel() > max_len: + seq = seq[:max_len] + batch = memory.size(0) + return seq.unsqueeze(0).repeat(batch, 1) + + +class DummyModel(torch.nn.Module): + def __init__(self, tokenizer: Tokenizer, metadata: LabelMetadata) -> None: + super().__init__() + self.encoder = DummyEncoder() + self.decoder = DummyDecoder(tokenizer) + emotion_logits = torch.tensor([-2.0, 3.0, -1.0], dtype=torch.float32) + topic_logits = torch.tensor([0.25, 2.5, 0.1], dtype=torch.float32) + self.register_buffer("_emotion_logits", emotion_logits) + self.register_buffer("_topic_logits", topic_logits) + + def forward(self, task: str, inputs: dict[str, torch.Tensor]) -> torch.Tensor: # pragma: no cover - simple dispatch + batch = inputs["input_ids"].size(0) + if task == "emotion": + return self._emotion_logits.unsqueeze(0).repeat(batch, 1) + if task == "topic": + return self._topic_logits.unsqueeze(0).repeat(batch, 1) + raise KeyError(task) + + +def _build_pipeline() -> InferencePipeline: + tokenizer = Tokenizer(_local_tokenizer_config()) + metadata = LabelMetadata(emotion=["anger", "joy", "sadness"], topic=["news", "sports", "tech"]) + model = DummyModel(tokenizer, metadata) + return InferencePipeline( + model=model, + tokenizer=tokenizer, + emotion_labels=metadata.emotion, + topic_labels=metadata.topic, + config=InferenceConfig(summary_max_length=12), + ) + + +def test_pipeline_predictions_across_tasks() -> None: + pipeline = _build_pipeline() + text = "A quick unit test input." + + summaries = pipeline.summarize([text]) + assert summaries == ["dummy summary"], "Summaries should be decoded from dummy decoder sequence" + + emotions = pipeline.predict_emotions([text]) + assert len(emotions) == 1 + emotion = emotions[0] + assert isinstance(emotion, EmotionPrediction) + assert emotion.labels == ["joy"], "Only the positive logit should pass the threshold" + + topics = pipeline.predict_topics([text]) + assert len(topics) == 1 + topic = topics[0] + assert isinstance(topic, TopicPrediction) + assert topic.label == "sports" + assert topic.confidence > 0.0 + + combined = pipeline.batch_predict([text]) + assert combined["summaries"] == summaries + combined_emotions = cast(list[EmotionPrediction], combined["emotion"]) + combined_topics = cast(list[TopicPrediction], combined["topic"]) + assert combined_emotions[0].labels == emotion.labels + assert combined_topics[0].label == topic.label \ No newline at end of file diff --git a/tests/test_models/test_encoder_layer.py b/tests/test_models/test_encoder_layer.py index 551e05326c8e55a87cc3e778ffc04d4ae25db85a..5d2b37805a3cb198b850be9812096e4efc2b9315 100644 --- a/tests/test_models/test_encoder_layer.py +++ b/tests/test_models/test_encoder_layer.py @@ -3,6 +3,14 @@ import pytest from src.models.encoder import TransformerEncoderLayer +def _take_tensor(output): + """Return the tensor component regardless of (tensor, attn) tuple output.""" + + if isinstance(output, tuple): # modern layers return (output, attention) + return output[0] + return output + + def test_output_shape_and_grad(): """ The encoder layer should preserve the input shape (batch, seq_len, d_model) @@ -14,7 +22,7 @@ def test_output_shape_and_grad(): layer = TransformerEncoderLayer(d_model=d_model, num_heads=num_heads, d_ff=d_ff, dropout=0.0) x = torch.randn(batch_size, seq_len, d_model, requires_grad=True) - out = layer(x) # should accept mask=None by default + out = _take_tensor(layer(x)) # should accept mask=None by default assert out.shape == (batch_size, seq_len, d_model) # simple backward to ensure gradients propagate @@ -38,14 +46,14 @@ def test_dropout_behavior_train_vs_eval(): x = torch.randn(batch_size, seq_len, d_model) layer.train() - out1 = layer(x) - out2 = layer(x) + out1 = _take_tensor(layer(x)) + out2 = _take_tensor(layer(x)) # Training mode with dropout: outputs usually differ assert not torch.allclose(out1, out2), "Outputs identical in train mode despite dropout" layer.eval() - out3 = layer(x) - out4 = layer(x) + out3 = _take_tensor(layer(x)) + out4 = _take_tensor(layer(x)) # Eval mode deterministic: outputs should be identical assert torch.allclose(out3, out4), "Outputs differ in eval mode" @@ -64,12 +72,12 @@ def test_mask_broadcasting_accepts_3d_and_4d_mask(): # 3D mask: (batch, seq, seq) mask3 = torch.ones(batch_size, seq_len, seq_len, dtype=torch.bool) mask3[:, :, -2:] = False # mask out last two key positions - out3 = layer(x, mask=mask3) # should not raise + out3 = _take_tensor(layer(x, mask=mask3)) # should not raise assert out3.shape == (batch_size, seq_len, d_model) # 4D mask: (batch, 1, seq, seq) already including head dim for broadcasting mask4 = mask3.unsqueeze(1) - out4 = layer(x, mask=mask4) + out4 = _take_tensor(layer(x, mask=mask4)) assert out4.shape == (batch_size, seq_len, d_model) diff --git a/tests/test_models/test_positional_encoding.py b/tests/test_models/test_positional_encoding.py index 129acf3030e8c59c21855451cb8a9f4359755cc5..bc4797b000d4a56cb6ffc3945fd0a10e1a76068a 100644 --- a/tests/test_models/test_positional_encoding.py +++ b/tests/test_models/test_positional_encoding.py @@ -4,8 +4,13 @@ Tests for positional encoding. """ +import os + import pytest import torch +import matplotlib + +matplotlib.use("Agg") # use non-interactive backend for test environments import matplotlib.pyplot as plt import seaborn as sns from src.models.positional_encoding import PositionalEncoding @@ -93,6 +98,7 @@ def test_visualize_positional_encoding(): plt.ylabel('Embedding Dimension') plt.title('Positional Encoding Pattern\n(Notice the wave patterns with different frequencies)') plt.tight_layout() + os.makedirs('outputs', exist_ok=True) plt.savefig('outputs/positional_encoding_heatmap.png', dpi=150) print("βœ… Saved to outputs/positional_encoding_heatmap.png")