|
from fastapi import FastAPI, HTTPException |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from pydantic import BaseModel |
|
from dotenv import load_dotenv |
|
import os |
|
from prediction import genconvit_video_prediction |
|
from utils.db import supabase_client |
|
import json |
|
import requests |
|
from utils.utils import upload_file |
|
import redis |
|
from rq import Queue, Worker, Connection |
|
import uvicorn |
|
import torch |
|
os.environ['TORCH_HOME'] = './cache' |
|
torch.hub.set_dir('./cache') |
|
os.environ["HUGGINGFACE_HUB_CACHE"] = "./cache" |
|
|
|
load_dotenv() |
|
|
|
|
|
R2_ACCESS_KEY = os.getenv('R2_ACCESS_KEY') |
|
R2_SECRET_KEY = os.getenv('R2_SECRET_KEY') |
|
R2_BUCKET_NAME = os.getenv('R2_BUCKET_NAME') |
|
R2_ENDPOINT_URL = os.getenv('R2_ENDPOINT_URL') |
|
UPSTASH_REDIS_REST_URL = os.getenv('UPSTASH_REDIS_REST_URL') |
|
UPSTASH_REDIS_REST_TOKEN = os.getenv('UPSTASH_REDIS_REST_TOKEN') |
|
|
|
|
|
r = redis.Redis( |
|
host=UPSTASH_REDIS_REST_URL, |
|
port=6379, |
|
password=UPSTASH_REDIS_REST_TOKEN, |
|
ssl=True |
|
) |
|
|
|
q = Queue('video-predictions', connection=r) |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
class PredictionRequest(BaseModel): |
|
video_url: str |
|
query_id: str |
|
|
|
class DetectFacesRequest(BaseModel): |
|
video_url: str |
|
|
|
|
|
def predictionQueueResolver(prediction_data): |
|
data = json.loads(prediction_data) |
|
video_url = data['mediaUrl'] |
|
query_id = data['queryId'] |
|
if not video_url: |
|
raise HTTPException(status_code=400, detail="No video URL provided") |
|
|
|
try: |
|
result = genconvit_video_prediction(video_url) |
|
output = { |
|
"fd": "0", |
|
"gan": "0", |
|
"wave_grad": "0", |
|
"wave_rnn": "0" |
|
} |
|
transaction = { |
|
"status": "success", |
|
"score": result['score'], |
|
"output": json.dumps(output), |
|
} |
|
print(result) |
|
supabase_client.table('Result').update(transaction).eq('query_id', query_id).execute() |
|
return result |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/predict") |
|
def predict(request: PredictionRequest): |
|
try: |
|
result = genconvit_video_prediction(request.video_url) |
|
output = { |
|
"fd": "0", |
|
"gan": "0", |
|
"wave_grad": "0", |
|
"wave_rnn": "0" |
|
} |
|
transaction = { |
|
"status": "success", |
|
"score": result['score'], |
|
"output": json.dumps(output), |
|
} |
|
supabase_client.table('Result').update(transaction).eq('query_id', request.query_id).execute() |
|
return result |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.post("/detect-faces") |
|
def detect_faces(request: DetectFacesRequest): |
|
try: |
|
frames = detect_faces(request.video_url) |
|
|
|
res = [] |
|
for frame in frames: |
|
upload_file(f'{frame}', 'outputs', frame.split('/')[-1], R2_ENDPOINT_URL, R2_ACCESS_KEY, R2_SECRET_KEY) |
|
res.append(f'https://pub-08a118f4cb7c4b208b55e6877b0bacca.r2.dev/outputs/{frame.split("/")[-1]}') |
|
|
|
return res |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
uvicorn.run(app, host='0.0.0.0', port=8000) |
|
|
|
|
|
|
|
|
|
|