File size: 3,175 Bytes
cf1378b
 
 
46d3354
 
b4cb033
 
cf1378b
 
 
d7228b2
cf1378b
 
 
 
 
 
 
 
d7228b2
cf1378b
46d3354
d7228b2
cf1378b
46d3354
 
ec90435
 
b4cb033
cf1378b
 
 
80222e0
cf1378b
 
 
21d171d
cf1378b
 
ec90435
cf1378b
 
 
 
 
46d3354
 
cf1378b
ec90435
 
 
cf1378b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46d3354
cf1378b
b4cb033
cf1378b
 
b4cb033
cf1378b
ec90435
cf1378b
 
46d3354
cf1378b
 
ec90435
 
cf1378b
 
 
 
21d171d
b4cb033
cf1378b
21d171d
cf1378b
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from fastapi import FastAPI, UploadFile, Form, File, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from gradio_client import Client, file
import os
import shutil
import base64
import traceback

app = FastAPI()

# Allow CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# client = Client("yisol/IDM-VTON")
client = Client("kadirnar/IDM-VTON")

# Directory to save uploaded and processed files
UPLOAD_FOLDER = 'static/uploads'
RESULT_FOLDER = 'static/results'
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(RESULT_FOLDER, exist_ok=True)

@app.post("/")
async def hello():
    return {"Wearon":"wearon model is running"}


@app.post("/process")
async def predict(product_image_url: str = Form(...), model_image: UploadFile = File(...)):
    try:
        if not model_image:
            raise HTTPException(status_code=400, detail="No model image file provided")
        
        # Save the uploaded file to the upload directory
        filename = os.path.join(UPLOAD_FOLDER, model_image.filename)
        with open(filename, "wb") as buffer:
            shutil.copyfileobj(model_image.file, buffer)

        base_path = os.getcwd()
        full_filename = os.path.normpath(os.path.join(base_path, filename))

        print("Product image = ", product_image_url)
        print("Model image = ", full_filename)
        
        # Perform prediction
        try:
            result = await client.predict(
                dict={"background": file(full_filename), "layers": [], "composite": None},
                garm_img=file(product_image_url),
                garment_des="Hello!!",
                is_checked=True,
                is_checked_crop=False,
                denoise_steps=30,
                seed=42,
                api_name="/tryon"
            )
        except Exception as e:
            traceback.print_exc()
            raise

        print(result)
        # Extract the path of the first output image
        output_image_path = result[0]

        # Copy the output image to the RESULT_FOLDER
        output_image_filename = os.path.basename(output_image_path)
        local_output_path = os.path.join(RESULT_FOLDER, output_image_filename)
        shutil.copy(output_image_path, local_output_path)

        # Remove the uploaded file after processing
        os.remove(filename)

        # Encode the output image in base64
        with open(local_output_path, "rb") as image_file:
            encoded_image = base64.b64encode(image_file.read()).decode('utf-8')

        # Return the output image in JSON format
        return JSONResponse(content={"image": encoded_image}, status_code=200)

    except Exception as e:
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/uploads/{filename}")
async def uploaded_file(filename: str):
    file_path = os.path.join(UPLOAD_FOLDER, filename)
    if os.path.exists(file_path):
        return FileResponse(file_path)
    else:
        raise HTTPException(status_code=404, detail="File not found")