Spaces:
Sleeping
Sleeping
import base64 | |
import time | |
from io import BytesIO | |
from typing import Any, List, Mapping, Optional, Tuple, Union | |
from aiohttp import ClientResponse | |
from httpx import Headers, Response | |
from litellm.llms.base_llm.chat.transformation import ( | |
BaseLLMException, | |
LiteLLMLoggingObj, | |
) | |
from litellm.types.llms.openai import OpenAIImageVariationOptionalParams | |
from litellm.types.utils import ( | |
FileTypes, | |
HttpHandlerRequestFields, | |
ImageObject, | |
ImageResponse, | |
) | |
from ...base_llm.image_variations.transformation import BaseImageVariationConfig | |
from ..common_utils import TopazException, TopazModelInfo | |
class TopazImageVariationConfig(TopazModelInfo, BaseImageVariationConfig): | |
def get_supported_openai_params( | |
self, model: str | |
) -> List[OpenAIImageVariationOptionalParams]: | |
return ["response_format", "size"] | |
def get_complete_url( | |
self, | |
api_base: Optional[str], | |
api_key: Optional[str], | |
model: str, | |
optional_params: dict, | |
litellm_params: dict, | |
stream: Optional[bool] = None, | |
) -> str: | |
api_base = api_base or "https://api.topazlabs.com" | |
return f"{api_base}/image/v1/enhance" | |
def map_openai_params( | |
self, | |
non_default_params: dict, | |
optional_params: dict, | |
model: str, | |
drop_params: bool, | |
) -> dict: | |
for k, v in non_default_params.items(): | |
if k == "response_format": | |
optional_params["output_format"] = v | |
elif k == "size": | |
split_v = v.split("x") | |
assert len(split_v) == 2, "size must be in the format of widthxheight" | |
optional_params["output_width"] = split_v[0] | |
optional_params["output_height"] = split_v[1] | |
return optional_params | |
def prepare_file_tuple( | |
self, | |
file_data: FileTypes, | |
) -> Tuple[str, Optional[FileTypes], str, Mapping[str, str]]: | |
""" | |
Convert various file input formats to a consistent tuple format for HTTPX | |
Returns: (filename, file_content, content_type, headers) | |
""" | |
# Default values | |
filename = "image.png" | |
content: Optional[FileTypes] = None | |
content_type = "image/png" | |
headers: Mapping[str, str] = {} | |
if isinstance(file_data, (bytes, BytesIO)): | |
# Case 1: Just file content | |
content = file_data | |
elif isinstance(file_data, tuple): | |
if len(file_data) == 2: | |
# Case 2: (filename, content) | |
filename = file_data[0] or filename | |
content = file_data[1] | |
elif len(file_data) == 3: | |
# Case 3: (filename, content, content_type) | |
filename = file_data[0] or filename | |
content = file_data[1] | |
content_type = file_data[2] or content_type | |
elif len(file_data) == 4: | |
# Case 4: (filename, content, content_type, headers) | |
filename = file_data[0] or filename | |
content = file_data[1] | |
content_type = file_data[2] or content_type | |
headers = file_data[3] | |
return (filename, content, content_type, headers) | |
def transform_request_image_variation( | |
self, | |
model: Optional[str], | |
image: FileTypes, | |
optional_params: dict, | |
headers: dict, | |
) -> HttpHandlerRequestFields: | |
request_params = HttpHandlerRequestFields( | |
files={"image": self.prepare_file_tuple(image)}, | |
data=optional_params, | |
) | |
return request_params | |
def _common_transform_response_image_variation( | |
self, | |
image_content: bytes, | |
response_ms: float, | |
) -> ImageResponse: | |
# Convert to base64 | |
base64_image = base64.b64encode(image_content).decode("utf-8") | |
return ImageResponse( | |
created=int(time.time()), | |
data=[ | |
ImageObject( | |
b64_json=base64_image, | |
url=None, | |
revised_prompt=None, | |
) | |
], | |
response_ms=response_ms, | |
) | |
async def async_transform_response_image_variation( | |
self, | |
model: Optional[str], | |
raw_response: ClientResponse, | |
model_response: ImageResponse, | |
logging_obj: LiteLLMLoggingObj, | |
request_data: dict, | |
image: FileTypes, | |
optional_params: dict, | |
litellm_params: dict, | |
encoding: Any, | |
api_key: Optional[str] = None, | |
) -> ImageResponse: | |
image_content = await raw_response.read() | |
response_ms = logging_obj.get_response_ms() | |
return self._common_transform_response_image_variation( | |
image_content, response_ms | |
) | |
def transform_response_image_variation( | |
self, | |
model: Optional[str], | |
raw_response: Response, | |
model_response: ImageResponse, | |
logging_obj: LiteLLMLoggingObj, | |
request_data: dict, | |
image: FileTypes, | |
optional_params: dict, | |
litellm_params: dict, | |
encoding: Any, | |
api_key: Optional[str] = None, | |
) -> ImageResponse: | |
image_content = raw_response.content | |
response_ms = ( | |
raw_response.elapsed.total_seconds() * 1000 | |
) # Convert to milliseconds | |
return self._common_transform_response_image_variation( | |
image_content, response_ms | |
) | |
def get_error_class( | |
self, error_message: str, status_code: int, headers: Union[dict, Headers] | |
) -> BaseLLMException: | |
return TopazException( | |
status_code=status_code, | |
message=error_message, | |
headers=headers, | |
) | |