import gradio as gr import torch from tiny_ur5 import TinyUR5Env import yaml from initializer import Initializer import random import string import imageio from skimage import img_as_ubyte from test_model import model_forward_fn from PIL import Image def load_model(ckpt, method, device): if method == 'bcz': from models.film_model import Backbone # model = Backbone(img_size=224, num_traces_out=4, embedding_size=256, num_weight_points=10, input_nc=3, device=device) model = Backbone(img_size=224, num_traces_out=8, embedding_size=256, num_weight_points=12, input_nc=3, device=device) model.load_state_dict(torch.load(ckpt, map_location=device)['model'], strict=True) # model = model.cpu() model = model.to(device) return model elif method == 'ours': # import tinyur5.models.backbone_rgbd_sub_attn_tinyur5.Backbone as Backbone # import tinyur5 # import models.backbone_rgbd_sub_attn_tinyur5.Backbone from models.backbone_rgbd_sub_attn_tinyur5 import Backbone # from tinyur5.models.backbone_rgbd_sub_attn_tinyur5 import Backbone model = Backbone(img_size=224, embedding_size=256, num_traces_out=2, num_joints=8, num_weight_points=12, input_nc=3, device=device) model.load_state_dict(torch.load(ckpt, map_location=device)['model'], strict=True) model = model.to(device) return model device = torch.device('cpu') # ckpt = '580000.pth' ckpt = '160000.pth' print('start loading model') model = load_model(ckpt, 'ours', device) print('model loaded') with gr.Blocks() as demo: state = gr.State() # with open('config.yaml', "r") as stream: # try: # config = yaml.safe_load(stream) # # print(config, type(config)) # except yaml.YAMLError as exc: # print(exc) # initializer = Initializer(config) # config, task = initializer.get_config_and_task() # sentence = initializer.get_sentence() # env = TinyUR5Env(config) def init(environment): if environment == 'original': config_file = 'config.yaml' else: config_file = 'config_stable_diffusion.yaml' # with open('config.yaml', "r") as stream: with open(config_file, "r") as stream: try: config = yaml.safe_load(stream) # print(config, type(config)) except yaml.YAMLError as exc: print(exc) if environment == 'original': initializer = Initializer(config, obj_num_low=3, obj_num_high=5) else: initializer = Initializer(config, obj_num_low=1, obj_num_high=2) config, task = initializer.get_config_and_task() sentence = initializer.get_sentence() env = TinyUR5Env(config) init_img = env.render('rgb_array') current_state = { 'env': env, 'id': ''.join(random.choice(string.ascii_lowercase + string.ascii_uppercase + string.ascii_letters) for i in range(20)) } return init_img, current_state def exec(sentence, current_state, resolution): env = current_state['env'] img = env.render('rgb_array') imgs = [] time_step = 0 while time_step < 150: actions = model_forward_fn(env, model, sentence, 'ours', device) # for i in range(actions.shape[-1]): for i in range(15, 50): action = actions[:, i] observation, reward, done, info = env.step(action, eef_z=80) img = env.render('rgb_array') img = Image.fromarray(img) if resolution == 'low(3 sec)': img = img.resize((240, 140)) elif resolution == 'mid(5 sec)': img = img.resize((480, 280)) elif resolution == 'high(7 sec)': img = img.resize((720, 420)) # imgs.append(Image.fromarray(img)) if time_step % 12 == 0: imgs.append(img) time_step += 1 print(time_step) env.close() # context = {} # is_success, buffer = cv2.imencode(".jpg", cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) # img_buffer = BytesIO() # imgs[0].save(img_buffer, save_all=True, append_images=imgs[1:], duration=100, loop=0) # img = base64.b64encode(img_buffer.getvalue()).decode('utf-8') # imageio.mimsave(os.path.join('tinyur5/static/', request.session['id'] + '.gif') , [img_as_ubyte(frame) for frame in imgs], 'GIF', fps=20) # with open(os.path.join('tinyur5/static/', request.session['id'] + '.gif'), "rb") as gif_file: # img = format(base64.b64encode(gif_file.read()).decode()) img_id = ''.join(random.choice(string.ascii_lowercase + string.ascii_uppercase + string.ascii_letters) for i in range(20)) imageio.mimsave(img_id+'.gif', [img_as_ubyte(frame) for frame in imgs], 'GIF', fps=10) img = img_id+'.gif' next_state = { 'id': current_state['id'], 'env': env } return env.render('rgb_array'), img, next_state with gr.Row(): with gr.Column(scale=4): instruction = gr.Text(label="""Input an Instruction Here:""", placeholder='Push XXX to the right / Rotate XXX') with gr.Column(scale=2): resolution = gr.Radio( label='Image Quality', choices=['low(3 sec)', 'mid(5 sec)', 'high(7 sec)'], value='low(3 sec)') with gr.Column(scale=1): environment = gr.Radio( label='Environment', choices=['original', 'stable diffusion'], value='original') with gr.Row(): action = gr.Button(value='Action!') with gr.Row(): init_img_placeholder = gr.Image() gif_img_placeholder = gr.Image() with gr.Row(): load_env = gr.Button(value='Reload Simulator') with gr.Row(): with gr.Column(): illustration = gr.Markdown( # label='Try Commanding the Robot Yourself!', value= """ ## Commanding the Robot Yourself! (1) Type in some instructions in the instruction box at the top. (2) Hit 'Action!' button to start executing your instruction. (3) Hit 'Reload Simulator' button if you want to re-initialize the simulator. ## Try the images generated from stable diffusion! Click on the 'stable diffusion' radio for initializing the environment by images generated by stable diffusion. """, # lines=3, # interactive=False ) with gr.Column(): illustration = gr.Markdown( # label='Sample instructions:', value= """ ## Sample Instructions: The robot can support pushing the objects in 4 directions, as well as rotating them: ``` \u2022 Push the apple to the right \u2022 Rotate the watermelon clockwise \u2022 Move the clock backwards ``` """, # lines=4, # interactive=False ) load_env.click( init, inputs=[environment], outputs=[init_img_placeholder, state], show_progress=True ) action.click( exec, inputs=[instruction, state, resolution], outputs=[init_img_placeholder, gif_img_placeholder, state], show_progress=True ) demo.load( init, inputs=[environment], outputs=[init_img_placeholder, state], show_progress=True) environment.change( init, inputs=[environment], outputs=[init_img_placeholder, state], show_progress=True ) demo.launch(share=False)