from PIL import Image import requests import io import base64 import jwt import time import logging import sys import asyncio from requests.exceptions import RequestException # Set up logging logging.basicConfig( level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('virtual_tryon.log'), logging.StreamHandler(sys.stdout) ] ) logger = logging.getLogger(__name__) # Constants VALID_CLOTH_TYPES = ["upper", "lower", "full"] VALID_IMAGE_SIZES = ["256x256", "512x512", "768x768"] DEFAULT_IMAGE_SIZE = "512x512" DEFAULT_NUM_STEPS = 30 DEFAULT_GUIDANCE_SCALE = 7.5 DEFAULT_SEED = 42 API_BASE_URL = "https://api.klingai.com" def generate_api_token(access_key, secret_key): """Generate JWT token for API authentication""" try: current_time = int(time.time()) payload = { "iss": access_key, "exp": current_time + 1800, # 30 minutes expiration "nbf": current_time } logger.debug(f"Generating token with payload: {payload}") token = jwt.encode(payload, secret_key, algorithm="HS256") logger.debug("Token generated successfully") return token except Exception as e: logger.error(f"Error generating token: {str(e)}") raise def encode_image_to_base64(image): """Convert PIL Image to base64 string""" try: if isinstance(image, Image.Image): buffered = io.BytesIO() image.save(buffered, format="PNG") base64_string = base64.b64encode(buffered.getvalue()).decode('utf-8') logger.debug(f"Image encoded to base64 successfully. Length: {len(base64_string)}") return base64_string logger.error("Input is not a PIL Image") return None except Exception as e: logger.error(f"Error encoding image to base64: {str(e)}") return None async def check_task_status(task_id, access_key, secret_key): """Check the status of a task""" max_attempts = 3 wait_interval = 20 attempt = 1 while attempt <= max_attempts: await asyncio.sleep(wait_interval) logger.info(f"Checking task status (Attempt {attempt}/{max_attempts})...") try: # Generate new token for status check token = generate_api_token(access_key, secret_key) headers = { "Content-Type": "application/json", "Authorization": f"Bearer {token}" } # Status check endpoint url = f"{API_BASE_URL}/v1/images/kolors-virtual-try-on/{task_id}" response = requests.get(url, headers=headers, verify=False) logger.debug(f"Status check response: {response.text}") result = response.json() if response.status_code == 200 and result.get('code') == 0: data = result.get('data', {}) task_status = data.get('task_status', '').lower() if task_status in ['completed', 'succeed']: images = data.get('task_result', {}).get('images', []) if images: image_url = images[0].get('url') return None, image_url else: return "No images found in the task result.", None elif task_status in ['failed', 'error']: error_message = data.get('task_status_msg', 'Task failed.') return f"Task failed: {error_message}", None else: logger.info(f"Task status: {task_status}. Waiting for next attempt...") else: error_message = result.get('message', 'Unknown error occurred.') logger.error(f"Error fetching task status: {error_message}") except Exception as e: logger.error(f"Error checking task status: {str(e)}") attempt += 1 return "Task did not complete within the expected time.", None async def apply_virtual_tryon_async( person_image, garment_image, access_key, secret_key ): """Apply virtual try-on using Kling API asynchronously""" try: logger.info("Starting virtual try-on process") # Generate API token jwt_token = generate_api_token(access_key, secret_key) if not jwt_token: return None, "Failed to generate JWT token" # Ensure token is string if isinstance(jwt_token, bytes): jwt_token = jwt_token.decode('utf-8') # Prepare images logger.debug("Preparing images") person_base64 = encode_image_to_base64(person_image) garment_base64 = encode_image_to_base64(garment_image) if not person_base64 or not garment_base64: logger.error("Failed to convert images to base64") return None, "Error converting images to base64" # Prepare request headers = { "Content-Type": "application/json", "Authorization": f"Bearer {jwt_token}" } # Payload structure payload = { "model_name": "kolors-virtual-try-on-v1", "human_image": person_base64, "cloth_image": garment_base64 } # Submit task url = f"{API_BASE_URL}/v1/images/kolors-virtual-try-on" logger.debug(f"Making API request to {url}") response = requests.post(url, headers=headers, json=payload, verify=False) result = response.json() if response.status_code == 200 and result.get('code') == 0: task_id = result.get('data', {}).get('task_id') if not task_id: return None, "No task ID received" logger.info(f"Task submitted successfully. Task ID: {task_id}") # Check task status error_message, image_url = await check_task_status(task_id, access_key, secret_key) if error_message: return None, error_message # Download result image try: image_response = requests.get(image_url) if image_response.status_code == 200: return Image.open(io.BytesIO(image_response.content)), "Success" else: return None, f"Failed to download result image: {image_response.status_code}" except Exception as e: return None, f"Error downloading result image: {str(e)}" else: error_msg = result.get('message', 'Unknown error') logger.error(f"API Error: {error_msg}") return None, f"API Error: {error_msg}" except Exception as e: logger.error(f"Unexpected Error: {str(e)}") return None, f"Error: {str(e)}" def apply_virtual_tryon( person_image, garment_image, access_key, secret_key, cloth_type="upper", image_size="512x512", num_steps=DEFAULT_NUM_STEPS, guidance_scale=DEFAULT_GUIDANCE_SCALE, seed=DEFAULT_SEED ): """Synchronous wrapper for async virtual try-on function""" loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: return loop.run_until_complete( apply_virtual_tryon_async( person_image, garment_image, access_key, secret_key ) ) finally: loop.close()