Abhlash commited on
Commit
c65d70c
·
verified ·
1 Parent(s): 42935f3

Create virtual_tryon.py

Browse files
Files changed (1) hide show
  1. virtual_tryon.py +222 -0
virtual_tryon.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import requests
3
+ import io
4
+ import base64
5
+ import jwt
6
+ import time
7
+ import logging
8
+ import sys
9
+ import asyncio
10
+ from requests.exceptions import RequestException
11
+
12
+ # Set up logging
13
+ logging.basicConfig(
14
+ level=logging.DEBUG,
15
+ format='%(asctime)s - %(levelname)s - %(message)s',
16
+ handlers=[
17
+ logging.FileHandler('virtual_tryon.log'),
18
+ logging.StreamHandler(sys.stdout)
19
+ ]
20
+ )
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Constants
24
+ VALID_CLOTH_TYPES = ["upper", "lower", "full"]
25
+ VALID_IMAGE_SIZES = ["256x256", "512x512", "768x768"]
26
+ DEFAULT_IMAGE_SIZE = "512x512"
27
+ DEFAULT_NUM_STEPS = 30
28
+ DEFAULT_GUIDANCE_SCALE = 7.5
29
+ DEFAULT_SEED = 42
30
+ API_BASE_URL = "https://api.klingai.com"
31
+
32
+ def generate_api_token(access_key, secret_key):
33
+ """Generate JWT token for API authentication"""
34
+ try:
35
+ current_time = int(time.time())
36
+ payload = {
37
+ "iss": access_key,
38
+ "exp": current_time + 1800, # 30 minutes expiration
39
+ "nbf": current_time
40
+ }
41
+
42
+ logger.debug(f"Generating token with payload: {payload}")
43
+ token = jwt.encode(payload, secret_key, algorithm="HS256")
44
+ logger.debug("Token generated successfully")
45
+ return token
46
+
47
+ except Exception as e:
48
+ logger.error(f"Error generating token: {str(e)}")
49
+ raise
50
+
51
+ def encode_image_to_base64(image):
52
+ """Convert PIL Image to base64 string"""
53
+ try:
54
+ if isinstance(image, Image.Image):
55
+ buffered = io.BytesIO()
56
+ image.save(buffered, format="PNG")
57
+ base64_string = base64.b64encode(buffered.getvalue()).decode('utf-8')
58
+ logger.debug(f"Image encoded to base64 successfully. Length: {len(base64_string)}")
59
+ return base64_string
60
+ logger.error("Input is not a PIL Image")
61
+ return None
62
+ except Exception as e:
63
+ logger.error(f"Error encoding image to base64: {str(e)}")
64
+ return None
65
+
66
+ async def check_task_status(task_id, access_key, secret_key):
67
+ """Check the status of a task"""
68
+ max_attempts = 3
69
+ wait_interval = 20
70
+ attempt = 1
71
+
72
+ while attempt <= max_attempts:
73
+ await asyncio.sleep(wait_interval)
74
+ logger.info(f"Checking task status (Attempt {attempt}/{max_attempts})...")
75
+
76
+ try:
77
+ # Generate new token for status check
78
+ token = generate_api_token(access_key, secret_key)
79
+
80
+ headers = {
81
+ "Content-Type": "application/json",
82
+ "Authorization": f"Bearer {token}"
83
+ }
84
+
85
+ # Status check endpoint
86
+ url = f"{API_BASE_URL}/v1/images/kolors-virtual-try-on/{task_id}"
87
+
88
+ response = requests.get(url, headers=headers, verify=False)
89
+ logger.debug(f"Status check response: {response.text}")
90
+
91
+ result = response.json()
92
+ if response.status_code == 200 and result.get('code') == 0:
93
+ data = result.get('data', {})
94
+ task_status = data.get('task_status', '').lower()
95
+
96
+ if task_status in ['completed', 'succeed']:
97
+ images = data.get('task_result', {}).get('images', [])
98
+ if images:
99
+ image_url = images[0].get('url')
100
+ return None, image_url
101
+ else:
102
+ return "No images found in the task result.", None
103
+ elif task_status in ['failed', 'error']:
104
+ error_message = data.get('task_status_msg', 'Task failed.')
105
+ return f"Task failed: {error_message}", None
106
+ else:
107
+ logger.info(f"Task status: {task_status}. Waiting for next attempt...")
108
+ else:
109
+ error_message = result.get('message', 'Unknown error occurred.')
110
+ logger.error(f"Error fetching task status: {error_message}")
111
+
112
+ except Exception as e:
113
+ logger.error(f"Error checking task status: {str(e)}")
114
+
115
+ attempt += 1
116
+
117
+ return "Task did not complete within the expected time.", None
118
+
119
+ async def apply_virtual_tryon_async(
120
+ person_image,
121
+ garment_image,
122
+ access_key,
123
+ secret_key
124
+ ):
125
+ """Apply virtual try-on using Kling API asynchronously"""
126
+ try:
127
+ logger.info("Starting virtual try-on process")
128
+
129
+ # Generate API token
130
+ jwt_token = generate_api_token(access_key, secret_key)
131
+ if not jwt_token:
132
+ return None, "Failed to generate JWT token"
133
+
134
+ # Ensure token is string
135
+ if isinstance(jwt_token, bytes):
136
+ jwt_token = jwt_token.decode('utf-8')
137
+
138
+ # Prepare images
139
+ logger.debug("Preparing images")
140
+ person_base64 = encode_image_to_base64(person_image)
141
+ garment_base64 = encode_image_to_base64(garment_image)
142
+
143
+ if not person_base64 or not garment_base64:
144
+ logger.error("Failed to convert images to base64")
145
+ return None, "Error converting images to base64"
146
+
147
+ # Prepare request
148
+ headers = {
149
+ "Content-Type": "application/json",
150
+ "Authorization": f"Bearer {jwt_token}"
151
+ }
152
+
153
+ # Payload structure
154
+ payload = {
155
+ "model_name": "kolors-virtual-try-on-v1",
156
+ "human_image": person_base64,
157
+ "cloth_image": garment_base64
158
+ }
159
+
160
+ # Submit task
161
+ url = f"{API_BASE_URL}/v1/images/kolors-virtual-try-on"
162
+ logger.debug(f"Making API request to {url}")
163
+
164
+ response = requests.post(url, headers=headers, json=payload, verify=False)
165
+ result = response.json()
166
+
167
+ if response.status_code == 200 and result.get('code') == 0:
168
+ task_id = result.get('data', {}).get('task_id')
169
+ if not task_id:
170
+ return None, "No task ID received"
171
+
172
+ logger.info(f"Task submitted successfully. Task ID: {task_id}")
173
+
174
+ # Check task status
175
+ error_message, image_url = await check_task_status(task_id, access_key, secret_key)
176
+
177
+ if error_message:
178
+ return None, error_message
179
+
180
+ # Download result image
181
+ try:
182
+ image_response = requests.get(image_url)
183
+ if image_response.status_code == 200:
184
+ return Image.open(io.BytesIO(image_response.content)), "Success"
185
+ else:
186
+ return None, f"Failed to download result image: {image_response.status_code}"
187
+ except Exception as e:
188
+ return None, f"Error downloading result image: {str(e)}"
189
+ else:
190
+ error_msg = result.get('message', 'Unknown error')
191
+ logger.error(f"API Error: {error_msg}")
192
+ return None, f"API Error: {error_msg}"
193
+
194
+ except Exception as e:
195
+ logger.error(f"Unexpected Error: {str(e)}")
196
+ return None, f"Error: {str(e)}"
197
+
198
+ def apply_virtual_tryon(
199
+ person_image,
200
+ garment_image,
201
+ access_key,
202
+ secret_key,
203
+ cloth_type="upper",
204
+ image_size="512x512",
205
+ num_steps=DEFAULT_NUM_STEPS,
206
+ guidance_scale=DEFAULT_GUIDANCE_SCALE,
207
+ seed=DEFAULT_SEED
208
+ ):
209
+ """Synchronous wrapper for async virtual try-on function"""
210
+ loop = asyncio.new_event_loop()
211
+ asyncio.set_event_loop(loop)
212
+ try:
213
+ return loop.run_until_complete(
214
+ apply_virtual_tryon_async(
215
+ person_image,
216
+ garment_image,
217
+ access_key,
218
+ secret_key
219
+ )
220
+ )
221
+ finally:
222
+ loop.close()