from fastapi import FastAPI, UploadFile, Form, File, HTTPException from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from gradio_client import Client, file import os import shutil import base64 import traceback app = FastAPI() # Allow CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # client = Client("yisol/IDM-VTON") client = Client("kadirnar/IDM-VTON") # Directory to save uploaded and processed files UPLOAD_FOLDER = 'static/uploads' RESULT_FOLDER = 'static/results' os.makedirs(UPLOAD_FOLDER, exist_ok=True) os.makedirs(RESULT_FOLDER, exist_ok=True) @app.post("/") async def hello(): return {"Wearon":"wearon model is running"} @app.post("/process") async def predict(product_image_url: str = Form(...), model_image: UploadFile = File(...)): try: if not model_image: raise HTTPException(status_code=400, detail="No model image file provided") # Save the uploaded file to the upload directory filename = os.path.join(UPLOAD_FOLDER, model_image.filename) with open(filename, "wb") as buffer: shutil.copyfileobj(model_image.file, buffer) base_path = os.getcwd() full_filename = os.path.normpath(os.path.join(base_path, filename)) print("Product image = ", product_image_url) print("Model image = ", full_filename) # Perform prediction try: result = await client.predict( dict={"background": file(full_filename), "layers": [], "composite": None}, garm_img=file(product_image_url), garment_des="Hello!!", is_checked=True, is_checked_crop=False, denoise_steps=30, seed=42, api_name="/tryon" ) except Exception as e: traceback.print_exc() raise print(result) # Extract the path of the first output image output_image_path = result[0] # Copy the output image to the RESULT_FOLDER output_image_filename = os.path.basename(output_image_path) local_output_path = os.path.join(RESULT_FOLDER, output_image_filename) shutil.copy(output_image_path, local_output_path) # Remove the uploaded file after processing os.remove(filename) # Encode the output image in base64 with open(local_output_path, "rb") as image_file: encoded_image = base64.b64encode(image_file.read()).decode('utf-8') # Return the output image in JSON format return JSONResponse(content={"image": encoded_image}, status_code=200) except Exception as e: traceback.print_exc() raise HTTPException(status_code=500, detail=str(e)) @app.get("/uploads/{filename}") async def uploaded_file(filename: str): file_path = os.path.join(UPLOAD_FOLDER, filename) if os.path.exists(file_path): return FileResponse(file_path) else: raise HTTPException(status_code=404, detail="File not found")