from fastapi import FastAPI, File, UploadFile from fastapi import FastAPI, File, UploadFile, Form, Request from fastapi.responses import HTMLResponse, FileResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse from fastapi.responses import StreamingResponse from gradio_client import Client import os import io app = FastAPI() hf_token = os.environ.get('HF_TOKEN') client = Client("https://ashrafb-ifur.hf.space/", hf_token=hf_token) import tempfile import base64 from fastapi.middleware.cors import CORSMiddleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # Adjust as needed, '*' allows requests from any origin allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.post("/upload/") async def upload_file(file: UploadFile = File(...), version: str = Form(...), scale: int = Form(...)): with tempfile.NamedTemporaryFile(delete=False) as temp_file: temp_file.write(await file.read()) temp_file_path = temp_file.name try: result = client.predict(temp_file_path, version, scale, api_name="/predict") # Check if the result is valid if result and len(result) == 2: # Convert the image data to base64 string with open(result[0], "rb") as image_file: image_data = base64.b64encode(image_file.read()).decode("utf-8") return { "sketch_image_base64": f"data:image/png;base64,{image_data}", "result_file": result[1] } else: return {"error": "Invalid result from the prediction API."} except Exception as e: return {"error": str(e)} finally: if os.path.exists(temp_file_path): os.unlink(temp_file_path)