Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import numpy as np | |
| import torch | |
| app = FastAPI() | |
| # Check if CUDA is available | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda:0") | |
| else: | |
| device = torch.device("cpu") | |
| # Load the tokenizer and model | |
| tokenizer = AutoTokenizer.from_pretrained("kmack/malicious-url-detection") | |
| model = AutoModelForSequenceClassification.from_pretrained("kmack/malicious-url-detection") | |
| model = model.to(device) | |
| # Define the request model | |
| class URLRequest(BaseModel): | |
| url: str | |
| # Prediction function | |
| def get_prediction(input_text: str) -> dict: | |
| label2id = model.config.label2id | |
| inputs = tokenizer(input_text, return_tensors='pt', truncation=True) | |
| inputs = inputs.to(device) | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| sigmoid = torch.nn.Sigmoid() | |
| probs = sigmoid(logits.squeeze().cpu()) | |
| probs = probs.detach().numpy() | |
| for i, k in enumerate(label2id.keys()): | |
| label2id[k] = probs[i] | |
| label2id = {k: float(v) for k, v in sorted(label2id.items(), key=lambda item: item[1].item(), reverse=True)} | |
| return label2id | |
| # Define the API endpoint for URL prediction | |
| async def predict(url_request: URLRequest): | |
| url_to_check = url_request.url | |
| result = get_prediction(url_to_check) | |
| return {"prediction": result} | |
| # Health check endpoint | |
| async def read_root(): | |
| return {"message": "API is up and running"} | |