Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from tiny_ur5 import TinyUR5Env | |
import yaml | |
from initializer import Initializer | |
import random | |
import string | |
import imageio | |
from skimage import img_as_ubyte | |
from test_model import model_forward_fn | |
from PIL import Image | |
def load_model(ckpt, method, device): | |
if method == 'bcz': | |
from models.film_model import Backbone | |
# model = Backbone(img_size=224, num_traces_out=4, embedding_size=256, num_weight_points=10, input_nc=3, device=device) | |
model = Backbone(img_size=224, num_traces_out=8, embedding_size=256, num_weight_points=12, input_nc=3, device=device) | |
model.load_state_dict(torch.load(ckpt, map_location=device)['model'], strict=True) | |
# model = model.cpu() | |
model = model.to(device) | |
return model | |
elif method == 'ours': | |
# import tinyur5.models.backbone_rgbd_sub_attn_tinyur5.Backbone as Backbone | |
# import tinyur5 | |
# import models.backbone_rgbd_sub_attn_tinyur5.Backbone | |
from models.backbone_rgbd_sub_attn_tinyur5 import Backbone | |
# from tinyur5.models.backbone_rgbd_sub_attn_tinyur5 import Backbone | |
model = Backbone(img_size=224, embedding_size=256, num_traces_out=2, num_joints=8, num_weight_points=12, input_nc=3, device=device) | |
model.load_state_dict(torch.load(ckpt, map_location=device)['model'], strict=True) | |
model = model.to(device) | |
return model | |
device = torch.device('cpu') | |
# ckpt = '580000.pth' | |
ckpt = '160000.pth' | |
print('start loading model') | |
model = load_model(ckpt, 'ours', device) | |
print('model loaded') | |
with gr.Blocks() as demo: | |
state = gr.State() | |
# with open('config.yaml', "r") as stream: | |
# try: | |
# config = yaml.safe_load(stream) | |
# # print(config, type(config)) | |
# except yaml.YAMLError as exc: | |
# print(exc) | |
# initializer = Initializer(config) | |
# config, task = initializer.get_config_and_task() | |
# sentence = initializer.get_sentence() | |
# env = TinyUR5Env(config) | |
def init(environment): | |
if environment == 'original': | |
config_file = 'config.yaml' | |
else: | |
config_file = 'config_stable_diffusion.yaml' | |
# with open('config.yaml', "r") as stream: | |
with open(config_file, "r") as stream: | |
try: | |
config = yaml.safe_load(stream) | |
# print(config, type(config)) | |
except yaml.YAMLError as exc: | |
print(exc) | |
if environment == 'original': | |
initializer = Initializer(config, obj_num_low=3, obj_num_high=5) | |
else: | |
initializer = Initializer(config, obj_num_low=1, obj_num_high=2) | |
config, task = initializer.get_config_and_task() | |
sentence = initializer.get_sentence() | |
env = TinyUR5Env(config) | |
init_img = env.render('rgb_array') | |
current_state = { | |
'env': env, | |
'id': ''.join(random.choice(string.ascii_lowercase + string.ascii_uppercase + string.ascii_letters) for i in range(20)) | |
} | |
return init_img, current_state | |
def exec(sentence, current_state, resolution): | |
env = current_state['env'] | |
img = env.render('rgb_array') | |
imgs = [] | |
time_step = 0 | |
while time_step < 150: | |
actions = model_forward_fn(env, model, sentence, 'ours', device) | |
# for i in range(actions.shape[-1]): | |
for i in range(15, 50): | |
action = actions[:, i] | |
observation, reward, done, info = env.step(action, eef_z=80) | |
img = env.render('rgb_array') | |
img = Image.fromarray(img) | |
if resolution == 'low(3 sec)': | |
img = img.resize((240, 140)) | |
elif resolution == 'mid(5 sec)': | |
img = img.resize((480, 280)) | |
elif resolution == 'high(7 sec)': | |
img = img.resize((720, 420)) | |
# imgs.append(Image.fromarray(img)) | |
if time_step % 12 == 0: | |
imgs.append(img) | |
time_step += 1 | |
print(time_step) | |
env.close() | |
# context = {} | |
# is_success, buffer = cv2.imencode(".jpg", cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) | |
# img_buffer = BytesIO() | |
# imgs[0].save(img_buffer, save_all=True, append_images=imgs[1:], duration=100, loop=0) | |
# img = base64.b64encode(img_buffer.getvalue()).decode('utf-8') | |
# imageio.mimsave(os.path.join('tinyur5/static/', request.session['id'] + '.gif') , [img_as_ubyte(frame) for frame in imgs], 'GIF', fps=20) | |
# with open(os.path.join('tinyur5/static/', request.session['id'] + '.gif'), "rb") as gif_file: | |
# img = format(base64.b64encode(gif_file.read()).decode()) | |
img_id = ''.join(random.choice(string.ascii_lowercase + string.ascii_uppercase + string.ascii_letters) for i in range(20)) | |
imageio.mimsave(img_id+'.gif', [img_as_ubyte(frame) for frame in imgs], 'GIF', fps=10) | |
img = img_id+'.gif' | |
next_state = { | |
'id': current_state['id'], | |
'env': env | |
} | |
return env.render('rgb_array'), img, next_state | |
with gr.Row(): | |
with gr.Column(scale=4): | |
instruction = gr.Text(label="""Input an Instruction Here:""", placeholder='Push XXX to the right / Rotate XXX') | |
with gr.Column(scale=2): | |
resolution = gr.Radio( | |
label='Image Quality', | |
choices=['low(3 sec)', 'mid(5 sec)', 'high(7 sec)'], | |
value='low(3 sec)') | |
with gr.Column(scale=1): | |
environment = gr.Radio( | |
label='Environment', | |
choices=['original', 'stable diffusion'], | |
value='original') | |
with gr.Row(): | |
action = gr.Button(value='Action!') | |
with gr.Row(): | |
init_img_placeholder = gr.Image() | |
gif_img_placeholder = gr.Image() | |
with gr.Row(): | |
load_env = gr.Button(value='Reload Simulator') | |
with gr.Row(): | |
with gr.Column(): | |
illustration = gr.Markdown( | |
# label='Try Commanding the Robot Yourself!', | |
value= | |
""" | |
## Commanding the Robot Yourself! | |
(1) Type in some instructions in the instruction box at the top. | |
(2) Hit 'Action!' button to start executing your instruction. | |
(3) Hit 'Reload Simulator' button if you want to re-initialize the simulator. | |
## Try the images generated from stable diffusion! | |
Click on the 'stable diffusion' radio for initializing the environment by images generated by stable diffusion. | |
""", | |
# lines=3, | |
# interactive=False | |
) | |
with gr.Column(): | |
illustration = gr.Markdown( | |
# label='Sample instructions:', | |
value= | |
""" | |
## Sample Instructions: | |
The robot can support pushing the objects in 4 directions, as well as rotating them: | |
``` | |
\u2022 Push the apple to the right | |
\u2022 Rotate the watermelon clockwise | |
\u2022 Move the clock backwards | |
``` | |
""", | |
# lines=4, | |
# interactive=False | |
) | |
load_env.click( | |
init, | |
inputs=[environment], | |
outputs=[init_img_placeholder, state], | |
show_progress=True | |
) | |
action.click( | |
exec, | |
inputs=[instruction, state, resolution], | |
outputs=[init_img_placeholder, gif_img_placeholder, state], | |
show_progress=True | |
) | |
demo.load( | |
init, | |
inputs=[environment], | |
outputs=[init_img_placeholder, state], | |
show_progress=True) | |
environment.change( | |
init, | |
inputs=[environment], | |
outputs=[init_img_placeholder, state], | |
show_progress=True | |
) | |
demo.launch(share=False) | |