mimosa-ai / gradio1.py
vivekk3's picture
Upload folder using huggingface_hub
9c4b01e verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from dotenv import load_dotenv
import os
from prediction import genconvit_video_prediction
from utils.db import supabase_client
import json
import requests
from utils.utils import upload_file
import redis
from rq import Queue, Worker, Connection
import uvicorn
import torch
os.environ['TORCH_HOME'] = './cache'
torch.hub.set_dir('./cache')
os.environ["HUGGINGFACE_HUB_CACHE"] = "./cache"
load_dotenv()
# Environment variables
R2_ACCESS_KEY = os.getenv('R2_ACCESS_KEY')
R2_SECRET_KEY = os.getenv('R2_SECRET_KEY')
R2_BUCKET_NAME = os.getenv('R2_BUCKET_NAME')
R2_ENDPOINT_URL = os.getenv('R2_ENDPOINT_URL')
UPSTASH_REDIS_REST_URL = os.getenv('UPSTASH_REDIS_REST_URL')
UPSTASH_REDIS_REST_TOKEN = os.getenv('UPSTASH_REDIS_REST_TOKEN')
# Redis connection
r = redis.Redis(
host=UPSTASH_REDIS_REST_URL,
port=6379,
password=UPSTASH_REDIS_REST_TOKEN,
ssl=True
)
q = Queue('video-predictions', connection=r)
# FastAPI initialization
app = FastAPI()
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Update with your domain
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Pydantic models for request validation
class PredictionRequest(BaseModel):
video_url: str
query_id: str
class DetectFacesRequest(BaseModel):
video_url: str
# Prediction queue resolver
def predictionQueueResolver(prediction_data):
data = json.loads(prediction_data)
video_url = data['mediaUrl']
query_id = data['queryId']
if not video_url:
raise HTTPException(status_code=400, detail="No video URL provided")
try:
result = genconvit_video_prediction(video_url)
output = {
"fd": "0",
"gan": "0",
"wave_grad": "0",
"wave_rnn": "0"
}
transaction = {
"status": "success",
"score": result['score'],
"output": json.dumps(output),
}
print(result)
supabase_client.table('Result').update(transaction).eq('query_id', query_id).execute()
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# @app.get("/")
# def health():
# return "APP is Ready"
# @app.get("/health")
# def health():
# return "Healthy AI API"
@app.post("/predict")
def predict(request: PredictionRequest):
try:
result = genconvit_video_prediction(request.video_url)
output = {
"fd": "0",
"gan": "0",
"wave_grad": "0",
"wave_rnn": "0"
}
transaction = {
"status": "success",
"score": result['score'],
"output": json.dumps(output),
}
supabase_client.table('Result').update(transaction).eq('query_id', request.query_id).execute()
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/detect-faces")
def detect_faces(request: DetectFacesRequest):
try:
frames = detect_faces(request.video_url)
res = []
for frame in frames:
upload_file(f'{frame}', 'outputs', frame.split('/')[-1], R2_ENDPOINT_URL, R2_ACCESS_KEY, R2_SECRET_KEY)
res.append(f'https://pub-08a118f4cb7c4b208b55e6877b0bacca.r2.dev/outputs/{frame.split("/")[-1]}')
return res
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Uncomment to start worker and fetch queue data
# def fetch_and_enqueue():
# response = requests.get(UPSTASH_REDIS_REST_URL)
# if response.status_code == 200:
# data = response.json()
# for item in data['items']:
# prediction_data = item.get('prediction')
# q.enqueue(predictionQueueResolver, prediction_data)
if __name__ == '__main__':
uvicorn.run(app, host='0.0.0.0', port=8000)
# with Connection(r):
# worker = Worker([q])
# worker.work()
# fetch_and_enqueue()