Spaces:
Running
Running
import time | |
from io import BytesIO | |
import httpx | |
import streamlit as st | |
from anthropic import Anthropic | |
from anthropic import APIError as AnthropicAPIError | |
from openai import APIError as OpenAIAPIError | |
from openai import OpenAI | |
from PIL import Image | |
from .config import config | |
from .util import base64_decode_image_data_url | |
def txt2txt_generate(api_key, service, parameters, **kwargs): | |
model = parameters.get("model", "") | |
base_url = config.services[service].url | |
if service == "hf": | |
base_url = f"{base_url}/{model}/v1" | |
try: | |
if service == "anthropic": | |
client = Anthropic(api_key=api_key) | |
with client.messages.stream(**parameters, **kwargs) as stream: | |
return st.write_stream(stream.text_stream) | |
else: | |
client = OpenAI(api_key=api_key, base_url=base_url) | |
stream = client.chat.completions.create(stream=True, **parameters, **kwargs) | |
return st.write_stream(stream) | |
except AnthropicAPIError as e: | |
return e.message | |
except OpenAIAPIError as e: | |
# OpenAI uses this message for streaming errors and attaches response.error to error.body | |
# https://github.com/openai/openai-python/blob/v1.0.0/src/openai/_streaming.py#L59 | |
return e.body if e.message == "An error occurred during streaming" else e.message | |
except Exception as e: | |
return str(e) | |
def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs): | |
headers = {} | |
json = {**parameters, **kwargs} | |
if service == "bfl": | |
headers["x-key"] = api_key | |
json["prompt"] = inputs | |
if service == "fal": | |
headers["Authorization"] = f"Key {api_key}" | |
json["prompt"] = inputs | |
if service == "hf": | |
headers["Authorization"] = f"Bearer {api_key}" | |
headers["X-Wait-For-Model"] = "true" | |
headers["X-Use-Cache"] = "false" | |
json = { | |
"inputs": inputs, | |
"parameters": {**parameters, **kwargs}, | |
} | |
if service == "together": | |
headers["Authorization"] = f"Bearer {api_key}" | |
json["prompt"] = inputs | |
base_url = config.services[service].url | |
if service not in ["together"]: | |
base_url = f"{base_url}/{model}" | |
try: | |
timeout = config.timeout | |
response = httpx.post(base_url, headers=headers, json=json, timeout=timeout) | |
if response.status_code // 100 == 2: # 2xx | |
# BFL is async so we need to poll for result | |
# https://api.bfl.ml/docs | |
if service == "bfl": | |
id = response.json()["id"] | |
url = f"{config.services[service].url}/get_result?id={id}" | |
retries = 0 | |
while retries < timeout: | |
response = httpx.get(url, timeout=timeout) | |
if response.status_code // 100 != 2: | |
return f"Error: {response.status_code} {response.text}" | |
if response.json()["status"] == "Ready": | |
image = httpx.get( | |
response.json()["result"]["sample"], | |
headers=headers, | |
timeout=timeout, | |
) | |
return Image.open(BytesIO(image.content)) | |
retries += 1 | |
time.sleep(1) | |
return "Error: API timeout" | |
if service == "fal": | |
# Sync mode means wait for image base64 string instead of CDN link | |
url = response.json()["images"][0]["url"] | |
if parameters.get("sync_mode", True): | |
return base64_decode_image_data_url(url) | |
else: | |
image = httpx.get(url, headers=headers, timeout=timeout) | |
return Image.open(BytesIO(image.content)) | |
if service == "hf": | |
return Image.open(BytesIO(response.content)) | |
if service == "together": | |
url = response.json()["data"][0]["url"] | |
image = httpx.get(url, headers=headers, timeout=timeout) | |
return Image.open(BytesIO(image.content)) | |
else: | |
return f"Error: {response.status_code} {response.text}" | |
except Exception as e: | |
return str(e) | |