Bhushan26 commited on
Commit
cf1378b
1 Parent(s): 80222e0

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +66 -47
main.py CHANGED
@@ -1,78 +1,97 @@
1
- from flask import Flask, request, jsonify, send_from_directory
 
 
2
  from gradio_client import Client, file
3
- from flask_cors import CORS
4
  import os
5
- import traceback
6
  import shutil
7
  import base64
 
 
 
8
 
9
- app = Flask(__name__)
10
- CORS(app)
 
 
 
 
 
 
11
 
 
12
  client = Client("kadirnar/IDM-VTON")
13
 
 
14
  UPLOAD_FOLDER = 'static/uploads'
15
  RESULT_FOLDER = 'static/results'
16
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
17
  os.makedirs(RESULT_FOLDER, exist_ok=True)
18
- app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
19
- app.config['RESULT_FOLDER'] = RESULT_FOLDER
20
 
 
 
 
21
 
22
- @app.route('/')
23
- def Home():
24
- return {"model": "Wearon is running"}
25
-
26
- @app.route('/process', methods=['POST'])
27
- def predict():
28
  try:
29
- product_image_url = request.form.get('product_image_url')
30
-
31
- if 'model_image' not in request.files:
32
- return jsonify(error='No model image file provided'), 400
33
-
34
- model_image = request.files['model_image']
35
- if model_image.filename == '':
36
- return jsonify(error='No selected file'), 400
37
-
38
- filename = os.path.join(app.config['UPLOAD_FOLDER'], model_image.filename)
39
- model_image.save(filename)
40
 
 
 
 
 
 
41
  base_path = os.getcwd()
42
  full_filename = os.path.normpath(os.path.join(base_path, filename))
43
-
44
  print("Product image = ", product_image_url)
45
  print("Model image = ", full_filename)
46
 
47
- result = client.predict(
48
- dict={"background": file(full_filename), "layers": [], "composite": None},
49
- garm_img=file(product_image_url),
50
- garment_des="Hello!!",
51
- is_checked=True,
52
- is_checked_crop=False,
53
- denoise_steps=30,
54
- seed=42,
55
- api_name="/tryon"
56
- )
57
-
 
 
 
 
 
58
  print(result)
 
59
  output_image_path = result[0]
60
-
 
61
  output_image_filename = os.path.basename(output_image_path)
62
- local_output_path = os.path.join(app.config['RESULT_FOLDER'], output_image_filename)
63
  shutil.copy(output_image_path, local_output_path)
64
-
 
65
  os.remove(filename)
66
-
 
67
  with open(local_output_path, "rb") as image_file:
68
  encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
69
-
70
- return jsonify(image=encoded_image), 200
71
-
 
72
  except Exception as e:
73
  traceback.print_exc()
74
- return jsonify(error=str(e)), 500
75
 
76
- @app.route('/uploads/<filename>')
77
- def uploaded_file(filename):
78
- return send_from_directory(app.config['UPLOAD_FOLDER'], filename)
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, Form, File, HTTPException
2
+ from fastapi.responses import JSONResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
  from gradio_client import Client, file
 
5
  import os
 
6
  import shutil
7
  import base64
8
+ import traceback
9
+
10
+ app = FastAPI()
11
 
12
+ # Allow CORS
13
+ app.add_middleware(
14
+ CORSMiddleware,
15
+ allow_origins=["*"],
16
+ allow_credentials=True,
17
+ allow_methods=["*"],
18
+ allow_headers=["*"],
19
+ )
20
 
21
+ # client = Client("yisol/IDM-VTON")
22
  client = Client("kadirnar/IDM-VTON")
23
 
24
+ # Directory to save uploaded and processed files
25
  UPLOAD_FOLDER = 'static/uploads'
26
  RESULT_FOLDER = 'static/results'
27
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
28
  os.makedirs(RESULT_FOLDER, exist_ok=True)
 
 
29
 
30
+ @app.post("/")
31
+ async def hello():
32
+ return {"Wearon":"wearon model is running"}
33
 
34
+
35
+ @app.post("/process")
36
+ async def predict(product_image_url: str = Form(...), model_image: UploadFile = File(...)):
 
 
 
37
  try:
38
+ if not model_image:
39
+ raise HTTPException(status_code=400, detail="No model image file provided")
 
 
 
 
 
 
 
 
 
40
 
41
+ # Save the uploaded file to the upload directory
42
+ filename = os.path.join(UPLOAD_FOLDER, model_image.filename)
43
+ with open(filename, "wb") as buffer:
44
+ shutil.copyfileobj(model_image.file, buffer)
45
+
46
  base_path = os.getcwd()
47
  full_filename = os.path.normpath(os.path.join(base_path, filename))
48
+
49
  print("Product image = ", product_image_url)
50
  print("Model image = ", full_filename)
51
 
52
+ # Perform prediction
53
+ try:
54
+ result = await client.predict(
55
+ dict={"background": file(full_filename), "layers": [], "composite": None},
56
+ garm_img=file(product_image_url),
57
+ garment_des="Hello!!",
58
+ is_checked=True,
59
+ is_checked_crop=False,
60
+ denoise_steps=30,
61
+ seed=42,
62
+ api_name="/tryon"
63
+ )
64
+ except Exception as e:
65
+ traceback.print_exc()
66
+ raise
67
+
68
  print(result)
69
+ # Extract the path of the first output image
70
  output_image_path = result[0]
71
+
72
+ # Copy the output image to the RESULT_FOLDER
73
  output_image_filename = os.path.basename(output_image_path)
74
+ local_output_path = os.path.join(RESULT_FOLDER, output_image_filename)
75
  shutil.copy(output_image_path, local_output_path)
76
+
77
+ # Remove the uploaded file after processing
78
  os.remove(filename)
79
+
80
+ # Encode the output image in base64
81
  with open(local_output_path, "rb") as image_file:
82
  encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
83
+
84
+ # Return the output image in JSON format
85
+ return JSONResponse(content={"image": encoded_image}, status_code=200)
86
+
87
  except Exception as e:
88
  traceback.print_exc()
89
+ raise HTTPException(status_code=500, detail=str(e))
90
 
91
+ @app.get("/uploads/{filename}")
92
+ async def uploaded_file(filename: str):
93
+ file_path = os.path.join(UPLOAD_FOLDER, filename)
94
+ if os.path.exists(file_path):
95
+ return FileResponse(file_path)
96
+ else:
97
+ raise HTTPException(status_code=404, detail="File not found")