Spaces:
Running
on
T4
Running
on
T4
from fastapi import FastAPI, Request | |
from pydantic import BaseModel | |
from pathlib import Path | |
import torch | |
from fastapi import UploadFile, File | |
import os | |
from fastapi.middleware.cors import CORSMiddleware | |
from ROBERTAmodel import * | |
from BERTmodel import * | |
from DISTILLBERTmodel import * | |
import os | |
import zipfile | |
import shutil | |
from fastapi import Form | |
from fastapi import UploadFile, File, Form | |
from pathlib import Path | |
VISUALIZER_CLASSES = { | |
"BERT": BERTVisualizer, | |
"RoBERTa": RoBERTaVisualizer, | |
"DistilBERT": DistilBERTVisualizer, | |
} | |
VISUALIZER_CACHE = {} | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
MODEL_MAP = { | |
"BERT": "bert-base-uncased", | |
"RoBERTa": "roberta-base", | |
"DistilBERT": "distilbert-base-uncased", | |
} | |
class LoadModelRequest(BaseModel): | |
model: str | |
sentence: str | |
task:str | |
hypothesis:str | |
class GradAttnModelRequest(BaseModel): | |
model: str | |
task: str | |
sentence: str | |
hypothesis:str | |
maskID: int | None = None | |
class PredModelRequest(BaseModel): | |
model: str | |
sentence: str | |
task:str | |
hypothesis:str | |
maskID: int | None = None | |
async def upload_model(file: UploadFile = File(...)): | |
save_path = f"/data/models/{file.filename}" # or wherever your disk is mounted | |
os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
with open(save_path, "wb") as f: | |
f.write(await file.read()) | |
return {"status": "uploaded", "path": save_path} | |
def load_model(req: LoadModelRequest): | |
print(f"\n--- /load_model request received ---") | |
print(f"Model: {req.model}") | |
print(f"Sentence: {req.sentence}") | |
print(f"Task: {req.task}") | |
print(f"hypothesis: {req.hypothesis}") | |
if req.model in VISUALIZER_CACHE: | |
del VISUALIZER_CACHE[req.model] | |
torch.cuda.empty_cache() | |
vis_class = VISUALIZER_CLASSES.get(req.model) | |
if vis_class is None: | |
return {"error": f"Unknown model: {req.model}"} | |
print("instantiating visualizer") | |
try: | |
vis = vis_class(task=req.task.lower()) | |
print(vis) | |
VISUALIZER_CACHE[req.model] = vis | |
print("Visualizer instantiated") | |
except Exception as e: | |
print("Visualizer init failed:", e) | |
return {"error": f"Instantiation failed: {str(e)}"} | |
print('tokenizing') | |
try: | |
if req.task.lower() == 'mnli': | |
token_output = vis.tokenize(req.sentence, hypothesis=req.hypothesis) | |
else: | |
token_output = vis.tokenize(req.sentence) | |
print("0 Tokenization successful:", token_output["tokens"]) | |
except Exception as e: | |
print("Tokenization failed:", e) | |
return {"error": f"Tokenization failed: {str(e)}"} | |
print('response ready') | |
response = { | |
"model": req.model, | |
"tokens": token_output['tokens'], | |
"num_layers": vis.num_attention_layers, | |
} | |
print("load model successful") | |
print(response) | |
return response | |
def predict_model(req: PredModelRequest): | |
print(f"\n--- /predict_model request received ---") | |
print(f"predict: Model: {req.model}") | |
print(f"predict: Task: {req.task}") | |
print(f"predict: sentence: {req.sentence}") | |
print(f"predict: hypothesis: {req.hypothesis}") | |
print(f"predict: maskID: {req.maskID}") | |
print('predict: instantiating') | |
try: | |
vis_class = VISUALIZER_CLASSES.get(req.model) | |
if vis_class is None: | |
return {"error": f"Unknown model: {req.model}"} | |
#if any(p.device.type == 'meta' for p in vis.model.parameters()): | |
# vis.model = torch.nn.Module.to_empty(vis.model, device=torch.device("cpu")) | |
vis = vis_class(task=req.task.lower()) | |
VISUALIZER_CACHE[req.model] = vis | |
print("Model reloaded and cached.") | |
except Exception as e: | |
return {"error": f"Failed to reload model: {str(e)}"} | |
print('predict: meta stuff') | |
print('predict: Run prediction') | |
try: | |
if req.task.lower() == 'mnli': | |
decoded, top_probs = vis.predict(req.task.lower(), req.sentence, hypothesis=req.hypothesis) | |
elif req.task.lower() == 'mlm': | |
decoded, top_probs = vis.predict(req.task.lower(), req.sentence, maskID=req.maskID) | |
else: | |
decoded, top_probs = vis.predict(req.task.lower(), req.sentence) | |
except Exception as e: | |
decoded, top_probs = "error", e | |
print(e) | |
print('predict: response ready') | |
response = { | |
"decoded": decoded, | |
"top_probs": top_probs.tolist(), | |
} | |
print("predict: predict model successful") | |
if len(decoded) > 5: | |
print([(k,v[:5]) for k,v in response.items()]) | |
else: | |
print(response) | |
return response | |
def get_grad_attn_matrix(req: GradAttnModelRequest): | |
try: | |
print(f"\n--- /get_grad_matrix request received ---") | |
print(f"grad:Model: {req.model}") | |
print(f"grad:Task: {req.task}") | |
print(f"grad:sentence: {req.sentence}") | |
print(f"grad: hypothesis: {req.hypothesis}") | |
print(f"predict: maskID: {req.maskID}") | |
try: | |
vis_class = VISUALIZER_CLASSES.get(req.model) | |
if vis_class is None: | |
return {"error": f"Unknown model: {req.model}"} | |
#if any(p.device.type == 'meta' for p in vis.model.parameters()): | |
# vis.model = torch.nn.Module.to_empty(vis.model, device=torch.device("cpu")) | |
vis = vis_class(task=req.task.lower()) | |
VISUALIZER_CACHE[req.model] = vis | |
print("Model reloaded and cached.") | |
except Exception as e: | |
return {"error": f"Failed to reload model: {str(e)}"} | |
print("run function") | |
try: | |
if req.task.lower()=='mnli': | |
grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence,hypothesis=req.hypothesis) | |
elif req.task.lower()=='mlm': | |
grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence,maskID=req.maskID) | |
else: | |
grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence) | |
except Exception as e: | |
print("Exception during grad/attn computation:", e) | |
grad_matrix, attn_matrix = e,e | |
response = { | |
"grad_matrix": grad_matrix, | |
"attn_matrix": attn_matrix, | |
} | |
print('grad attn successful') | |
return response | |
except Exception as e: | |
print("SERVER EXCEPTION:", e) | |
return {"error": str(e)} | |
################################################## | |
def ping(): | |
return {"message": "pong"} | |
async def upload_to_path( | |
file: UploadFile = File(...), | |
dest_path: str = Form(...) # e.g., "models/model.pt" | |
): | |
full_path = Path("/data") / dest_path | |
full_path.parent.mkdir(parents=True, exist_ok=True) | |
with open(full_path, "wb") as f: | |
f.write(await file.read()) | |
return {"status": "uploaded", "path": str(full_path)} | |
def make_directory( | |
dir_path: str = Form(...) # e.g., "logs/test_run" | |
): | |
full_dir = Path("/data") / dir_path | |
full_dir.mkdir(parents=True, exist_ok=True) | |
return {"status": "created", "directory": str(full_dir)} | |
def list_data(): | |
base_path = Path("/data") | |
all_items = [] | |
for path in base_path.rglob("*"): # recursive glob | |
all_items.append({ | |
"path": str(path.relative_to(base_path)), | |
"type": "dir" if path.is_dir() else "file", | |
"size": path.stat().st_size if path.is_file() else None | |
}) | |
return {"items": all_items} | |
def purge_data(): | |
base_path = Path("/data") | |
if not base_path.exists(): | |
return {"status": "error", "message": "/data does not exist"} | |
deleted = [] | |
for child in base_path.iterdir(): | |
try: | |
if child.is_file() or child.is_symlink(): | |
child.unlink() | |
elif child.is_dir(): | |
shutil.rmtree(child) | |
deleted.append(str(child.name)) | |
except Exception as e: | |
deleted.append(f"FAILED: {child.name} ({e})") | |
return { | |
"status": "done", | |
"deleted": deleted, | |
"total": len(deleted) | |
} | |
""" | |
if __name__ == "__main__": | |
print('rim ') | |
BERTVisualizer('mlm') | |
BERTVisualizer('mnli') | |
BERTVisualizer('sst') | |
RoBERTaVisualizer('mlm') | |
RoBERTaVisualizer('mnli') | |
RoBERTaVisualizer('sst') | |
DistilBERTVisualizer('mlm') | |
DistilBERTVisualizer('mnli') | |
DistilBERTVisualizer('sst') | |
""" |