BERTGradGraph / server.py
yifan0sun's picture
Upload 5 files
04ccab0 verified
from fastapi import FastAPI, Request
from pydantic import BaseModel
from pathlib import Path
import torch
from fastapi import UploadFile, File
import os
from fastapi.middleware.cors import CORSMiddleware
from ROBERTAmodel import *
from BERTmodel import *
from DISTILLBERTmodel import *
import os
import zipfile
import shutil
from fastapi import Form
from fastapi import UploadFile, File, Form
from pathlib import Path
VISUALIZER_CLASSES = {
"BERT": BERTVisualizer,
"RoBERTa": RoBERTaVisualizer,
"DistilBERT": DistilBERTVisualizer,
}
VISUALIZER_CACHE = {}
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
MODEL_MAP = {
"BERT": "bert-base-uncased",
"RoBERTa": "roberta-base",
"DistilBERT": "distilbert-base-uncased",
}
class LoadModelRequest(BaseModel):
model: str
sentence: str
task:str
hypothesis:str
class GradAttnModelRequest(BaseModel):
model: str
task: str
sentence: str
hypothesis:str
maskID: int | None = None
class PredModelRequest(BaseModel):
model: str
sentence: str
task:str
hypothesis:str
maskID: int | None = None
@app.post("/upload_model")
async def upload_model(file: UploadFile = File(...)):
save_path = f"/data/models/{file.filename}" # or wherever your disk is mounted
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, "wb") as f:
f.write(await file.read())
return {"status": "uploaded", "path": save_path}
@app.post("/load_model")
def load_model(req: LoadModelRequest):
print(f"\n--- /load_model request received ---")
print(f"Model: {req.model}")
print(f"Sentence: {req.sentence}")
print(f"Task: {req.task}")
print(f"hypothesis: {req.hypothesis}")
if req.model in VISUALIZER_CACHE:
del VISUALIZER_CACHE[req.model]
torch.cuda.empty_cache()
vis_class = VISUALIZER_CLASSES.get(req.model)
if vis_class is None:
return {"error": f"Unknown model: {req.model}"}
print("instantiating visualizer")
try:
vis = vis_class(task=req.task.lower())
print(vis)
VISUALIZER_CACHE[req.model] = vis
print("Visualizer instantiated")
except Exception as e:
print("Visualizer init failed:", e)
return {"error": f"Instantiation failed: {str(e)}"}
print('tokenizing')
try:
if req.task.lower() == 'mnli':
token_output = vis.tokenize(req.sentence, hypothesis=req.hypothesis)
else:
token_output = vis.tokenize(req.sentence)
print("0 Tokenization successful:", token_output["tokens"])
except Exception as e:
print("Tokenization failed:", e)
return {"error": f"Tokenization failed: {str(e)}"}
print('response ready')
response = {
"model": req.model,
"tokens": token_output['tokens'],
"num_layers": vis.num_attention_layers,
}
print("load model successful")
print(response)
return response
@app.post("/predict_model")
def predict_model(req: PredModelRequest):
print(f"\n--- /predict_model request received ---")
print(f"predict: Model: {req.model}")
print(f"predict: Task: {req.task}")
print(f"predict: sentence: {req.sentence}")
print(f"predict: hypothesis: {req.hypothesis}")
print(f"predict: maskID: {req.maskID}")
print('predict: instantiating')
try:
vis_class = VISUALIZER_CLASSES.get(req.model)
if vis_class is None:
return {"error": f"Unknown model: {req.model}"}
#if any(p.device.type == 'meta' for p in vis.model.parameters()):
# vis.model = torch.nn.Module.to_empty(vis.model, device=torch.device("cpu"))
vis = vis_class(task=req.task.lower())
VISUALIZER_CACHE[req.model] = vis
print("Model reloaded and cached.")
except Exception as e:
return {"error": f"Failed to reload model: {str(e)}"}
print('predict: meta stuff')
print('predict: Run prediction')
try:
if req.task.lower() == 'mnli':
decoded, top_probs = vis.predict(req.task.lower(), req.sentence, hypothesis=req.hypothesis)
elif req.task.lower() == 'mlm':
decoded, top_probs = vis.predict(req.task.lower(), req.sentence, maskID=req.maskID)
else:
decoded, top_probs = vis.predict(req.task.lower(), req.sentence)
except Exception as e:
decoded, top_probs = "error", e
print(e)
print('predict: response ready')
response = {
"decoded": decoded,
"top_probs": top_probs.tolist(),
}
print("predict: predict model successful")
if len(decoded) > 5:
print([(k,v[:5]) for k,v in response.items()])
else:
print(response)
return response
@app.post("/get_grad_attn_matrix")
def get_grad_attn_matrix(req: GradAttnModelRequest):
try:
print(f"\n--- /get_grad_matrix request received ---")
print(f"grad:Model: {req.model}")
print(f"grad:Task: {req.task}")
print(f"grad:sentence: {req.sentence}")
print(f"grad: hypothesis: {req.hypothesis}")
print(f"predict: maskID: {req.maskID}")
try:
vis_class = VISUALIZER_CLASSES.get(req.model)
if vis_class is None:
return {"error": f"Unknown model: {req.model}"}
#if any(p.device.type == 'meta' for p in vis.model.parameters()):
# vis.model = torch.nn.Module.to_empty(vis.model, device=torch.device("cpu"))
vis = vis_class(task=req.task.lower())
VISUALIZER_CACHE[req.model] = vis
print("Model reloaded and cached.")
except Exception as e:
return {"error": f"Failed to reload model: {str(e)}"}
print("run function")
try:
if req.task.lower()=='mnli':
grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence,hypothesis=req.hypothesis)
elif req.task.lower()=='mlm':
grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence,maskID=req.maskID)
else:
grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence)
except Exception as e:
print("Exception during grad/attn computation:", e)
grad_matrix, attn_matrix = e,e
response = {
"grad_matrix": grad_matrix,
"attn_matrix": attn_matrix,
}
print('grad attn successful')
return response
except Exception as e:
print("SERVER EXCEPTION:", e)
return {"error": str(e)}
##################################################
@app.get("/ping")
def ping():
return {"message": "pong"}
@app.post("/upload_to_path")
async def upload_to_path(
file: UploadFile = File(...),
dest_path: str = Form(...) # e.g., "models/model.pt"
):
full_path = Path("/data") / dest_path
full_path.parent.mkdir(parents=True, exist_ok=True)
with open(full_path, "wb") as f:
f.write(await file.read())
return {"status": "uploaded", "path": str(full_path)}
@app.post("/make_dir")
def make_directory(
dir_path: str = Form(...) # e.g., "logs/test_run"
):
full_dir = Path("/data") / dir_path
full_dir.mkdir(parents=True, exist_ok=True)
return {"status": "created", "directory": str(full_dir)}
@app.get("/list_data")
def list_data():
base_path = Path("/data")
all_items = []
for path in base_path.rglob("*"): # recursive glob
all_items.append({
"path": str(path.relative_to(base_path)),
"type": "dir" if path.is_dir() else "file",
"size": path.stat().st_size if path.is_file() else None
})
return {"items": all_items}
@app.post("/purge_data_123456789")
def purge_data():
base_path = Path("/data")
if not base_path.exists():
return {"status": "error", "message": "/data does not exist"}
deleted = []
for child in base_path.iterdir():
try:
if child.is_file() or child.is_symlink():
child.unlink()
elif child.is_dir():
shutil.rmtree(child)
deleted.append(str(child.name))
except Exception as e:
deleted.append(f"FAILED: {child.name} ({e})")
return {
"status": "done",
"deleted": deleted,
"total": len(deleted)
}
"""
if __name__ == "__main__":
print('rim ')
BERTVisualizer('mlm')
BERTVisualizer('mnli')
BERTVisualizer('sst')
RoBERTaVisualizer('mlm')
RoBERTaVisualizer('mnli')
RoBERTaVisualizer('sst')
DistilBERTVisualizer('mlm')
DistilBERTVisualizer('mnli')
DistilBERTVisualizer('sst')
"""