import os import re import pickle import base64 import requests import argparse import numpy as np import gradio as gr from functools import partial from PIL import Image SERVER_URL = os.getenv('SERVER_URL') def get_images(state): history = '' for i in range(len(state)): for j in range(len(state[i])): history += state[i][j] + '\n' for image_path in re.findall('image/[0-9,a-z]+\.png', history): if os.path.exists(image_path): continue data = {'method': 'get_image', 'args': [image_path], 'kwargs': {}} data = base64.b64encode(pickle.dumps(data)).decode('utf-8') response = requests.post(SERVER_URL, json=data) image = pickle.loads(base64.b64decode(response.json().encode('utf-8'))) image.save(image_path) def bot_request(method, *args, **kwargs): data = {'method': method, 'args': args, 'kwargs': kwargs} data = base64.b64encode(pickle.dumps(data)).decode('utf-8') response = requests.post(SERVER_URL, json=data) response = pickle.loads(base64.b64decode(response.json().encode('utf-8'))) if response is not None: state = response[0] get_images(state) return response def run_image(image, *args, **kwargs): if image is not None: width, height = image.size ratio = min(512 / width, 512 / height) width_new, height_new = (round(width * ratio), round(height * ratio)) width_new = int(np.round(width_new / 64.0)) * 64 height_new = int(np.round(height_new / 64.0)) * 64 image = image.resize((width_new, height_new)) image = image.convert('RGB') return bot_request('run_image', image, *args, **kwargs) def predict_example(temperature, top_p, max_new_token, keep_last_n_paragraphs, image, text): state = [] buffer = '' chatbot, state, text, buffer = run_image(image, state, text, buffer) chatbot, state, text, buffer = bot_request( 'run_text', text, state, temperature, top_p, max_new_token, keep_last_n_paragraphs, buffer) return chatbot, state, text, None, buffer if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--temperature', type=float, default=0.0, help='temperature for the llm model') parser.add_argument('--max_new_tokens', type=int, default=256, help='max number of new tokens to generate') parser.add_argument('--top_p', type=float, default=1.0, help='top_p for the llm model') parser.add_argument('--top_k', type=int, default=40, help='top_k for the llm model') parser.add_argument('--keep_last_n_paragraphs', type=int, default=0, help='keep last n paragraphs in the memory') args = parser.parse_args() examples = [ ['images/example-1.jpg', 'What is unusual about this image?'], ['images/example-2.jpg', 'Make the image look like a cartoon.'], ['images/example-3.jpg', 'Segment the tie in the image.'], ['images/example-4.jpg', 'Generate a man watching a sea based on the pose of the woman.'], ['images/example-5.jpg', 'Replace the dog with a monkey.'], ] if not os.path.exists('image'): os.makedirs('image') with gr.Blocks() as demo: state = gr.Chatbot([], visible=False) buffer = gr.Textbox('', visible=False) with gr.Row(): with gr.Column(scale=0.3): with gr.Row(): image = gr.Image(type='pil', label='input image') with gr.Row(): txt = gr.Textbox(lines=7, show_label=False, elem_id='textbox', placeholder='Enter text and press submit, or upload an image').style(container=False) with gr.Row(): submit = gr.Button('Submit') with gr.Row(): clear = gr.Button('Clear') with gr.Row(): llm_name = gr.Radio( ["Vicuna-13B"], label="LLM Backend", value="Vicuna-13B", interactive=True) keep_last_n_paragraphs = gr.Slider( minimum=0, maximum=3, value=args.keep_last_n_paragraphs, step=1, interactive=True, label='Remember Last N Paragraphs') max_new_token = gr.Slider( minimum=64, maximum=512, value=args.max_new_tokens, step=1, interactive=True, label='Max New Tokens') temperature = gr.Slider( minimum=0.0, maximum=1.0, value=args.temperature, step=0.1, interactive=True, visible=False, label='Temperature') top_p = gr.Slider( minimum=0.0, maximum=1.0, value=args.top_p, step=0.1, interactive=True, visible=False, label='Top P') with gr.Column(scale=0.7): chatbot = gr.Chatbot(elem_id='chatbot', label='🦙 GPT4Tools').style(height=690) image.upload(lambda: '', None, txt) submit.click(run_image, [image, state, txt, buffer], [chatbot, state, txt, buffer]).then( partial(bot_request, 'run_text'), [txt, state, temperature, top_p, max_new_token, keep_last_n_paragraphs, buffer], [chatbot, state, txt, buffer]).then( lambda: None, None, image) clear.click(partial(bot_request, 'clear')) clear.click(lambda: [[], [], '', ''], None, [chatbot, state, txt, buffer]) with gr.Row(): gr.Examples( examples=examples, fn=partial(predict_example, args.temperature, args.top_p, args.max_new_tokens, args.keep_last_n_paragraphs), inputs=[image, txt], outputs=[chatbot, state, txt, image, buffer], cache_examples=True, ) demo.queue(concurrency_count=6) demo.launch()