odiaz1066 ThomasSimonini HF staff commited on
Commit
18dd671
0 Parent(s):

Duplicate from ThomasSimonini/Compare-Reinforcement-Learning-Agents

Browse files

Co-authored-by: Thomas Simonini <ThomasSimonini@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.onnx filter=lfs diff=lfs merge=lfs -text
13
+ *.ot filter=lfs diff=lfs merge=lfs -text
14
+ *.parquet filter=lfs diff=lfs merge=lfs -text
15
+ *.pb filter=lfs diff=lfs merge=lfs -text
16
+ *.pt filter=lfs diff=lfs merge=lfs -text
17
+ *.pth filter=lfs diff=lfs merge=lfs -text
18
+ *.rar filter=lfs diff=lfs merge=lfs -text
19
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
20
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
21
+ *.tflite filter=lfs diff=lfs merge=lfs -text
22
+ *.tgz filter=lfs diff=lfs merge=lfs -text
23
+ *.wasm filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.idea/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
.idea/Compare-Reinforcement-Learning-Agents.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8" project-jdk-type="Python SDK" />
4
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/Compare-Reinforcement-Learning-Agents.iml" filepath="$PROJECT_DIR$/.idea/Compare-Reinforcement-Learning-Agents.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="$PROJECT_DIR$" vcs="Git" />
5
+ </component>
6
+ </project>
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Compare Reinforcement Learning Agents
3
+ emoji: 👀
4
+ colorFrom: blue
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 3.1.1
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: ThomasSimonini/Compare-Reinforcement-Learning-Agents
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
__pycache__/app.cpython-38.pyc ADDED
Binary file (1.08 kB). View file
 
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests.exceptions
3
+ from huggingface_hub import HfApi, hf_hub_download
4
+ from huggingface_hub.repocard import metadata_load
5
+
6
+ app = gr.Blocks()
7
+
8
+ def load_agent(model_id_1, model_id_2):
9
+ """
10
+ This function load the agent's video and results
11
+ :return: video_path
12
+ """
13
+ # Load the metrics
14
+ metadata_1 = get_metadata(model_id_1)
15
+
16
+ # Get the accuracy
17
+ results_1 = parse_metrics_accuracy(metadata_1)
18
+
19
+ # Load the video
20
+ video_path_1 = hf_hub_download(model_id_1, filename="replay.mp4")
21
+
22
+ # Load the metrics
23
+ metadata_2 = get_metadata(model_id_2)
24
+
25
+ # Get the accuracy
26
+ results_2 = parse_metrics_accuracy(metadata_2)
27
+
28
+ # Load the video
29
+ video_path_2 = hf_hub_download(model_id_2, filename="replay.mp4")
30
+
31
+ return model_id_1, video_path_1, results_1, model_id_2, video_path_2, results_2
32
+
33
+ def parse_metrics_accuracy(meta):
34
+ if "model-index" not in meta:
35
+ return None
36
+ result = meta["model-index"][0]["results"]
37
+ metrics = result[0]["metrics"]
38
+ accuracy = metrics[0]["value"]
39
+ return accuracy
40
+
41
+ def get_metadata(model_id):
42
+ """
43
+ Get the metadata of the model repo
44
+ :param model_id:
45
+ :return: metadata
46
+ """
47
+ try:
48
+ readme_path = hf_hub_download(model_id, filename="README.md")
49
+ metadata = metadata_load(readme_path)
50
+ print(metadata)
51
+ return metadata
52
+ except requests.exceptions.HTTPError:
53
+ return None
54
+
55
+
56
+
57
+
58
+ with app:
59
+ gr.Markdown(
60
+ """
61
+ # Compare Deep Reinforcement Learning Agents 🤖
62
+
63
+ Type two models id you want to compare or check examples below.
64
+ """)
65
+ with gr.Row():
66
+ model1_input = gr.Textbox(label="Model 1")
67
+ model2_input = gr.Textbox(label="Model 2")
68
+ with gr.Row():
69
+ app_button = gr.Button("Compare models")
70
+ with gr.Row():
71
+ with gr.Column():
72
+ model1_name = gr.Markdown()
73
+ model1_video_output = gr.Video()
74
+ model1_score_output = gr.Textbox(label="Mean Reward +/- Std Reward")
75
+ with gr.Column():
76
+ model2_name = gr.Markdown()
77
+ model2_video_output = gr.Video()
78
+ model2_score_output = gr.Textbox(label="Mean Reward +/- Std Reward")
79
+
80
+ app_button.click(load_agent, inputs=[model1_input, model2_input], outputs=[model1_name, model1_video_output, model1_score_output, model2_name, model2_video_output, model2_score_output])
81
+
82
+ examples = gr.Examples(examples=[["sb3/a2c-AntBulletEnv-v0","sb3/ppo-AntBulletEnv-v0"],
83
+ ["ThomasSimonini/a2c-AntBulletEnv-v0", "sb3/a2c-AntBulletEnv-v0"],
84
+ ["sb3/dqn-SpaceInvadersNoFrameskip-v4", "sb3/a2c-SpaceInvadersNoFrameskip-v4"],
85
+ ["ThomasSimonini/ppo-QbertNoFrameskip-v4","sb3/ppo-QbertNoFrameskip-v4"]],
86
+ inputs=[model1_input, model2_input])
87
+
88
+
89
+ app.launch()