Spaces:
Running
Running
| import os | |
| import torch | |
| import torch.nn as nn | |
| import sentencepiece as spm | |
| from fastapi import FastAPI, Query | |
| from pydantic import BaseModel | |
| from fastapi.responses import JSONResponse | |
| # ========================================================= | |
| # CONFIG | |
| # ========================================================= | |
| ARTIFACTS_DIR = "artifacts" | |
| MODEL_PATH = os.path.join(ARTIFACTS_DIR, "responder.pt") | |
| SPM_MODEL_PATH = os.path.join(ARTIFACTS_DIR, "spm.model") | |
| MAX_LEN = 64 | |
| DEVICE = "cpu" | |
| app = FastAPI(title="Responder API", version="FINAL-GRU") | |
| # ========================================================= | |
| # TOKENIZER | |
| # ========================================================= | |
| class HybridTokenizer: | |
| def __init__(self, path): | |
| self.sp = spm.SentencePieceProcessor() | |
| self.sp.Load(path) | |
| self.vocab_size = self.sp.get_piece_size() | |
| self.pad_id = self.sp.pad_id() if self.sp.pad_id() >= 0 else 0 | |
| def encode(self, text): | |
| return self.sp.encode(text, out_type=int) | |
| def decode(self, ids): | |
| return self.sp.decode(ids) | |
| def pad(self, ids, max_len): | |
| return ids + [self.pad_id] * (max_len - len(ids)) if len(ids) < max_len else ids[:max_len] | |
| # ========================================================= | |
| # MODEL (CORRECT: GRU) | |
| # ========================================================= | |
| class ResponderModel(nn.Module): | |
| def __init__(self, vocab_size): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, 512) | |
| # ✅ GRU instead of LSTM | |
| self.encoder = nn.GRU(512, 512, batch_first=True) | |
| self.decoder = nn.GRU(512, 512, batch_first=True) | |
| # Attention | |
| self.attn = nn.MultiheadAttention(512, num_heads=16, batch_first=True) | |
| # Domain system | |
| self.domain_embed = nn.Embedding(6, 512) | |
| self.domain_head = nn.Linear(512, 6) | |
| # Brain system | |
| self.brain_head = nn.Linear(512, 512) | |
| # Projections | |
| self.brain_to_hidden = nn.Linear(512, 512) | |
| self.combined_to_embed = nn.Linear(512, 512) | |
| def forward(self, x): | |
| emb = self.embedding(x) | |
| enc_out, _ = self.encoder(emb) | |
| attn_out, _ = self.attn(enc_out, enc_out, enc_out) | |
| dec_out, _ = self.decoder(attn_out) | |
| out = self.brain_head(dec_out) | |
| return out | |
| # ========================================================= | |
| # LOAD MODEL | |
| # ========================================================= | |
| def load_model(path, vocab_size): | |
| model = ResponderModel(vocab_size).to(DEVICE) | |
| state_dict = torch.load(path, map_location=DEVICE) | |
| model.load_state_dict(state_dict, strict=False) | |
| model.eval() | |
| return model | |
| # ========================================================= | |
| # INIT | |
| # ========================================================= | |
| tokenizer = HybridTokenizer(SPM_MODEL_PATH) | |
| model = load_model(MODEL_PATH, tokenizer.vocab_size) | |
| # ========================================================= | |
| # API | |
| # ========================================================= | |
| class PredictRequest(BaseModel): | |
| text: str | |
| def health(): | |
| return {"status": "ok"} | |
| def predict(req: PredictRequest): | |
| ids = tokenizer.encode(req.text) | |
| ids = tokenizer.pad(ids, MAX_LEN) | |
| x = torch.tensor([ids], dtype=torch.long) | |
| with torch.no_grad(): | |
| out = model(x) | |
| pred = torch.argmax(out, dim=-1)[0].tolist() | |
| return {"response": tokenizer.decode(pred)} |