| | from pydantic import BaseModel |
| | from io import BytesIO |
| | import requests |
| | from model import TransformerSeq2Seq,translate |
| | from utils import load_tokenizers_and_embeddings |
| |
|
| | import torch |
| |
|
| | |
| |
|
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | |
| | resources = load_tokenizers_and_embeddings() |
| | tokenizer_vi = resources["tokenizer_vi"] |
| | embedding_matrix_vi = resources["embedding_vi"] |
| | tokenizer_en = resources["tokenizer_en"] |
| | embedding_matrix_en = resources["embedding_en"] |
| | device = resources["device"] |
| |
|
| | print("✅ Tokenizers & embeddings loaded!") |
| | if isinstance(embedding_matrix_en, torch.Tensor): |
| | embed_dim = embedding_matrix_en.size(1) |
| | else: |
| | embed_dim = embedding_matrix_en.embedding_dim |
| | max_len = 128 |
| | batch_size = 32 |
| | |
| | model = TransformerSeq2Seq( |
| | embed_dim=embed_dim, |
| | vocab_size=tokenizer_vi.vocab_size, |
| | embedding_decoder=embedding_matrix_vi, |
| | num_heads=4, |
| | num_layers=2, |
| | dim_feedforward=256, |
| | dropout=0.1, |
| | freeze_decoder_emb=True, |
| | max_len=max_len |
| | ) |
| | MODEL_URL = "https://huggingface.co/nemabruh404/Machine_Translation/resolve/main/model_state_dict.pt" |
| |
|
| | |
| | checkpoint_bytes = BytesIO(requests.get(MODEL_URL).content) |
| | checkpoint = torch.load(checkpoint_bytes, map_location=device) |
| |
|
| | |
| | model.load_state_dict(checkpoint["model_state_dict"]) |
| | model.to(device) |
| | model.eval() |
| |
|
| | print("✅ Model loaded from Hugging Face Hub") |
| | print("Model loaded") |
| |
|
| | def hf_inference_fn(inputs: str): |
| | return translate( |
| | model=model, |
| | src_sentence=inputs, |
| | tokenizer_src=tokenizer_en, |
| | tokenizer_tgt=tokenizer_vi, |
| | embedding_src=embedding_matrix_en, |
| | device=device, |
| | max_len=max_len |
| | ) |
| | from fastapi import FastAPI |
| | from pydantic import BaseModel |
| |
|
| | app = FastAPI() |
| |
|
| | class TranslateRequest(BaseModel): |
| | text: str |
| |
|
| | @app.post("/translate") |
| | def translate(req: TranslateRequest): |
| | return {"translation": hf_inference_fn(req.text)} |
| |
|