|
|
import os |
|
|
import json |
|
|
import torch |
|
|
from pathlib import Path |
|
|
from services.transformer import TinyTransformer |
|
|
|
|
|
|
|
|
_MODEL_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "model") |
|
|
_VOCAB_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data") |
|
|
_MODEL_PATH = os.path.join(_MODEL_DIR, "rellow-2.pt") |
|
|
_VOCAB_PATH = os.path.join(_VOCAB_DIR, "vocab.json") |
|
|
|
|
|
|
|
|
_DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu") |
|
|
|
|
|
|
|
|
def save_model(model, vocab): |
|
|
|
|
|
os.makedirs(_MODEL_DIR, exist_ok=True) |
|
|
os.makedirs(_VOCAB_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
torch.save(model.state_dict(), _MODEL_PATH) |
|
|
|
|
|
|
|
|
with open(_VOCAB_PATH, "w", encoding="utf-8") as f: |
|
|
json.dump(vocab, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
print(f"Model saved to {_MODEL_PATH}") |
|
|
print(f"Vocabulary saved to {_VOCAB_PATH}") |
|
|
|
|
|
|
|
|
def load_model(): |
|
|
|
|
|
with open(_VOCAB_PATH, "r", encoding="utf-8") as f: |
|
|
vocab = json.load(f) |
|
|
inv_vocab = {int(v): k for k, v in vocab.items()} |
|
|
|
|
|
|
|
|
model = TinyTransformer(vocab_size=len(vocab)).to(_DEVICE) |
|
|
model.load_state_dict(torch.load(_MODEL_PATH, map_location=_DEVICE)) |
|
|
model.eval() |
|
|
|
|
|
return model, vocab, inv_vocab |
|
|
|
|
|
|
|
|
def get_device(): |
|
|
return _DEVICE |
|
|
|