V2APIToCheck / app.py
ZealPyae's picture
Update app.py
609af13 verified
# from fastapi import FastAPI
# app = FastAPI()
# @app.get("/")
# def greet_json():
# return {"Hello": "World!"}
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import torch
app = FastAPI()
# Check if CUDA is available
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
# Define the request model
class URLRequest(BaseModel):
url: str
# Load the tokenizer and model using pipeline
pipe = pipeline("text-classification", model="kmack/malicious-url-detection", device=device.index if torch.cuda.is_available() else -1)
# Define the prediction function
def get_prediction(url_to_check: str):
result = pipe(url_to_check)
return result
# Define the API endpoint for URL prediction
@app.post("/predict")
async def predict(url_request: URLRequest):
url_to_check = url_request.url
result = get_prediction(url_to_check)
return {"prediction": result}
# Health check endpoint
@app.get("/")
async def read_root():
return {"message": "API is up and running"}