import cv2 import numpy as np import requests from requests_toolbelt.multipart.encoder import MultipartEncoder from urllib.parse import urlparse import logging import json from io import BytesIO from dataclasses import dataclass @dataclass class TryOnDiffusionAPIResponse: status_code: int image: np.ndarray = None response_data: bytes = None error_details: str = None seed: int = None class TryOnDiffusionClient: def __init__(self, base_url: str = "http://localhost:8000/", api_key: str = ""): self._logger = logging.getLogger("try_on_diffusion_client") self._base_url = base_url self._api_key = api_key if self._base_url[-1] == "/": self._base_url = self._base_url[:-1] parsed_url = urlparse(self._base_url) self._rapidapi_host = parsed_url.netloc if parsed_url.netloc.endswith(".rapidapi.com") else None if self._rapidapi_host is not None: self._logger.info(f"Using RapidAPI proxy: {self._rapidapi_host}") @staticmethod def _image_to_upload_file(image: np.ndarray) -> tuple: _, jpeg_data = cv2.imencode(".jpg", image, [int(cv2.IMWRITE_JPEG_QUALITY), 99]) jpeg_data = jpeg_data.tobytes() fp = BytesIO(jpeg_data) return "image.jpg", fp, "image/jpeg" def try_on_file( self, clothing_image: np.ndarray = None, clothing_prompt: str = None, avatar_image: np.ndarray = None, avatar_prompt: str = None, avatar_sex: str = None, background_image: np.ndarray = None, background_prompt: str = None, seed: int = -1, raw_response: bool = False, ) -> TryOnDiffusionAPIResponse: url = self._base_url + "/try-on-file" request_data = {"seed": str(seed)} if clothing_image is not None: request_data["clothing_image"] = self._image_to_upload_file(clothing_image) if clothing_prompt is not None: request_data["clothing_prompt"] = clothing_prompt if avatar_image is not None: request_data["avatar_image"] = self._image_to_upload_file(avatar_image) if avatar_prompt is not None: request_data["avatar_prompt"] = avatar_prompt if avatar_sex is not None: request_data["avatar_sex"] = avatar_sex if background_image is not None: request_data["background_image"] = self._image_to_upload_file(background_image) if background_prompt is not None: request_data["background_prompt"] = background_prompt multipart_data = MultipartEncoder(fields=request_data) headers = {"Content-Type": multipart_data.content_type} if self._rapidapi_host is not None: headers["X-RapidAPI-Key"] = self._api_key headers["X-RapidAPI-Host"] = self._rapidapi_host else: headers["X-API-Key"] = self._api_key try: response = requests.post( url, data=multipart_data, headers=headers, ) except Exception as e: self._logger.error(e, exc_info=True) return TryOnDiffusionAPIResponse(status_code=0) if response.status_code != 200: self._logger.warning(f"Request failed, status code: {response.status_code}, response: {response.content}") result = TryOnDiffusionAPIResponse(status_code=response.status_code) if not raw_response and response.status_code == 200: try: result.image = cv2.imdecode(np.frombuffer(response.content, np.uint8), cv2.IMREAD_COLOR) except: result.image = None else: result.response_data = response.content if result.status_code == 200: if "X-Seed" in response.headers: result.seed = int(response.headers["X-Seed"]) else: try: response_json = ( json.loads(result.response_data.decode("utf-8")) if result.response_data is not None else None ) if response_json is not None and "detail" in response_json: result.error_details = response_json["detail"] except: result.error_details = None return result