yfzhoucs's picture
Updated: new instructions; auto-reload env when switch for stable diffusion
8fcb5eb
import gradio as gr
import torch
from tiny_ur5 import TinyUR5Env
import yaml
from initializer import Initializer
import random
import string
import imageio
from skimage import img_as_ubyte
from test_model import model_forward_fn
from PIL import Image
def load_model(ckpt, method, device):
if method == 'bcz':
from models.film_model import Backbone
# model = Backbone(img_size=224, num_traces_out=4, embedding_size=256, num_weight_points=10, input_nc=3, device=device)
model = Backbone(img_size=224, num_traces_out=8, embedding_size=256, num_weight_points=12, input_nc=3, device=device)
model.load_state_dict(torch.load(ckpt, map_location=device)['model'], strict=True)
# model = model.cpu()
model = model.to(device)
return model
elif method == 'ours':
# import tinyur5.models.backbone_rgbd_sub_attn_tinyur5.Backbone as Backbone
# import tinyur5
# import models.backbone_rgbd_sub_attn_tinyur5.Backbone
from models.backbone_rgbd_sub_attn_tinyur5 import Backbone
# from tinyur5.models.backbone_rgbd_sub_attn_tinyur5 import Backbone
model = Backbone(img_size=224, embedding_size=256, num_traces_out=2, num_joints=8, num_weight_points=12, input_nc=3, device=device)
model.load_state_dict(torch.load(ckpt, map_location=device)['model'], strict=True)
model = model.to(device)
return model
device = torch.device('cpu')
# ckpt = '580000.pth'
ckpt = '160000.pth'
print('start loading model')
model = load_model(ckpt, 'ours', device)
print('model loaded')
with gr.Blocks() as demo:
state = gr.State()
# with open('config.yaml', "r") as stream:
# try:
# config = yaml.safe_load(stream)
# # print(config, type(config))
# except yaml.YAMLError as exc:
# print(exc)
# initializer = Initializer(config)
# config, task = initializer.get_config_and_task()
# sentence = initializer.get_sentence()
# env = TinyUR5Env(config)
def init(environment):
if environment == 'original':
config_file = 'config.yaml'
else:
config_file = 'config_stable_diffusion.yaml'
# with open('config.yaml', "r") as stream:
with open(config_file, "r") as stream:
try:
config = yaml.safe_load(stream)
# print(config, type(config))
except yaml.YAMLError as exc:
print(exc)
if environment == 'original':
initializer = Initializer(config, obj_num_low=3, obj_num_high=5)
else:
initializer = Initializer(config, obj_num_low=1, obj_num_high=2)
config, task = initializer.get_config_and_task()
sentence = initializer.get_sentence()
env = TinyUR5Env(config)
init_img = env.render('rgb_array')
current_state = {
'env': env,
'id': ''.join(random.choice(string.ascii_lowercase + string.ascii_uppercase + string.ascii_letters) for i in range(20))
}
return init_img, current_state
def exec(sentence, current_state, resolution):
env = current_state['env']
img = env.render('rgb_array')
imgs = []
time_step = 0
while time_step < 150:
actions = model_forward_fn(env, model, sentence, 'ours', device)
# for i in range(actions.shape[-1]):
for i in range(15, 50):
action = actions[:, i]
observation, reward, done, info = env.step(action, eef_z=80)
img = env.render('rgb_array')
img = Image.fromarray(img)
if resolution == 'low(3 sec)':
img = img.resize((240, 140))
elif resolution == 'mid(5 sec)':
img = img.resize((480, 280))
elif resolution == 'high(7 sec)':
img = img.resize((720, 420))
# imgs.append(Image.fromarray(img))
if time_step % 12 == 0:
imgs.append(img)
time_step += 1
print(time_step)
env.close()
# context = {}
# is_success, buffer = cv2.imencode(".jpg", cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
# img_buffer = BytesIO()
# imgs[0].save(img_buffer, save_all=True, append_images=imgs[1:], duration=100, loop=0)
# img = base64.b64encode(img_buffer.getvalue()).decode('utf-8')
# imageio.mimsave(os.path.join('tinyur5/static/', request.session['id'] + '.gif') , [img_as_ubyte(frame) for frame in imgs], 'GIF', fps=20)
# with open(os.path.join('tinyur5/static/', request.session['id'] + '.gif'), "rb") as gif_file:
# img = format(base64.b64encode(gif_file.read()).decode())
img_id = ''.join(random.choice(string.ascii_lowercase + string.ascii_uppercase + string.ascii_letters) for i in range(20))
imageio.mimsave(img_id+'.gif', [img_as_ubyte(frame) for frame in imgs], 'GIF', fps=10)
img = img_id+'.gif'
next_state = {
'id': current_state['id'],
'env': env
}
return env.render('rgb_array'), img, next_state
with gr.Row():
with gr.Column(scale=4):
instruction = gr.Text(label="""Input an Instruction Here:""", placeholder='Push XXX to the right / Rotate XXX')
with gr.Column(scale=2):
resolution = gr.Radio(
label='Image Quality',
choices=['low(3 sec)', 'mid(5 sec)', 'high(7 sec)'],
value='low(3 sec)')
with gr.Column(scale=1):
environment = gr.Radio(
label='Environment',
choices=['original', 'stable diffusion'],
value='original')
with gr.Row():
action = gr.Button(value='Action!')
with gr.Row():
init_img_placeholder = gr.Image()
gif_img_placeholder = gr.Image()
with gr.Row():
load_env = gr.Button(value='Reload Simulator')
with gr.Row():
with gr.Column():
illustration = gr.Markdown(
# label='Try Commanding the Robot Yourself!',
value=
"""
## Commanding the Robot Yourself!
(1) Type in some instructions in the instruction box at the top.
(2) Hit 'Action!' button to start executing your instruction.
(3) Hit 'Reload Simulator' button if you want to re-initialize the simulator.
## Try the images generated from stable diffusion!
Click on the 'stable diffusion' radio for initializing the environment by images generated by stable diffusion.
""",
# lines=3,
# interactive=False
)
with gr.Column():
illustration = gr.Markdown(
# label='Sample instructions:',
value=
"""
## Sample Instructions:
The robot can support pushing the objects in 4 directions, as well as rotating them:
```
\u2022 Push the apple to the right
\u2022 Rotate the watermelon clockwise
\u2022 Move the clock backwards
```
""",
# lines=4,
# interactive=False
)
load_env.click(
init,
inputs=[environment],
outputs=[init_img_placeholder, state],
show_progress=True
)
action.click(
exec,
inputs=[instruction, state, resolution],
outputs=[init_img_placeholder, gif_img_placeholder, state],
show_progress=True
)
demo.load(
init,
inputs=[environment],
outputs=[init_img_placeholder, state],
show_progress=True)
environment.change(
init,
inputs=[environment],
outputs=[init_img_placeholder, state],
show_progress=True
)
demo.launch(share=False)