Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException, Request | |
from fastapi.responses import FileResponse | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.middleware.trustedhost import TrustedHostMiddleware | |
from pydantic import BaseModel | |
import os | |
from huggingface_hub import HfApi | |
import time | |
from dotenv import load_dotenv | |
load_dotenv() | |
api = HfApi(token=os.getenv("HF_TOKEN")) | |
PASSWORD = os.getenv("PASSWORD") | |
app = FastAPI() | |
repo_url = os.environ["HF_SPACE_ID"].replace("/", "-") | |
app.add_middleware( | |
TrustedHostMiddleware, | |
allowed_hosts=["localhost", f"{repo_url}.hf.space"] # Replace with your actual HF space URL | |
) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["http://localhost:7860", f"https://{repo_url}.hf.space"], # Replace with your actual HF space URL | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Rate limiting | |
class RateLimiter: | |
def __init__(self, max_attempts: int = 5, window_seconds: int = 300): | |
self.max_attempts = max_attempts | |
self.window_seconds = window_seconds | |
self.attempts = {} | |
async def check_rate_limit(self, ip: str) -> bool: | |
now = time.time() | |
if ip in self.attempts: | |
attempts = [t for t in self.attempts[ip] if now - t < self.window_seconds] | |
self.attempts[ip] = attempts | |
if len(attempts) >= self.max_attempts: | |
raise HTTPException( | |
status_code=429, | |
detail=f"Too many attempts. Try again in {self.window_seconds} seconds" | |
) | |
else: | |
self.attempts[ip] = [] | |
self.attempts[ip].append(now) | |
return True | |
rate_limiter = RateLimiter() | |
class PasswordCheck(BaseModel): | |
password: str | |
async def verify_password(password_check: PasswordCheck, request: Request): | |
await rate_limiter.check_rate_limit(request.client.host) | |
if password_check.password == PASSWORD: | |
# Return list of available items | |
items = api.list_repo_files(repo_id=os.environ["HF_DATASET_ID"], repo_type="dataset") | |
return sorted(items) | |
raise HTTPException(status_code=401, detail="Invalid password") | |
async def download_item(item_name: str, request: Request): | |
await rate_limiter.check_rate_limit(request.client.host) | |
filepath = api.hf_hub_download(repo_id=os.environ["HF_DATASET_ID"], filename=item_name, repo_type="dataset") | |
return FileResponse(filepath, filename=item_name) | |
app.mount("/", StaticFiles(directory="static", html=True), name="static") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |