Bhushan26 commited on
Commit
46d3354
·
verified ·
1 Parent(s): a8d2ccf

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +51 -81
main.py CHANGED
@@ -1,128 +1,98 @@
1
- from flask import Flask
 
 
 
 
2
  import shutil
3
  import base64
4
- from gradio_client import Client, file
5
-
 
 
6
 
 
 
7
 
 
8
 
9
-
10
-
11
- app = Flask(__name__)
12
- CORS(app)
13
  # Directory to save uploaded and processed files
14
- UPLOAD_FOLDER = tempfile.mkdtemp()
15
- RESULT_FOLDER = tempfile.mkdtemp()
 
 
 
 
16
 
17
  app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
18
  app.config['RESULT_FOLDER'] = RESULT_FOLDER
19
 
20
-
21
-
22
-
23
- def predict_with_timeout(model_image_path, product_image_url, timeout=600):
24
- result = [None] # Mutable object to store the result
25
-
26
- def target():
27
- try:
28
-
29
- result[0] = client.predict(
30
- dict({"background": file(model_image_path), "layers": [], "composite": None}),
31
- garm_img=file(product_image_url),
32
- seed=42,
33
- api_name="/tryon"
34
- )
35
-
36
- except Exception as e:
37
-
38
- result[0] = str(e)
39
-
40
- thread = threading.Thread(target=target)
41
- thread.start()
42
- thread.join(timeout)
43
-
44
- if thread.is_alive():
45
- return "Prediction timed out after {} seconds".format(timeout)
46
-
47
- if isinstance(result[0], Exception):
48
- return str(result[0]) # Return the error message
49
- return result[0]
50
-
51
- @app.route('/')
52
- def index():
53
-
54
- return {'message': 'This is a wearon API'}
55
-
56
  @app.route('/process', methods=['POST'])
57
- def predict():
58
-
59
  try:
60
  # Get the product image URL from the request
61
- product_image_url = request.form.get('product_image_url')
62
- if not product_image_url:
63
-
64
- return jsonify(error='No product image URL provided'), 400
65
 
66
  # Handle the uploaded model image
67
  if 'model_image' not in request.files:
68
-
69
  return jsonify(error='No model image file provided'), 400
70
 
71
- model_image = request.files['model_image']
72
  if model_image.filename == '':
73
-
74
  return jsonify(error='No selected file'), 400
75
 
76
  # Save the uploaded file to the upload directory
77
  filename = os.path.join(app.config['UPLOAD_FOLDER'], model_image.filename)
78
- model_image.save(filename)
79
 
80
- full_filename = os.path.abspath(filename)
 
81
 
82
- print("Product image URL:", product_image_url)
83
- print("Model image path:", full_filename)
84
 
85
- # Perform prediction with a timeout
86
- result = predict_with_timeout(full_filename, product_image_url)
87
- if isinstance(result, str):
88
-
89
- return jsonify(error=result), 500
90
-
91
- print("Prediction result:", result)
 
 
 
 
 
 
 
 
92
 
93
- # Check if the result contains a valid path
 
94
  output_image_path = result[0]
95
- if not os.path.exists(output_image_path):
96
- return jsonify(error='Output image file not found: {}'.format(output_image_path)), 500
97
-
98
 
99
  # Copy the output image to the RESULT_FOLDER
100
  output_image_filename = os.path.basename(output_image_path)
101
  local_output_path = os.path.join(app.config['RESULT_FOLDER'], output_image_filename)
102
- shutil.copy(output_image_path, local_output_path)
103
-
104
 
105
  # Remove the uploaded file after processing
106
- os.remove(full_filename)
107
-
108
 
109
  # Encode the output image in base64
110
- with open(local_output_path, "rb") as image_file:
111
- encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
112
 
113
  # Return the output image in JSON format
114
-
115
  return jsonify(image=encoded_image), 200
116
 
117
  except Exception as e:
118
-
119
  traceback.print_exc()
120
  return jsonify(error=str(e)), 500
121
 
122
  @app.route('/uploads/<filename>')
