| | """ |
| | export.py |
| | |
| | Exporting a trained MiniTransformer PyTorch model to TorchScript format for inference. |
| | |
| | Steps: |
| | 1. Loading the MiniTransformer model architecture. |
| | 2. Loading trained weights from 'phish_model.pt'. |
| | 3. Converting the model to TorchScript using torch.jit.script. |
| | 4. Saving the TorchScript model as 'phish_model_ts.pt'. |
| | 5. Printing the file size of the exported model. |
| | """ |
| |
|
| | import torch |
| | from pathlib import Path |
| | import sys |
| |
|
| | |
| | project_root = Path(__file__).resolve().parents[2] |
| | sys.path.insert(0, str(project_root / "src")) |
| |
|
| | from model.model import MiniTransformer |
| |
|
| | def export(): |
| | """ |
| | Exporting the trained MiniTransformer to TorchScript. |
| | |
| | Requirements: |
| | - Having 'phish_model.pt' existing in the project root. |
| | |
| | Output: |
| | - Saving 'phish_model_ts.pt' in the project root. |
| | - Printing the file size of the exported model in KB. |
| | """ |
| |
|
| | |
| | model = MiniTransformer() |
| |
|
| | |
| | model_path = project_root / "models" / "phish_model.pt" |
| | if not model_path.exists(): |
| | raise FileNotFoundError(f"Model file not found: {model_path}") |
| | model.load_state_dict(torch.load(model_path, map_location="cpu")) |
| |
|
| | |
| | model.eval() |
| |
|
| | |
| | with torch.no_grad(): |
| | scripted = torch.jit.script(model) |
| |
|
| | |
| | output_path = project_root / "models" / "phish_model.pt" |
| | scripted.save(output_path) |
| |
|
| | |
| | size_kb = len(scripted.save_to_buffer()) / 1024 |
| | print(f"Exported TorchScript model to '{output_path}' | size: {size_kb:.1f} KB") |
| |
|
| | if __name__ == "__main__": |
| | export() |