import base64
import gradio as gr
import json
import mimetypes
import os
import requests
import time


MODEL_VERSION = os.environ['MODEL_VERSION']
API_URL = os.environ['API_URL']
API_KEY = os.environ['API_KEY']
SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT')
MULTIMODAL_FLAG = os.environ.get('MULTIMODAL')
MODEL_CONTROL_DEFAULTS = json.loads(os.environ['MODEL_CONTROL_DEFAULTS'])
NAME_MAP = {
    'system': os.environ.get('SYSTEM_NAME'),
    'user': os.environ.get('USER_NAME'),
}


def respond(
    message,
    history,
    max_tokens,
    temperature,
    top_p,
):
    messages = []
    if SYSTEM_PROMPT is not None:
        messages.append({
            'role': 'system',
            'content': SYSTEM_PROMPT,
        })
    for val in history:
        messages.append({
            'role': val['role'],
            'content': convert_content(val['content']),
        })
    messages.append({
        'role': 'user',
        'content': convert_content(message),
    })
    for message in messages:
        add_name_for_message(message)

    data = {
        'model': MODEL_VERSION,
        'messages': messages,
        'stream': True,
        'max_tokens': max_tokens,
        'temperature': temperature,
        'top_p': top_p,
    }
    r = requests.post(
        API_URL,
        headers={
            'Content-Type': 'application/json',
            'Authorization': 'Bearer {}'.format(API_KEY),
        },
        data=json.dumps(data),
        stream=True,
    )
    reply = ''
    for row in r.iter_lines():
        if row.startswith(b'data:'):
            data = json.loads(row[5:])
            if 'choices' not in data:
                raise gr.Error('request failed')
            choice = data['choices'][0]
            if 'delta' in choice:
                reply += choice['delta']['content']
                yield reply
            elif 'message' in choice:
                yield choice['message']['content']


def add_name_for_message(message):
    name = NAME_MAP.get(message['role'])
    if name is not None:
        message['name'] = name


def convert_content(content):
    if isinstance(content, str):
        return content
    if isinstance(content, tuple):
        return [{
            'type': 'image_url',
            'image_url': {
                'url': encode_base64(content[0]),
            },
        }]
    content_list = []
    for key, val in content.items():
        if key == 'text':
            content_list.append({
                'type': 'text',
                'text': val,
            })
        elif key == 'files':
            for f in val:
                content_list.append({
                    'type': 'image_url',
                    'image_url': {
                        'url': encode_base64(f),
                    },
                })
    return content_list


def encode_base64(path):
    guess_type = mimetypes.guess_type(path)[0]
    if not guess_type.startswith('image/'):
        raise gr.Error('not an image ({}): {}'.format(guess_type, path))
    with open(path, 'rb') as handle:
        data = handle.read()
        return 'data:{};base64,{}'.format(
            guess_type,
            base64.b64encode(data).decode(),
        )


demo = gr.ChatInterface(
    respond,
    multimodal=MULTIMODAL_FLAG == 'ON',
    type='messages',
    additional_inputs=[
        gr.Slider(minimum=1, maximum=1000000, value=MODEL_CONTROL_DEFAULTS['tokens_to_generate'], step=1, label='Tokens to generate'),
        gr.Slider(minimum=0.1, maximum=1.0, value=MODEL_CONTROL_DEFAULTS['temperature'], step=0.05, label='Temperature'),
        gr.Slider(minimum=0.1, maximum=1.0, value=MODEL_CONTROL_DEFAULTS['top_p'], step=0.05, label='Top-p (nucleus sampling)'),
    ],
)


if __name__ == '__main__':
    demo.queue(default_concurrency_limit=50).launch()