Bhushan26 commited on
Commit
b4cb033
·
verified ·
1 Parent(s): c43143f

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +79 -19
main.py CHANGED
@@ -1,26 +1,70 @@
1
  import tempfile
2
- import os
3
- from flask import Flask, request, jsonify
 
 
4
  from flask_cors import CORS
5
- from diffusers import DiffusionPipeline
 
 
 
 
 
6
 
7
  app = Flask(__name__)
8
  CORS(app)
9
 
10
- # Load the diffusion model
11
- pipeline = DiffusionPipeline.from_pretrained("yisol/IDM-VTON")
12
 
13
- # Directory to save uploaded files
14
  UPLOAD_FOLDER = tempfile.mkdtemp()
 
 
 
 
 
 
15
  app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  @app.route('/')
18
  def index():
19
- return {'hello': 'This is a wearon API using Diffusers'}
20
 
21
  @app.route('/process', methods=['POST'])
22
- def process():
23
  try:
 
 
 
24
  # Handle the uploaded model image
25
  if 'model_image' not in request.files:
26
  return jsonify(error='No model image file provided'), 400
@@ -33,24 +77,40 @@ def process():
33
  filename = os.path.join(app.config['UPLOAD_FOLDER'], model_image.filename)
34
  model_image.save(filename)
35
 
36
- # Get the product image URL from the request
37
- product_image_url = request.form.get('product_image_url')
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- # Here you would process the images with your pipeline
40
- # Example (you would replace this with the actual method to process the images):
41
- result = pipeline(product_image_url, model_image_path=filename)
 
42
 
43
- # Assuming the pipeline returns a path to the result image
44
- result_image_path = result["path"]
45
 
46
- # Read the result image and encode it in base64
47
- with open(result_image_path, "rb") as image_file:
48
  encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
49
 
 
50
  return jsonify(image=encoded_image), 200
51
 
52
  except Exception as e:
 
53
  return jsonify(error=str(e)), 500
54
 
55
- if __name__ == '__main__':
56
- app.run(debug=True)
 
 
1
  import tempfile
2
+ import threading
3
+ import time
4
+ import signal
5
+ from flask import Flask, request, jsonify, send_from_directory
6
  from flask_cors import CORS
7
+ import os
8
+ import traceback
9
+ import shutil
10
+ import base64
11
+ from gradio_client import Client, file
12
+
13
 
14
  app = Flask(__name__)
15
  CORS(app)
16
 
17
+ # Define Gradio client instance
18
+ client = Client("yisol/IDM-VTON")
19
 
20
+ # Directory to save uploaded and processed files
21
  UPLOAD_FOLDER = tempfile.mkdtemp()
22
+ RESULT_FOLDER = tempfile.mkdtemp()
23
+ if not os.path.exists(UPLOAD_FOLDER):
24
+ os.makedirs(UPLOAD_FOLDER)
25
+ if not os.path.exists(RESULT_FOLDER):
26
+ os.makedirs(RESULT_FOLDER)
27
+
28
  app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
29
+ app.config['RESULT_FOLDER'] = RESULT_FOLDER
30
+
31
+ def predict_with_timeout(model_image_path, product_image_url, timeout=600):
32
+ result = [None] # Mutable object to store the result
33
+
34
+ def target():
35
+ try:
36
+ result[0] = client.predict(
37
+ dict({"background": file(model_image_path), "layers": [], "composite": None}),
38
+ garm_img=file(product_image_url),
39
+ garment_des="Hello!!",
40
+ is_checked=True,
41
+ is_checked_crop=False,
42
+ denoise_steps=30,
43
+ seed=42,
44
+ api_name="/tryon"
45
+ )
46
+ except Exception as e:
47
+ result[0] = str(e)
48
+
49
+ thread = threading.Thread(target=target)
50
+ thread.start()
51
+ thread.join(timeout)
52
+ if thread.is_alive():
53
+ return None # Timeout
54
+ if isinstance(result[0], Exception):
55
+ return str(result[0]) # Return the error message
56
+ return result[0]
57
 
58
  @app.route('/')
59
  def index():
60
+ return {'hello': 'This is a wearon API'}
61
 
62
  @app.route('/process', methods=['POST'])
63
+ def predict():
64
  try:
65
+ # Get the product image URL from the request
66
+ product_image_url = request.form.get('product_image_url')
67
+
68
  # Handle the uploaded model image
69
  if 'model_image' not in request.files:
70
  return jsonify(error='No model image file provided'), 400
 
77
  filename = os.path.join(app.config['UPLOAD_FOLDER'], model_image.filename)
78
  model_image.save(filename)
79
 
80
+ base_path = os.getcwd()
81
+ full_filename = os.path.normpath(os.path.join(base_path, filename))
82
+
83
+ print("Product image = ", product_image_url)
84
+ print("Model image = ", full_filename)
85
+
86
+ # Perform prediction with a timeout
87
+ result = predict_with_timeout(full_filename, product_image_url)
88
+ if result is None:
89
+ return jsonify(error='Prediction timed out after 10 minutes'), 500
90
+
91
+ print(result)
92
+ # Extract the path of the first output image
93
+ output_image_path = result[0]
94
 
95
+ # Copy the output image to the RESULT_FOLDER
96
+ output_image_filename = os.path.basename(output_image_path)
97
+ local_output_path = os.path.join(app.config['RESULT_FOLDER'], output_image_filename)
98
+ shutil.copy(output_image_path, local_output_path)
99
 
100
+ # Remove the uploaded file after processing
101
+ os.remove(filename)
102
 
103
+ # Encode the output image in base64
104
+ with open(local_output_path, "rb") as image_file:
105
  encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
106
 
107
+ # Return the output image in JSON format
108
  return jsonify(image=encoded_image), 200
109
 
110
  except Exception as e:
111
+ traceback.print_exc()
112
  return jsonify(error=str(e)), 500
113
 
114
+ @app.route('/uploads/<filename>')
115
+ def uploaded_file(filename):
116
+ return send_from_directory(app.config['UPLOAD_FOLDER'], filename)