ThomasSimonini HF staff commited on
Commit
0704015
1 Parent(s): 5e3cbb5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -18
app.py CHANGED
@@ -3,6 +3,8 @@ import requests.exceptions
3
  from huggingface_hub import HfApi, hf_hub_download
4
  from huggingface_hub.repocard import metadata_load
5
 
 
 
6
  def load_agent(model_id_1, model_id_2):
7
  """
8
  This function load the agent's video and results
@@ -51,26 +53,42 @@ def get_metadata(model_id):
51
  return None
52
 
53
 
54
- gr.Interface(load_agent,
55
- [
56
- gr.Textbox(
57
- label="Model 1",
58
- ),
59
- gr.Textbox(
60
- label="Model 2",
61
- ),
62
- ],
63
- [ "text", "video", gr.Textbox(
64
- label="Mean Reward +/- Std Reward",
65
- ), "text", "video", gr.Textbox(
66
- label="Mean Reward +/- Std Reward",
67
- )],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  examples=[
69
  ["sb3/a2c-AntBulletEnv-v0","sb3/ppo-AntBulletEnv-v0"],
70
  ["ThomasSimonini/a2c-AntBulletEnv-v0", "sb3/a2c-AntBulletEnv-v0"],
71
  ["sb3/dqn-SpaceInvadersNoFrameskip-v4", "sb3/a2c-SpaceInvadersNoFrameskip-v4"],
72
  ["ThomasSimonini/ppo-QbertNoFrameskip-v4","sb3/ppo-QbertNoFrameskip-v4"],
73
- ],
74
- title="Compare Deep Reinforcement Learning agents",
75
- description="Type two models id you want to compare or check examples below."
76
- ).launch()
3
  from huggingface_hub import HfApi, hf_hub_download
4
  from huggingface_hub.repocard import metadata_load
5
 
6
+ app = gr.Blocks()
7
+
8
  def load_agent(model_id_1, model_id_2):
9
  """
10
  This function load the agent's video and results
53
  return None
54
 
55
 
56
+
57
+
58
+ with app:
59
+ gr.Markdown(
60
+ """
61
+ # Compare Deep Reinforcement Learning Agents 🤖
62
+
63
+ Type two models id you want to compare or check examples below.
64
+ """)
65
+ with gr.Row():
66
+ model1_input = gr.Textbox(label="Model 1")
67
+ model2_input = gr.Textbox(label="Model 2")
68
+ with gr.Row():
69
+ app_button = gr.Button("Compare models")
70
+ with gr.Row():
71
+ with gr.Column():
72
+ model1_name = gr.Markdown()
73
+ model1_video_output = gr.Video()
74
+ model1_score_output = gr.Textbox(label="Mean Reward +/- Std Reward")
75
+ with gr.Column():
76
+ model2_name = gr.Markdown()
77
+ model2_video_output = gr.Video()
78
+ model2_score_output = gr.Textbox(label="Mean Reward +/- Std Reward")
79
+
80
+ app_button.click(load_agent, inputs=[model1_input, model2_input], outputs=[model1_name, model1_video_output, model1_score_output, model2_name, model2_video_output, model2_score_output])
81
+
82
+ app.launch()
83
+
84
+
85
+
86
+
87
+ """
88
+
89
  examples=[
90
  ["sb3/a2c-AntBulletEnv-v0","sb3/ppo-AntBulletEnv-v0"],
91
  ["ThomasSimonini/a2c-AntBulletEnv-v0", "sb3/a2c-AntBulletEnv-v0"],
92
  ["sb3/dqn-SpaceInvadersNoFrameskip-v4", "sb3/a2c-SpaceInvadersNoFrameskip-v4"],
93
  ["ThomasSimonini/ppo-QbertNoFrameskip-v4","sb3/ppo-QbertNoFrameskip-v4"],
94
+ """