tuxedocat's picture
Clean up
9e03ce7
raw
history blame contribute delete
No virus
4.54 kB
import base64
import io
from functools import partial
import gradio as gr
import httpx
from const import CSS, FOOTER, HEADER, MODELS, PLACEHOLDER
from openai import OpenAI
from PIL import Image
from cycloud.auth import load_default_credentials
def get_headers(host: str) -> dict:
creds = load_default_credentials()
return {
"Authorization": f"Bearer {creds.access_token}",
"Host": host,
"Accept": "application/json",
"Content-Type": "application/json",
}
def proxy(request: httpx.Request, model_info: dict) -> httpx.Request:
request.url = request.url.copy_with(path=model_info["endpoint"])
request.headers.update(get_headers(host=model_info["host"].replace("https://", "")))
return request
def encode_image_with_pillow(image_path: str) -> str:
with Image.open(image_path) as img:
img.thumbnail((384, 384))
buffered = io.BytesIO()
img.convert("RGB").save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def call_chat_api(message, history, model_name):
if message["files"]:
if isinstance(message["files"], dict):
image = message["files"]["path"]
else:
image = message["files"][-1]
else:
for hist in history:
if isinstance(hist[0], tuple):
image = hist[0][0]
img_base64 = encode_image_with_pillow(image)
history_openai_format = [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{img_base64}",
},
},
],
}
]
if len(history) == 0:
history_openai_format[0]["content"].append(
{"type": "text", "text": message["text"]}
)
else:
for human, assistant in history[1:]:
if len(history_openai_format) == 1:
history_openai_format[0]["content"].append(
{"type": "text", "text": human}
)
else:
history_openai_format.append({"role": "user", "content": human})
history_openai_format.append({"role": "assistant", "content": assistant})
history_openai_format.append({"role": "user", "content": message["text"]})
client = OpenAI(
api_key="",
base_url=MODELS[model_name]["host"],
http_client=httpx.Client(
event_hooks={
"request": [partial(proxy, model_info=MODELS[model_name])],
},
verify=False,
),
)
stream = client.chat.completions.create(
model=f"/data/cyberagent/{model_name}",
messages=history_openai_format,
temperature=0.2,
top_p=1.0,
max_tokens=1024,
stream=True,
extra_body={"repetition_penalty": 1.1},
)
message = ""
for chunk in stream:
content = chunk.choices[0].delta.content or ""
message = message + content
yield message
def run():
chatbot = gr.Chatbot(
elem_id="chatbot", placeholder=PLACEHOLDER, scale=1, height=700
)
chat_input = gr.MultimodalTextbox(
interactive=True,
file_types=["image"],
placeholder="Enter message or upload file...",
show_label=False,
)
with gr.Blocks(css=CSS) as demo:
gr.Markdown(HEADER)
with gr.Row():
model_selector = gr.Dropdown(
choices=MODELS.keys(),
value=list(MODELS.keys())[0],
label="Model",
)
gr.ChatInterface(
fn=call_chat_api,
stop_btn="Stop Generation",
examples=[
[
{
"text": "ใ“ใฎ็”ปๅƒใ‚’่ฉณใ—ใ่ชฌๆ˜Žใ—ใฆใใ ใ•ใ„ใ€‚",
"files": ["./examples/cat.jpg"],
},
],
[
{
"text": "ใ“ใฎๆ–™็†ใฏใฉใ‚“ใชๅ‘ณใŒใ™ใ‚‹ใ‹่ฉณใ—ใๆ•™ใˆใฆใใ ใ•ใ„ใ€‚",
"files": ["./examples/takoyaki.jpg"],
},
],
],
multimodal=True,
textbox=chat_input,
chatbot=chatbot,
additional_inputs=[model_selector],
)
gr.Markdown(FOOTER)
demo.queue().launch(share=False)
if __name__ == "__main__":
run()