123
- def uploaded_file(filename):
124
-
125
- return send_from_directory(app.config['UPLOAD_FOLDER'], filename)
126
 
127
  if __name__ == '__main__':
128
- app.run(host='0.0.0.0', port=7860)
 
1
+ from flask import Flask, request, jsonify, send_from_directory
2
+ from gradio_client import Client, file
3
+ from flask_cors import CORS
4
+ import os
5
+ import traceback
6
  import shutil
7
  import base64
8
+ import asyncio
9
+ from quart import Quart, request, jsonify, send_from_directory
10
+ from quart_cors import cors
11
+ import aiofiles
12
 
13
+ app = Quart(__name__)
14
+ cors(app)
15
 
16
+ client = Client("kadirnar/IDM-VTON")
17
 
 
 
 
 
18
  # Directory to save uploaded and processed files
19
+ UPLOAD_FOLDER = 'static/uploads'
20
+ RESULT_FOLDER = 'static/results'
21
+ if not os.path.exists(UPLOAD_FOLDER):
22
+ os.makedirs(UPLOAD_FOLDER)
23
+ if not os.path.exists(RESULT_FOLDER):
24
+ os.makedirs(RESULT_FOLDER)
25
 
26
  app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
27
  app.config['RESULT_FOLDER'] = RESULT_FOLDER
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  @app.route('/process', methods=['POST'])
30
+ async def predict():
 
31
  try:
32
  # Get the product image URL from the request
33
+ form = await request.form
34
+ product_image_url = form.get('product_image_url')
 
 
35
 
36
  # Handle the uploaded model image
37
  if 'model_image' not in request.files:
 
38
  return jsonify(error='No model image file provided'), 400
39
 
40
+ model_image = await request.files['model_image']
41
  if model_image.filename == '':
 
42
  return jsonify(error='No selected file'), 400
43
 
44
  # Save the uploaded file to the upload directory
45
  filename = os.path.join(app.config['UPLOAD_FOLDER'], model_image.filename)
46
+ await model_image.save(filename)
47
 
48
+ base_path = os.getcwd()
49
+ full_filename = os.path.normpath(os.path.join(base_path, filename))
50
 
51
+ print("Product image = ", product_image_url)
52
+ print("Model image = ", full_filename)
53
 
54
+ # Perform prediction
55
+ try:
56
+ result = await asyncio.to_thread(client.predict,
57
+ dict={"background": file(full_filename), "layers": [], "composite": None},
58
+ garm_img=file(product_image_url),
59
+ garment_des="Hello!!",
60
+ is_checked=True,
61
+ is_checked_crop=False,
62
+ denoise_steps=30,
63
+ seed=42,
64
+ api_name="/tryon"
65
+ )
66
+ except Exception as e:
67
+ traceback.print_exc()
68
+ raise
69
 
70
+ print(result)
71
+ # Extract the path of the first output image
72
  output_image_path = result[0]
 
 
 
73
 
74
  # Copy the output image to the RESULT_FOLDER
75
  output_image_filename = os.path.basename(output_image_path)
76
  local_output_path = os.path.join(app.config['RESULT_FOLDER'], output_image_filename)
77
+ await asyncio.to_thread(shutil.copy, output_image_path, local_output_path)
 
78
 
79
  # Remove the uploaded file after processing
80
+ os.remove(filename)
 
81
 
82
  # Encode the output image in base64
83
+ async with aiofiles.open(local_output_path, "rb") as image_file:
84
+ encoded_image = base64.b64encode(await image_file.read()).decode('utf-8')
85
 
86
  # Return the output image in JSON format
 
87
  return jsonify(image=encoded_image), 200
88
 
89
  except Exception as e:
 
90
  traceback.print_exc()
91
  return jsonify(error=str(e)), 500
92
 
93
  @app.route('/uploads/<filename>')
94
+ async def uploaded_file(filename):
95
+ return await send_from_directory(app.config['UPLOAD_FOLDER'], filename)
 
96
 
97
  if __name__ == '__main__':
98
+ app.run(host='0.0.0.0', port=5000)