Spaces:
Running
Running
import base64 | |
import gradio as gr | |
import json | |
import mimetypes | |
import os | |
import requests | |
import time | |
MODEL_VERSION = os.environ['MODEL_VERSION'] | |
API_URL = os.environ['API_URL'] | |
API_KEY = os.environ['API_KEY'] | |
SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT') | |
MULTIMODAL_FLAG = os.environ.get('MULTIMODAL') | |
MODEL_CONTROL_DEFAULTS = json.loads(os.environ['MODEL_CONTROL_DEFAULTS']) | |
NAME_MAP = { | |
'system': os.environ.get('SYSTEM_NAME'), | |
'user': os.environ.get('USER_NAME'), | |
} | |
def respond( | |
message, | |
history, | |
max_tokens, | |
temperature, | |
top_p, | |
): | |
messages = [] | |
if SYSTEM_PROMPT is not None: | |
messages.append({ | |
'role': 'system', | |
'content': SYSTEM_PROMPT, | |
}) | |
for val in history: | |
messages.append({ | |
'role': val['role'], | |
'content': convert_content(val['content']), | |
}) | |
messages.append({ | |
'role': 'user', | |
'content': convert_content(message), | |
}) | |
for message in messages: | |
add_name_for_message(message) | |
data = { | |
'model': MODEL_VERSION, | |
'messages': messages, | |
'stream': True, | |
'max_tokens': max_tokens, | |
'temperature': temperature, | |
'top_p': top_p, | |
} | |
r = requests.post( | |
API_URL, | |
headers={ | |
'Content-Type': 'application/json', | |
'Authorization': 'Bearer {}'.format(API_KEY), | |
}, | |
data=json.dumps(data), | |
stream=True, | |
) | |
reply = '' | |
for row in r.iter_lines(): | |
if row.startswith(b'data:'): | |
data = json.loads(row[5:]) | |
if 'choices' not in data: | |
raise gr.Error('request failed') | |
choice = data['choices'][0] | |
if 'delta' in choice: | |
reply += choice['delta']['content'] | |
yield reply | |
elif 'message' in choice: | |
yield choice['message']['content'] | |
def add_name_for_message(message): | |
name = NAME_MAP.get(message['role']) | |
if name is not None: | |
message['name'] = name | |
def convert_content(content): | |
if isinstance(content, str): | |
return content | |
if isinstance(content, tuple): | |
return [{ | |
'type': 'image_url', | |
'image_url': { | |
'url': encode_base64(content[0]), | |
}, | |
}] | |
content_list = [] | |
for key, val in content.items(): | |
if key == 'text': | |
content_list.append({ | |
'type': 'text', | |
'text': val, | |
}) | |
elif key == 'files': | |
for f in val: | |
content_list.append({ | |
'type': 'image_url', | |
'image_url': { | |
'url': encode_base64(f), | |
}, | |
}) | |
return content_list | |
def encode_base64(path): | |
guess_type = mimetypes.guess_type(path)[0] | |
if not guess_type.startswith('image/'): | |
raise gr.Error('not an image ({}): {}'.format(guess_type, path)) | |
with open(path, 'rb') as handle: | |
data = handle.read() | |
return 'data:{};base64,{}'.format( | |
guess_type, | |
base64.b64encode(data).decode(), | |
) | |
demo = gr.ChatInterface( | |
respond, | |
multimodal=MULTIMODAL_FLAG == 'ON', | |
type='messages', | |
additional_inputs=[ | |
gr.Slider(minimum=1, maximum=1000000, value=MODEL_CONTROL_DEFAULTS['tokens_to_generate'], step=1, label='Tokens to generate'), | |
gr.Slider(minimum=0.1, maximum=1.0, value=MODEL_CONTROL_DEFAULTS['temperature'], step=0.05, label='Temperature'), | |
gr.Slider(minimum=0.1, maximum=1.0, value=MODEL_CONTROL_DEFAULTS['top_p'], step=0.05, label='Top-p (nucleus sampling)'), | |
], | |
) | |
if __name__ == '__main__': | |
demo.queue(default_concurrency_limit=50).launch() | |