mvi-ai-engine / app.py
Musombi's picture
Update app.py
6515e46 verified
import os
import torch
import torch.nn as nn
import sentencepiece as spm
from fastapi import FastAPI, Query
from pydantic import BaseModel
from fastapi.responses import JSONResponse
# =========================================================
# CONFIG
# =========================================================
ARTIFACTS_DIR = "artifacts"
MODEL_PATH = os.path.join(ARTIFACTS_DIR, "responder.pt")
SPM_MODEL_PATH = os.path.join(ARTIFACTS_DIR, "spm.model")
MAX_LEN = 64
DEVICE = "cpu"
app = FastAPI(title="Responder API", version="FINAL-GRU")
# =========================================================
# TOKENIZER
# =========================================================
class HybridTokenizer:
def __init__(self, path):
self.sp = spm.SentencePieceProcessor()
self.sp.Load(path)
self.vocab_size = self.sp.get_piece_size()
self.pad_id = self.sp.pad_id() if self.sp.pad_id() >= 0 else 0
def encode(self, text):
return self.sp.encode(text, out_type=int)
def decode(self, ids):
return self.sp.decode(ids)
def pad(self, ids, max_len):
return ids + [self.pad_id] * (max_len - len(ids)) if len(ids) < max_len else ids[:max_len]
# =========================================================
# MODEL (CORRECT: GRU)
# =========================================================
class ResponderModel(nn.Module):
def __init__(self, vocab_size):
super().__init__()
self.embedding = nn.Embedding(vocab_size, 512)
# ✅ GRU instead of LSTM
self.encoder = nn.GRU(512, 512, batch_first=True)
self.decoder = nn.GRU(512, 512, batch_first=True)
# Attention
self.attn = nn.MultiheadAttention(512, num_heads=16, batch_first=True)
# Domain system
self.domain_embed = nn.Embedding(6, 512)
self.domain_head = nn.Linear(512, 6)
# Brain system
self.brain_head = nn.Linear(512, 512)
# Projections
self.brain_to_hidden = nn.Linear(512, 512)
self.combined_to_embed = nn.Linear(512, 512)
def forward(self, x):
emb = self.embedding(x)
enc_out, _ = self.encoder(emb)
attn_out, _ = self.attn(enc_out, enc_out, enc_out)
dec_out, _ = self.decoder(attn_out)
out = self.brain_head(dec_out)
return out
# =========================================================
# LOAD MODEL
# =========================================================
def load_model(path, vocab_size):
model = ResponderModel(vocab_size).to(DEVICE)
state_dict = torch.load(path, map_location=DEVICE)
model.load_state_dict(state_dict, strict=False)
model.eval()
return model
# =========================================================
# INIT
# =========================================================
tokenizer = HybridTokenizer(SPM_MODEL_PATH)
model = load_model(MODEL_PATH, tokenizer.vocab_size)
# =========================================================
# API
# =========================================================
class PredictRequest(BaseModel):
text: str
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/predict")
def predict(req: PredictRequest):
ids = tokenizer.encode(req.text)
ids = tokenizer.pad(ids, MAX_LEN)
x = torch.tensor([ids], dtype=torch.long)
with torch.no_grad():
out = model(x)
pred = torch.argmax(out, dim=-1)[0].tolist()
return {"response": tokenizer.decode(pred)}