Spaces:
Runtime error
Runtime error
import os | |
import numpy as np | |
import datetime | |
import json | |
from typing import Optional | |
import transformers | |
from dataclasses import dataclass, field | |
import io | |
import spaces | |
import base64 | |
from PIL import Image | |
import gradio as gr | |
import time | |
import hashlib | |
from utils import build_logger | |
from conversation import conv_seed_llama2 | |
import hydra | |
import pyrootutils | |
import torch | |
import re | |
import time | |
from omegaconf import OmegaConf | |
from flask import Flask | |
import json | |
from typing import Optional | |
import cv2 | |
from diffusers import AutoencoderKL, UNet2DConditionModel, EulerDiscreteScheduler, StableDiffusionImg2ImgPipeline | |
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) | |
BOI_TOKEN = '<img>' | |
EOI_TOKEN = '</img>' | |
IMG_TOKEN = '<img_{:05d}>' | |
IMG_FLAG = '<image>' | |
num_img_in_tokens = 64 | |
num_img_out_tokens = 64 | |
resolution_grids = ['1x1', '1x2', '1x3', '1x4', '1x5', '1x6', '1x10', '2x1', '3x1', '4x1', '5x1', '6x1', '10x1', '2x2', | |
'2x3', '3x2', '2x4', '4x2'] | |
base_resolution = 448 | |
app = Flask(__name__) | |
def decode_image(encoded_image: str) -> Image: | |
decoded_bytes = base64.b64decode(encoded_image.encode('utf-8')) | |
buffer = io.BytesIO(decoded_bytes) | |
image = Image.open(buffer) | |
return image | |
def encode_image(image: Image.Image, format: str = 'PNG') -> str: | |
with io.BytesIO() as buffer: | |
image.save(buffer, format=format) | |
encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
return encoded_image | |
class Arguments: | |
image_transform: Optional[str] = field(default='configs/processer/qwen_448_transform.yaml', | |
metadata={"help": "config path of image transform"}) | |
tokenizer: Optional[str] = field(default='configs/tokenizer/clm_llama_tokenizer.yaml', | |
metadata={"help": "config path of tokenizer used to initialize tokenizer"}) | |
llm: Optional[str] = field(default='configs/clm_models/llama2chat7b_lora.yaml', metadata={"help": "config path of llm"}) | |
visual_encoder: Optional[str] = field(default='configs/visual_tokenizer/qwen_vitg_448.yaml', | |
metadata={"help": "config path of visual encoder"}) | |
sd_adapter: Optional[str] = field( | |
default='configs/detokenizer/detokenizer_sdxl_qwen_vit_adapted.yaml', | |
metadata={"help": "config path of sd adapter"}) | |
agent: Optional[str] = field(default='configs/clm_models/agent_7b_sft.yaml', | |
metadata={"help": "config path of agent model"}) | |
diffusion_path: Optional[str] = field(default='stabilityai/stable-diffusion-xl-base-1.0', | |
metadata={"help": "diffusion model path"}) | |
port: Optional[str] = field(default=80, metadata={"help": "network port"}) | |
llm_device: Optional[str] = field(default='cuda:0', metadata={"help": "llm device"}) | |
vit_sd_device: Optional[str] = field(default='cuda:0', metadata={"help": "sd and vit device"}) | |
dtype: Optional[str] = field(default='fp16', metadata={"help": "mix percision"}) | |
parser = transformers.HfArgumentParser(Arguments) | |
args, = parser.parse_args_into_dataclasses() | |
class LLMService: | |
def __init__(self, args) -> None: | |
self.llm_device = args.llm_device | |
self.vit_sd_device = args.vit_sd_device | |
dtype = args.dtype | |
if dtype == 'fp16': | |
self.dtype = torch.float16 | |
elif dtype == 'bf16': | |
self.dtype = torch.bfloat16 | |
else: | |
raise ValueError | |
image_transform_cfg = OmegaConf.load(args.image_transform) | |
self.image_transform = hydra.utils.instantiate(image_transform_cfg) | |
tokenizer_cfg = OmegaConf.load(args.tokenizer) | |
self.tokenizer = hydra.utils.instantiate(tokenizer_cfg) | |
visual_encoder_cfg = OmegaConf.load(args.visual_encoder) | |
self.visual_encoder = hydra.utils.instantiate(visual_encoder_cfg) | |
self.visual_encoder.eval().to(self.vit_sd_device, dtype=self.dtype) | |
print('Init visual encoder done') | |
llm_cfg = OmegaConf.load(args.llm) | |
llm = hydra.utils.instantiate(llm_cfg, torch_dtype=self.dtype) | |
print('Init llm done.') | |
agent_cfg = OmegaConf.load(args.agent) | |
self.agent = hydra.utils.instantiate(agent_cfg, llm=llm) | |
self.agent.eval().to(self.llm_device, dtype=self.dtype) | |
print('Init agent mdoel Done') | |
noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.diffusion_path, subfolder="scheduler") | |
vae = AutoencoderKL.from_pretrained(args.diffusion_path, subfolder="vae").to(self.vit_sd_device, | |
dtype=self.dtype) | |
unet = UNet2DConditionModel.from_pretrained(args.diffusion_path, subfolder="unet").to(self.vit_sd_device, | |
dtype=self.dtype) | |
sd_adapter_cfg = OmegaConf.load(args.sd_adapter) | |
self.sd_adapter = hydra.utils.instantiate(sd_adapter_cfg, unet=unet).eval().to(self.vit_sd_device, | |
dtype=self.dtype) | |
# self.sd_adapter.init_pipe(vae=vae, | |
# scheduler=noise_scheduler, | |
# visual_encoder=self.visual_encoder.cpu(), | |
# image_transform=self.image_transform, | |
# discrete_model=None, | |
# dtype=self.dtype, | |
# device="cpu") | |
self.sd_adapter.init_pipe(vae=vae, | |
scheduler=noise_scheduler, | |
visual_encoder=self.visual_encoder, | |
image_transform=self.image_transform, | |
discrete_model=None, | |
dtype=self.dtype, | |
device=self.vit_sd_device) | |
print('Init sd adapter pipe done.') | |
self.visual_encoder.to(self.vit_sd_device, dtype=self.dtype) | |
# model_id_or_path = "stablediffusionapi/realistic-vision-v51" | |
# self.vae_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, safety_checker=None, | |
# torch_dtype=torch.float16) | |
self.boi_token_id = self.tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0] | |
self.eoi_token_id = self.tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0] | |
service = LLMService(args) | |
def generate(text_list, image_list, max_new_tokens): | |
with torch.no_grad(): | |
text_list = text_list.split(IMG_FLAG) | |
top_p = 0.5 | |
assert len(text_list) == len(image_list) + 1 | |
image_tokens = BOI_TOKEN + ''.join( | |
[IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)]) + EOI_TOKEN | |
input_images = [] | |
if len(image_list) > 0: | |
image_tensor_list = [] | |
embeds_cmp_mask = [] | |
embeds_gen_mask = [] | |
if service.multi_resolution: | |
patch_pos = [] | |
image_patch_length = [] | |
image_size_list = [] | |
for idx, image_item in enumerate(image_list): | |
if isinstance(image_item, str): | |
image = decode_image(image_item) | |
print('after decode image size:', image.size) | |
input_images.append(image) | |
# if service.multi_resolution: | |
# image_size_list.append(image.size) | |
# print('image size:', image.size) | |
# image_tensor, patch_pos_tensor = process_anyres_image(image, service.image_transform, | |
# service.grid_pinpoints, | |
# service.base_resolution) | |
# image_tensor_list.append(image_tensor) | |
# patch_pos.append(patch_pos_tensor) | |
# image_patch_length.append(image_tensor.shape[0]) | |
# print('image_patch_length', image_patch_length) | |
# embeds_cmp_mask.extend([True] * image_tensor.shape[0]) | |
# embeds_gen_mask.extend([False] * image_tensor.shape[0]) | |
# | |
# else: | |
image_tensor = service.image_transform(image) | |
image_tensor_list.append(image_tensor) | |
embeds_cmp_mask.append(True) | |
embeds_gen_mask.append(False) | |
else: | |
raise ValueError | |
if service.multi_resolution: | |
pixel_values = torch.cat(image_tensor_list).to(service.vit_sd_device, dtype=service.dtype) | |
patch_position = torch.cat(patch_pos, dim=0) | |
image_tokens_list = [] | |
for patch_length in image_patch_length: | |
image_tokens = '' | |
for _ in range(patch_length - 1): | |
image_tokens += BOP_TOKEN + ''.join( | |
IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOP_TOKEN | |
image_tokens += BOI_TOKEN + ''.join( | |
IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOI_TOKEN | |
image_tokens_list.append(image_tokens) | |
else: | |
pixel_values = torch.stack(image_tensor_list).to(service.vit_sd_device, dtype=service.dtype) | |
image_embeds = service.visual_encoder(pixel_values) | |
image_embeds = image_embeds.to(service.llm_device) | |
embeds_cmp_mask = torch.tensor(embeds_cmp_mask, dtype=torch.bool).to(service.llm_device) | |
embeds_gen_mask = torch.tensor(embeds_gen_mask, dtype=torch.bool).to(service.llm_device) | |
else: | |
image_embeds = None | |
patch_position = 0 | |
embeds_cmp_mask = None | |
embeds_gen_mask = None | |
input_text = image_tokens.join(text_list) | |
print('input_text:', input_text) | |
input_ids = service.tokenizer.encode(input_text, add_special_tokens=False) | |
input_ids = [service.tokenizer.bos_token_id] + input_ids | |
input_ids = torch.tensor(input_ids).to(service.llm_device, dtype=torch.long) | |
ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool).to(service.llm_device) | |
ids_gen_mask = torch.zeros_like(input_ids, dtype=torch.bool).to(service.llm_device) | |
boi_indices = torch.where(input_ids == service.boi_token_id)[0].tolist() | |
eoi_indices = torch.where(input_ids == service.eoi_token_id)[0].tolist() | |
for boi_idx, eoi_idx in zip(boi_indices, eoi_indices): | |
ids_cmp_mask[boi_idx + 1:eoi_idx] = True | |
input_ids = input_ids.unsqueeze(0) | |
ids_cmp_mask = ids_cmp_mask.unsqueeze(0) | |
ids_gen_mask = ids_gen_mask.unsqueeze(0) | |
error_msg = [] | |
output = service.agent.generate( | |
tokenizer=service.tokenizer, | |
input_ids=input_ids, | |
image_embeds=image_embeds, | |
embeds_cmp_mask=embeds_cmp_mask, | |
ids_cmp_mask=ids_cmp_mask, | |
num_img_gen_tokens=num_img_out_tokens, | |
max_new_tokens=max_new_tokens, | |
dtype=service.dtype, | |
device=service.llm_device, | |
top_p=top_p, | |
) | |
gen_imgs_base64_list = [] | |
generated_text = output['text'] | |
generated_text = generated_text.replace(EOI_TOKEN, IMG_FLAG).replace(service.tokenizer.eos_token, '') | |
torch.cuda.empty_cache() | |
if output['has_img_output']: | |
# print('loading visual encoder and llm to CPU, and sd to GPU') | |
# a = time.time() | |
# service.agent = service.agent.cpu() | |
# service.sd_adapter = service.sd_adapter.to(service.vit_sd_device, dtype=service.dtype) | |
# print("Loading finished: ", time.time() - a) | |
img_gen_feat = output['img_gen_feat'].to(service.vit_sd_device, dtype=service.dtype) | |
for img_idx in range(output['num_gen_imgs']): | |
img_feat = img_gen_feat[img_idx:img_idx + 1] | |
generated_image = service.sd_adapter.generate(image_embeds=img_feat, num_inference_steps=50)[0] | |
# a = time.time() | |
# service.sd_adapter = service.sd_adapter.cpu() | |
# service.visual_encoder = service.visual_encoder.to(service.vit_sd_device, dtype=service.dtype) | |
# service.agent = service.agent.to(service.vit_sd_device, dtype=service.dtype) | |
# print("Loading finished: ", time.time() - a) | |
print(input_text + generated_text) | |
return {'text': generated_text, 'images': gen_imgs_base64_list, 'error_msg': error_msg} | |
def http_bot(dialog_state, input_state, max_new_tokens, max_turns, | |
request: gr.Request): | |
print('input_state:', input_state) | |
if len(dialog_state.messages) == 0 or dialog_state.messages[-1]['role'] != dialog_state.roles[0] or len( | |
dialog_state.messages[-1]['message']['text'].strip(' ?.;!/')) == 0: | |
return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (no_change_btn,) * 4 | |
if len(dialog_state.messages) > max_turns * 2: | |
output_state = init_input_state() | |
output_state['text'] = 'Error: History exceeds maximum rounds, please clear history and restart.' | |
dialog_state.messages.append({'role': dialog_state.roles[1], 'message': output_state}) | |
input_state = init_input_state() | |
return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (disable_btn,) * 3 + (enable_btn,) | |
prompt = dialog_state.get_prompt() | |
text = prompt['text'] | |
max_new_tokens = int(max_new_tokens) | |
images = prompt['images'] | |
results = generate(text, images, max_new_tokens) | |
print('response: ', {'text': results['text'], 'error_msg': results['error_msg']}) | |
output_state = init_input_state() | |
image_dir = get_conv_image_dir() | |
output_state['text'] = results['text'] | |
for image_base64 in results['images']: | |
if image_base64 == '': | |
image_path = '' | |
else: | |
image = decode_image(image_base64) | |
image = image.convert('RGB') | |
image_path = get_image_name(image=image, image_dir=image_dir) | |
if not os.path.exists(image_path): | |
image.save(image_path) | |
output_state['images'].append(image_path) | |
dialog_state.messages.append({'role': dialog_state.roles[1], 'message': output_state}) | |
vote_last_response(dialog_state, 'common', request) | |
input_state = init_input_state() | |
chatbot = update_error_msg(dialog_state.to_gradio_chatbot(), results['error_msg']) | |
return (dialog_state, input_state, chatbot) + (enable_btn,) * 4 | |
IMG_FLAG = '<image>' | |
LOGDIR = 'log' | |
logger = build_logger("gradio_seed_x", LOGDIR) | |
headers = {"User-Agent": "SEED-X Client"} | |
no_change_btn = gr.Button() | |
enable_btn = gr.Button(interactive=True) | |
disable_btn = gr.Button(interactive=False) | |
conv_seed_llama = conv_seed_llama2 | |
def get_conv_log_filename(): | |
t = datetime.datetime.now() | |
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") | |
return name | |
def get_conv_image_dir(): | |
name = os.path.join(LOGDIR, 'images') | |
os.makedirs(name, exist_ok=True) | |
return name | |
def get_image_name(image, image_dir=None): | |
buffer = io.BytesIO() | |
image.save(buffer, format='PNG') | |
image_bytes = buffer.getvalue() | |
md5 = hashlib.md5(image_bytes).hexdigest() | |
if image_dir is not None: | |
image_name = os.path.join(image_dir, md5 + '.png') | |
else: | |
image_name = md5 + '.png' | |
return image_name | |
def resize_image_square(image, target_size=448): | |
resized_image = image.resize((target_size, target_size)) | |
return resized_image | |
def resize_image(image, max_size=512): | |
width, height = image.size | |
aspect_ratio = float(width) / float(height) | |
if width > height: | |
new_width = max_size | |
new_height = int(new_width / aspect_ratio) | |
else: | |
new_height = max_size | |
new_width = int(new_height * aspect_ratio) | |
resized_image = image.resize((new_width, new_height)) | |
return resized_image | |
def center_crop_image(image, max_aspect_ratio=1.5): | |
width, height = image.size | |
aspect_ratio = max(width, height) / min(width, height) | |
if aspect_ratio >= max_aspect_ratio: | |
if width > height: | |
new_width = int(height * max_aspect_ratio) | |
left = (width - new_width) // 2 | |
right = (width + new_width) // 2 | |
top = 0 | |
bottom = height | |
else: | |
new_height = int(width * max_aspect_ratio) | |
left = 0 | |
right = width | |
top = (height - new_height) // 2 | |
bottom = (height + new_height) // 2 | |
cropped_image = image.crop((left, top, right, bottom)) | |
return cropped_image | |
else: | |
return image | |
def vote_last_response(state, vote_type, request: gr.Request): | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(time.time(), 4), | |
"type": vote_type, | |
"state": state.dict(), | |
"ip": request.client.host, | |
} | |
fout.write(json.dumps(data) + "\n") | |
def upvote_last_response(state, request: gr.Request): | |
logger.info(f"upvote. ip: {request.client.host}") | |
vote_last_response(state, "upvote", request) | |
return (disable_btn,) * 2 | |
def downvote_last_response(state, request: gr.Request): | |
logger.info(f"downvote. ip: {request.client.host}") | |
vote_last_response(state, "downvote", request) | |
return (disable_btn,) * 2 | |
def regenerate(dialog_state, request: gr.Request): | |
logger.info(f"regenerate. ip: {request.client.host}") | |
if dialog_state.messages[-1]['role'] == dialog_state.roles[1]: | |
dialog_state.messages.pop() | |
return ( | |
dialog_state, | |
dialog_state.to_gradio_chatbot(), | |
) + (disable_btn,) * 4 | |
def clear_history(request: gr.Request): | |
logger.info(f"clear_history. ip: {request.client.host}") | |
dialog_state = conv_seed_llama.copy() | |
input_state = init_input_state() | |
return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (disable_btn,) * 4 | |
def init_input_state(): | |
return {'images': [], 'text': ''} | |
def add_text(dialog_state, input_state, text, request: gr.Request): | |
logger.info(f"add_text. ip: {request.client.host}.") | |
if text is None or len(text) == 0: | |
return (dialog_state, input_state, "", dialog_state.to_gradio_chatbot()) + (no_change_btn,) * 4 | |
input_state['text'] += text | |
if len(dialog_state.messages) > 0 and dialog_state.messages[-1]['role'] == dialog_state.roles[0]: | |
dialog_state.messages[-1]['message'] = input_state | |
else: | |
dialog_state.messages.append({'role': dialog_state.roles[0], 'message': input_state}) | |
print('add_text: ', dialog_state.to_gradio_chatbot()) | |
return (dialog_state, input_state, "", dialog_state.to_gradio_chatbot()) + (disable_btn,) * 4 | |
def is_blank(image): | |
image_array = np.array(image) | |
unique_colors = np.unique(image_array) | |
print('unique_colors', len(unique_colors)) | |
return len(unique_colors) == 1 | |
def add_image(dialog_state, input_state, image, request: gr.Request): | |
logger.info(f"add_image. ip: {request.client.host}.") | |
if image is None: | |
return (dialog_state, input_state, None, dialog_state.to_gradio_chatbot()) + (no_change_btn,) * 4 | |
image = image.convert('RGB') | |
print('image size:', image.size) | |
image = center_crop_image(image, max_aspect_ratio=10) | |
image_dir = get_conv_image_dir() | |
image_path = get_image_name(image=image, image_dir=image_dir) | |
if not os.path.exists(image_path): | |
image.save(image_path) | |
input_state['images'].append(image_path) | |
input_state['text'] += IMG_FLAG | |
if len(dialog_state.messages) > 0 and dialog_state.messages[-1]['role'] == dialog_state.roles[0]: | |
dialog_state.messages[-1]['message'] = input_state | |
else: | |
dialog_state.messages.append({'role': dialog_state.roles[0], 'message': input_state}) | |
print('add_image:', dialog_state) | |
return (dialog_state, input_state, None, dialog_state.to_gradio_chatbot()) + (disable_btn,) * 4 | |
def update_error_msg(chatbot, error_msg): | |
if len(error_msg) > 0: | |
info = '\n-------------\nSome errors occurred during response, please clear history and restart.\n' + '\n'.join( | |
error_msg) | |
chatbot[-1][-1] = chatbot[-1][-1] + info | |
return chatbot | |
def load_demo(request: gr.Request): | |
logger.info(f"load_demo. ip: {request.client.host}") | |
dialog_state = conv_seed_llama.copy() | |
input_state = init_input_state() | |
return dialog_state, input_state | |
title = (""" | |
# SEED-Story | |
[[Paper]](https://arxiv.org/abs/2407.08683) [[Code]](https://github.com/TencentARC/SEED-Story) | |
Demo of a multimodal story generation model SEED-Story-George. It is trained on StoryStream-Curious George subset. | |
SEED-Story is a MLLM capable of generating multimodal long stories consisting of rich and coherent narrative texts, along with images that are consistent in characters and style. | |
## Tips: | |
* Check out the conversation examples (at the bottom) for inspiration. | |
* You can adjust "Max History Rounds" to try a conversation with up to **three rounds due to insufficient GPU memory**. For more turns, you can download our checkpoints from GitHub and deploy them locally for inference. | |
* Our demo supports a mix of images and texts as input. You can freely upload an image or enter text, and then click on "Add Image/Text". You can repeat the former step multiple times, and click on "Submit" for model inference at last. | |
* SEED-Story was trained with English-only data. It may process with other languages due to the inherent capabilities from LLaMA, but might not stable. | |
""") | |
css = """ | |
img { | |
font-family: 'Helvetica'; | |
font-weight: 300; | |
line-height: 2; | |
text-align: center; | |
width: auto; | |
height: auto; | |
display: block; | |
position: relative; | |
} | |
img:before { | |
content: " "; | |
display: block; | |
position: absolute; | |
top: -10px; | |
left: 0; | |
height: calc(100% + 10px); | |
width: 100%; | |
background-color: rgb(230, 230, 230); | |
border: 2px dotted rgb(200, 200, 200); | |
border-radius: 5px; | |
} | |
img:after { | |
content: " "; | |
display: block; | |
font-size: 16px; | |
font-style: normal; | |
font-family: FontAwesome; | |
color: rgb(100, 100, 100); | |
position: absolute; | |
top: 5px; | |
left: 0; | |
width: 100%; | |
text-align: center; | |
} | |
""" | |
if __name__ == '__main__': | |
examples_mix = [ | |
['https://github.com/AILab-CVC/SEED-X/blob/main/demos/bank.png?raw=true', | |
'Can I conntect with an advisor on Sunday?'], | |
['https://github.com/AILab-CVC/SEED-X/blob/main/demos/ground.png?raw=true', | |
'Is there anything in the image that can protect me from catching the flu virus when I go out? Show me the location.'], | |
['https://github.com/AILab-CVC/SEED-X/blob/main/demos/arrow.jpg?raw=true', | |
'What is the object pointed by the red arrow?'], | |
['https://github.com/AILab-CVC/SEED-X/blob/main/demos/shanghai.png?raw=true', | |
'Where was this image taken? Explain your answer.'], | |
['https://github.com/AILab-CVC/SEED-X/blob/main/demos/GPT4.png?raw=true', | |
'How long does it take to make GPT-4 safer?'], | |
['https://github.com/AILab-CVC/SEED-X/blob/main/demos/twitter.png?raw=true', | |
'Please provide a comprehensive description of this image.'], | |
] | |
examples_text = [ | |
['I want to build a two story cabin in the woods, with many commanding windows. Can you show me a picture?'], | |
['Use your imagination to design a concept image for Artificial General Intelligence (AGI). Show me an image.'], | |
[ | |
'Can you design an illustration for โThe Three-Body Problemโ to depict a scene from the novel? Show me a picture.'], | |
[ | |
'My four year old son loves toy trains. Can you design a fancy birthday cake for him? Please generate a picture.'], | |
[ | |
'Generate an image of a portrait of young nordic girl, age 25, freckled skin, neck tatoo, blue eyes 35mm lens, photography, ultra details.'], | |
['Generate an impressionist painting of an astronaut in a jungle.'] | |
] | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown(title) | |
dialog_state = gr.State() | |
input_state = gr.State() | |
with gr.Row(): | |
with gr.Column(scale=3): | |
with gr.Row(): | |
image = gr.Image(type='pil', label='input_image') | |
with gr.Row(): | |
text = gr.Textbox(lines=5, | |
show_label=False, | |
label='input_text', | |
elem_id='textbox', | |
placeholder="Enter text and image, and press submit,", container=False) | |
with gr.Row(): | |
add_image_btn = gr.Button("Add Image") | |
add_text_btn = gr.Button("Add Text") | |
submit_btn = gr.Button("Submit") | |
with gr.Row(): | |
max_new_tokens = gr.Slider(minimum=64, | |
maximum=1024, | |
value=768, | |
step=64, | |
interactive=True, | |
label="Max Output Tokens") | |
max_turns = gr.Slider(minimum=1, maximum=3, value=3, step=1, interactive=True, | |
label="Max History Rounds") | |
force_img_gen = gr.Radio(choices=[True, False], value=False, label='Force Image Generation') | |
force_bbox = gr.Radio(choices=[True, False], value=False, label='Force Bounding Box') | |
force_polish = gr.Radio(choices=[True, False], value=True, label='Force Polishing Generated Image') | |
with gr.Column(scale=7): | |
chatbot = gr.Chatbot(elem_id='chatbot', label="SEED-X-I", height=700) | |
with gr.Row(): | |
upvote_btn = gr.Button(value="๐ Upvote", interactive=False) | |
downvote_btn = gr.Button(value="๐ Downvote", interactive=False) | |
regenerate_btn = gr.Button(value="๐ Regenerate", interactive=False) | |
clear_btn = gr.Button(value="๐๏ธ Clear history", interactive=False) | |
with gr.Row(): | |
with gr.Column(scale=0.7): | |
gr.Examples(examples=examples_mix, label='Input examples', inputs=[image, text], cache_examples=False) | |
with gr.Column(scale=0.3): | |
gr.Examples(examples=examples_text, label='Input examples', inputs=[text], cache_examples=False) | |
# Register listeners | |
btn_list = [upvote_btn, downvote_btn, regenerate_btn, clear_btn] | |
upvote_btn.click(upvote_last_response, [dialog_state], [upvote_btn, downvote_btn]) | |
downvote_btn.click(downvote_last_response, [dialog_state], [upvote_btn, downvote_btn]) | |
regenerate_btn.click(regenerate, [dialog_state], [dialog_state, chatbot] + btn_list).then( | |
http_bot, [dialog_state, input_state, max_new_tokens, max_turns, force_img_gen, force_bbox, force_polish], | |
[dialog_state, input_state, chatbot] + btn_list) | |
add_image_btn.click(add_image, [dialog_state, input_state, image], | |
[dialog_state, input_state, image, chatbot] + btn_list) | |
add_text_btn.click(add_text, [dialog_state, input_state, text], | |
[dialog_state, input_state, text, chatbot] + btn_list) | |
submit_btn.click( | |
add_image, [dialog_state, input_state, image], [dialog_state, input_state, image, chatbot] + btn_list).then( | |
add_text, [dialog_state, input_state, text], | |
[dialog_state, input_state, text, chatbot, upvote_btn, downvote_btn, regenerate_btn, clear_btn]).then( | |
http_bot, | |
[dialog_state, input_state, max_new_tokens, max_turns, force_img_gen, force_bbox, force_polish], | |
[dialog_state, input_state, chatbot] + btn_list) | |
clear_btn.click(clear_history, None, [dialog_state, input_state, chatbot] + btn_list) | |
demo.load(load_demo, None, [dialog_state, input_state]) | |
demo.launch(debug=True) |