File size: 4,255 Bytes
297482a
39d4ddb
e692727
39d4ddb
d808b5b
04549f6
 
39d4ddb
 
1367e6b
7f1bd15
 
 
1367e6b
a5515e4
7f1bd15
 
 
 
 
 
 
a5515e4
e692727
a5515e4
7f1bd15
 
 
 
 
297482a
ecbcd62
 
 
 
 
 
a5515e4
297482a
 
 
e692727
ecbcd62
 
297482a
ecbcd62
de87bdc
ecbcd62
 
 
 
 
 
e692727
a5515e4
297482a
 
 
 
e692727
ecbcd62
297482a
 
 
1367e6b
e692727
ecbcd62
 
7f1bd15
 
1367e6b
297482a
e692727
 
de87bdc
e692727
1367e6b
e692727
 
1367e6b
 
e692727
 
 
 
 
 
 
1367e6b
e692727
 
 
 
 
 
 
 
ecbcd62
 
 
 
 
 
 
1367e6b
ecbcd62
 
 
 
 
 
 
1367e6b
ecbcd62
 
7f1bd15
297482a
7f1bd15
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import base64
import io
import time

import httpx
import streamlit as st
from openai import APIError, OpenAI
from PIL import Image

from .config import config


def txt2txt_generate(api_key, service, model, parameters, **kwargs):
    base_url = config.services[service]
    if service == "Hugging Face":
        base_url = f"{base_url}/{model}/v1"
    client = OpenAI(api_key=api_key, base_url=base_url)

    try:
        stream = client.chat.completions.create(stream=True, model=model, **parameters, **kwargs)
        return st.write_stream(stream)
    except APIError as e:
        # OpenAI uses this message for streaming errors and attaches response.error to error.body
        # https://github.com/openai/openai-python/blob/v1.0.0/src/openai/_streaming.py#L59
        return e.body if e.message == "An error occurred during streaming" else e.message
    except Exception as e:
        return str(e)


def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
    headers = {}
    if service == "Black Forest Labs":
        headers["x-key"] = api_key

    if service == "Fal":
        headers["Authorization"] = f"Key {api_key}"

    if service == "Hugging Face":
        headers["Authorization"] = f"Bearer {api_key}"
        headers["X-Wait-For-Model"] = "true"
        headers["X-Use-Cache"] = "false"

    if service == "Together":
        headers["Authorization"] = f"Bearer {api_key}"

    json = {}
    if service == "Black Forest Labs":
        json = {**parameters, **kwargs}
        json["prompt"] = inputs

    if service == "Fal":
        json = {**parameters, **kwargs}
        json["prompt"] = inputs

    if service == "Hugging Face":
        json = {
            "inputs": inputs,
            "parameters": {**parameters, **kwargs},
        }

    if service == "Together":
        json = {**parameters, **kwargs}
        json["prompt"] = inputs

    base_url = config.services[service]

    if service not in ["Together"]:
        base_url = f"{base_url}/{model}"

    try:
        response = httpx.post(base_url, headers=headers, json=json, timeout=config.txt2img.timeout)
        if response.status_code // 100 == 2:  # 2xx
            # BFL is async so we need to poll for result
            # https://api.bfl.ml/docs
            if service == "Black Forest Labs":
                id = response.json()["id"]
                url = f"{config.services[service]}/get_result?id={id}"

                retries = 0
                while retries < config.txt2img.timeout:
                    response = httpx.get(url, timeout=config.txt2img.timeout)
                    if response.status_code // 100 != 2:
                        return f"Error: {response.status_code} {response.text}"

                    if response.json()["status"] == "Ready":
                        image = httpx.get(
                            response.json()["result"]["sample"],
                            headers=headers,
                            timeout=config.txt2img.timeout,
                        )
                        return Image.open(io.BytesIO(image.content))

                    retries += 1
                    time.sleep(1)

                return "Error: API timeout"

            if service == "Fal":
                # Sync mode means wait for image base64 string instead of CDN link
                if parameters.get("sync_mode", True):
                    bytes = base64.b64decode(response.json()["images"][0]["url"].split(",")[-1])
                    return Image.open(io.BytesIO(bytes))
                else:
                    url = response.json()["images"][0]["url"]
                    image = httpx.get(url, headers=headers, timeout=config.txt2img.timeout)
                    return Image.open(io.BytesIO(image.content))

            if service == "Hugging Face":
                return Image.open(io.BytesIO(response.content))

            if service == "Together":
                url = response.json()["data"][0]["url"]
                image = httpx.get(url, headers=headers, timeout=config.txt2img.timeout)
                return Image.open(io.BytesIO(image.content))

        else:
            return f"Error: {response.status_code} {response.text}"
    except Exception as e:
        return str(e)