File size: 4,137 Bytes
9c4b01e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from dotenv import load_dotenv
import os
from prediction import genconvit_video_prediction
from utils.db import supabase_client
import json
import requests
from utils.utils import upload_file
import redis
from rq import Queue, Worker, Connection
import uvicorn
import torch
os.environ['TORCH_HOME'] = './cache'
torch.hub.set_dir('./cache')
os.environ["HUGGINGFACE_HUB_CACHE"] = "./cache"
load_dotenv()
# Environment variables
R2_ACCESS_KEY = os.getenv('R2_ACCESS_KEY')
R2_SECRET_KEY = os.getenv('R2_SECRET_KEY')
R2_BUCKET_NAME = os.getenv('R2_BUCKET_NAME')
R2_ENDPOINT_URL = os.getenv('R2_ENDPOINT_URL')
UPSTASH_REDIS_REST_URL = os.getenv('UPSTASH_REDIS_REST_URL')
UPSTASH_REDIS_REST_TOKEN = os.getenv('UPSTASH_REDIS_REST_TOKEN')
# Redis connection
r = redis.Redis(
host=UPSTASH_REDIS_REST_URL,
port=6379,
password=UPSTASH_REDIS_REST_TOKEN,
ssl=True
)
q = Queue('video-predictions', connection=r)
# FastAPI initialization
app = FastAPI()
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Update with your domain
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Pydantic models for request validation
class PredictionRequest(BaseModel):
video_url: str
query_id: str
class DetectFacesRequest(BaseModel):
video_url: str
# Prediction queue resolver
def predictionQueueResolver(prediction_data):
data = json.loads(prediction_data)
video_url = data['mediaUrl']
query_id = data['queryId']
if not video_url:
raise HTTPException(status_code=400, detail="No video URL provided")
try:
result = genconvit_video_prediction(video_url)
output = {
"fd": "0",
"gan": "0",
"wave_grad": "0",
"wave_rnn": "0"
}
transaction = {
"status": "success",
"score": result['score'],
"output": json.dumps(output),
}
print(result)
supabase_client.table('Result').update(transaction).eq('query_id', query_id).execute()
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# @app.get("/")
# def health():
# return "APP is Ready"
# @app.get("/health")
# def health():
# return "Healthy AI API"
@app.post("/predict")
def predict(request: PredictionRequest):
try:
result = genconvit_video_prediction(request.video_url)
output = {
"fd": "0",
"gan": "0",
"wave_grad": "0",
"wave_rnn": "0"
}
transaction = {
"status": "success",
"score": result['score'],
"output": json.dumps(output),
}
supabase_client.table('Result').update(transaction).eq('query_id', request.query_id).execute()
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/detect-faces")
def detect_faces(request: DetectFacesRequest):
try:
frames = detect_faces(request.video_url)
res = []
for frame in frames:
upload_file(f'{frame}', 'outputs', frame.split('/')[-1], R2_ENDPOINT_URL, R2_ACCESS_KEY, R2_SECRET_KEY)
res.append(f'https://pub-08a118f4cb7c4b208b55e6877b0bacca.r2.dev/outputs/{frame.split("/")[-1]}')
return res
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Uncomment to start worker and fetch queue data
# def fetch_and_enqueue():
# response = requests.get(UPSTASH_REDIS_REST_URL)
# if response.status_code == 200:
# data = response.json()
# for item in data['items']:
# prediction_data = item.get('prediction')
# q.enqueue(predictionQueueResolver, prediction_data)
if __name__ == '__main__':
uvicorn.run(app, host='0.0.0.0', port=8000)
# with Connection(r):
# worker = Worker([q])
# worker.work()
# fetch_and_enqueue()
|