ClementBM commited on
Commit
3532cbe
1 Parent(s): 7db569a
Files changed (33) hide show
  1. connectfour/app.py +40 -19
  2. connectfour/checkpoint/policies/always_same/policy_state.pkl +0 -3
  3. connectfour/checkpoint/policies/always_same/rllib_checkpoint.json +0 -1
  4. connectfour/checkpoint/policies/beat_last/policy_state.pkl +0 -3
  5. connectfour/checkpoint/policies/beat_last/rllib_checkpoint.json +0 -1
  6. connectfour/checkpoint/policies/learned/policy_state.pkl +0 -3
  7. connectfour/checkpoint/policies/learned/rllib_checkpoint.json +0 -1
  8. connectfour/checkpoint/policies/learned_v1/policy_state.pkl +0 -3
  9. connectfour/checkpoint/policies/learned_v1/rllib_checkpoint.json +0 -1
  10. connectfour/checkpoint/policies/learned_v2/policy_state.pkl +0 -3
  11. connectfour/checkpoint/policies/learned_v2/rllib_checkpoint.json +0 -1
  12. connectfour/checkpoint/policies/learned_v3/policy_state.pkl +0 -3
  13. connectfour/checkpoint/policies/learned_v3/rllib_checkpoint.json +0 -1
  14. connectfour/checkpoint/policies/learned_v4/policy_state.pkl +0 -3
  15. connectfour/checkpoint/policies/learned_v4/rllib_checkpoint.json +0 -1
  16. connectfour/checkpoint/policies/learned_v5/policy_state.pkl +0 -3
  17. connectfour/checkpoint/policies/learned_v5/rllib_checkpoint.json +0 -1
  18. connectfour/checkpoint/policies/linear/policy_state.pkl +0 -3
  19. connectfour/checkpoint/policies/linear/rllib_checkpoint.json +0 -1
  20. connectfour/checkpoint/policies/random/policy_state.pkl +0 -3
  21. connectfour/checkpoint/policies/random/rllib_checkpoint.json +0 -1
  22. connectfour/checkpoint/rllib_checkpoint.json +0 -1
  23. connectfour/training/__pycache__/callbacks.cpython-38.pyc +0 -0
  24. connectfour/training/__pycache__/dummy_policies.cpython-38.pyc +0 -0
  25. connectfour/training/__pycache__/wrappers.cpython-38.pyc +0 -0
  26. connectfour/training/callbacks.py +93 -51
  27. connectfour/training/dummy_policies.py +4 -0
  28. connectfour/training/train.py +117 -45
  29. connectfour/training/wrappers.py +7 -96
  30. models/__init__.py +3 -0
  31. connectfour/checkpoint/algorithm_state.pkl → models/model.onnx +2 -2
  32. poetry.lock +416 -33
  33. pyproject.toml +4 -1
connectfour/app.py CHANGED
@@ -2,25 +2,25 @@ import time
2
 
3
  import gradio as gr
4
  import numpy as np
5
- import ray
6
- import ray.rllib.algorithms.ppo as ppo
7
  from pettingzoo.classic import connect_four_v3
 
8
  from ray.tune import register_env
9
 
10
- from connectfour.checkpoint import CHECKPOINT
11
- from connectfour.training.models import Connect4MaskModel
12
  from connectfour.training.wrappers import Connect4Env
 
 
 
 
13
 
14
- POLICY_ID = "learned_v5"
15
 
16
  # poetry export -f requirements.txt --output requirements.txt --without-hashes
 
17
  # gradio connectfour/app.py
18
 
19
 
20
  class Connect4:
21
  def __init__(self, who_plays_first) -> None:
22
- ray.init(include_dashboard=False, ignore_reinit_error=True)
23
-
24
  # define how to make the environment
25
  env_creator = lambda config: connect_four_v3.env(render_mode="rgb_array")
26
 
@@ -44,25 +44,46 @@ class Connect4:
44
 
45
  return self.render_and_state
46
 
47
- def get_algo(self, checkpoint):
48
- config = (
49
- ppo.PPOConfig()
50
- .environment("connect4")
51
- .framework("torch")
52
- .training(model={"custom_model": Connect4MaskModel})
 
 
 
 
 
 
 
 
53
  )
54
- config.explore = False
55
- self.algo = config.build()
56
- self.algo.restore(checkpoint)
57
 
58
  def play(self, action=None):
59
  if self.has_erroneous_state():
60
  return self.blue_screen()
61
 
62
  if self.human != self.player_id:
63
- action = self.algo.compute_single_action(
64
- self.obs[self.player_id], policy_id=POLICY_ID
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  )
 
66
 
67
  if action not in self.legal_moves:
68
  action = np.random.choice(self.legal_moves)
@@ -114,7 +135,7 @@ demo = gr.Blocks()
114
 
115
  with demo:
116
  connect4 = Connect4("You")
117
- connect4.get_algo(str(CHECKPOINT))
118
 
119
  with gr.Row():
120
  with gr.Column(scale=1):
 
2
 
3
  import gradio as gr
4
  import numpy as np
 
 
5
  from pettingzoo.classic import connect_four_v3
6
+ from ray.rllib.utils.framework import try_import_torch
7
  from ray.tune import register_env
8
 
 
 
9
  from connectfour.training.wrappers import Connect4Env
10
+ from models import MODEL_PATH
11
+ import onnxruntime as ort
12
+ from ray.rllib.algorithms.algorithm import Algorithm
13
+ from connectfour.checkpoint import CHECKPOINT
14
 
15
+ torch, nn = try_import_torch()
16
 
17
  # poetry export -f requirements.txt --output requirements.txt --without-hashes
18
+ # tensorboard --logdir ~/ray_results/
19
  # gradio connectfour/app.py
20
 
21
 
22
  class Connect4:
23
  def __init__(self, who_plays_first) -> None:
 
 
24
  # define how to make the environment
25
  env_creator = lambda config: connect_four_v3.env(render_mode="rgb_array")
26
 
 
44
 
45
  return self.render_and_state
46
 
47
+ def get_algo(self):
48
+ # self.pytorch_model = torch.load(MODEL_PATH / "model.pt")
49
+ # self.algo = Algorithm.from_checkpoint(checkpoint=CHECKPOINT)
50
+ self.session = ort.InferenceSession(str(MODEL_PATH / "model.onnx"), None)
51
+
52
+ def compute_action(self, obs):
53
+ return self.pytorch_model(
54
+ input_dict={"obs": self.flatten_obs(obs)},
55
+ )
56
+
57
+ def flatten_obs(self, obs):
58
+ flatten_action_mask = torch.from_numpy(obs["action_mask"])
59
+ flatten_observation = torch.flatten(
60
+ torch.from_numpy(obs["observation"]), end_dim=2
61
  )
62
+ flatten_obs = torch.concat([flatten_action_mask, flatten_observation])
63
+ return flatten_obs[None, :]
 
64
 
65
  def play(self, action=None):
66
  if self.has_erroneous_state():
67
  return self.blue_screen()
68
 
69
  if self.human != self.player_id:
70
+ # action = self.algo.compute_single_action(
71
+ # self.obs[self.player_id], policy_id="learned_v9"
72
+ # )
73
+ # Torch
74
+ # action = self.compute_action(self.obs[self.player_id])
75
+ # action = int(torch.argmax(action[0]))
76
+ # ONNX
77
+ action = self.session.run(
78
+ ["output"],
79
+ {
80
+ "obs": self.flatten_obs(self.obs[self.player_id])
81
+ .numpy()
82
+ .astype(np.float32),
83
+ "state_ins": [],
84
+ },
85
  )
86
+ action = int(np.argmax(action[0]))
87
 
88
  if action not in self.legal_moves:
89
  action = np.random.choice(self.legal_moves)
 
135
 
136
  with demo:
137
  connect4 = Connect4("You")
138
+ connect4.get_algo()
139
 
140
  with gr.Row():
141
  with gr.Column(scale=1):
connectfour/checkpoint/policies/always_same/policy_state.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d278413093ad1bc4f227279e3dab7be04ebd70ca1ed156a1363515c69d0a858e
3
- size 10992
 
 
 
 
connectfour/checkpoint/policies/always_same/rllib_checkpoint.json DELETED
@@ -1 +0,0 @@
1
- {"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
 
 
connectfour/checkpoint/policies/beat_last/policy_state.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:cd422258c16de0866599730a5a5b2b48e2ee81cbae69f9d5471deeae76c42b47
3
- size 10992
 
 
 
 
connectfour/checkpoint/policies/beat_last/rllib_checkpoint.json DELETED
@@ -1 +0,0 @@
1
- {"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
 
 
connectfour/checkpoint/policies/learned/policy_state.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2a517583e5fcad7e483bca619723583cc6928499390c1fcfc25d907e109cd4b4
3
- size 2139442
 
 
 
 
connectfour/checkpoint/policies/learned/rllib_checkpoint.json DELETED
@@ -1 +0,0 @@
1
- {"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
 
 
connectfour/checkpoint/policies/learned_v1/policy_state.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:276c26007c2419a688c27f9dfa70c20fecb468a0aa07d28d6a9e8099bbc849be
3
- size 2139439
 
 
 
 
connectfour/checkpoint/policies/learned_v1/rllib_checkpoint.json DELETED
@@ -1 +0,0 @@
1
- {"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
 
 
connectfour/checkpoint/policies/learned_v2/policy_state.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e37a485d3a54f7a8b194693e7a61f790e67071358130178fa01cdbd840c4a4da
3
- size 2139439
 
 
 
 
connectfour/checkpoint/policies/learned_v2/rllib_checkpoint.json DELETED
@@ -1 +0,0 @@
1
- {"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
 
 
connectfour/checkpoint/policies/learned_v3/policy_state.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9f90899ae98a387e312333b234041c68b9c50da4af92ee5250686087a39eebb3
3
- size 2139439
 
 
 
 
connectfour/checkpoint/policies/learned_v3/rllib_checkpoint.json DELETED
@@ -1 +0,0 @@
1
- {"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
 
 
connectfour/checkpoint/policies/learned_v4/policy_state.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3af3b3fe41bac489cb693af387b1ccc4437a532a78d539b3abb4cc5f77929592
3
- size 2139439
 
 
 
 
connectfour/checkpoint/policies/learned_v4/rllib_checkpoint.json DELETED
@@ -1 +0,0 @@
1
- {"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
 
 
connectfour/checkpoint/policies/learned_v5/policy_state.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f2b28b979e2f4411d196e03ca75ea7f25f7601bb997aa8bcdcf1d49c9ea30754
3
- size 2139439
 
 
 
 
connectfour/checkpoint/policies/learned_v5/rllib_checkpoint.json DELETED
@@ -1 +0,0 @@
1
- {"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
 
 
connectfour/checkpoint/policies/linear/policy_state.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4f70d44ac661632dc0557204abe34308dfb25b800a668b49c2efd9a2a73a7bc0
3
- size 10992
 
 
 
 
connectfour/checkpoint/policies/linear/rllib_checkpoint.json DELETED
@@ -1 +0,0 @@
1
- {"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
 
 
connectfour/checkpoint/policies/random/policy_state.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f3b1ab86bada035779feedb2b92ae0a64f6d9474bb4f0ae44324e17d65659764
3
- size 10992
 
 
 
 
connectfour/checkpoint/policies/random/rllib_checkpoint.json DELETED
@@ -1 +0,0 @@
1
- {"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
 
 
connectfour/checkpoint/rllib_checkpoint.json DELETED
@@ -1 +0,0 @@
1
- {"type": "Algorithm", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
 
 
connectfour/training/__pycache__/callbacks.cpython-38.pyc CHANGED
Binary files a/connectfour/training/__pycache__/callbacks.cpython-38.pyc and b/connectfour/training/__pycache__/callbacks.cpython-38.pyc differ
 
connectfour/training/__pycache__/dummy_policies.cpython-38.pyc CHANGED
Binary files a/connectfour/training/__pycache__/dummy_policies.cpython-38.pyc and b/connectfour/training/__pycache__/dummy_policies.cpython-38.pyc differ
 
connectfour/training/__pycache__/wrappers.cpython-38.pyc CHANGED
Binary files a/connectfour/training/__pycache__/wrappers.cpython-38.pyc and b/connectfour/training/__pycache__/wrappers.cpython-38.pyc differ
 
connectfour/training/callbacks.py CHANGED
@@ -1,33 +1,44 @@
1
- from ray.rllib.algorithms.callbacks import DefaultCallbacks
 
2
  import numpy as np
 
 
 
 
 
 
 
 
3
 
4
 
5
- def create_self_play_callback(win_rate_thr, opponent_policies):
6
  class SelfPlayCallback(DefaultCallbacks):
7
  win_rate_threshold = win_rate_thr
8
 
9
  def __init__(self):
10
  super().__init__()
11
  self.current_opponent = 0
 
 
 
 
 
 
 
 
12
 
13
  def on_train_result(self, *, algorithm, result, **kwargs):
14
- # Get the win rate for the train batch.
15
- # Note that normally, one should set up a proper evaluation config,
16
- # such that evaluation always happens on the already updated policy,
17
- # instead of on the already used train_batch.
 
 
 
 
18
  main_rew = result["hist_stats"].pop("policy_learned_reward")
19
  opponent_rew = result["hist_stats"].pop("episode_reward")
20
 
21
- if len(main_rew) != len(opponent_rew):
22
- raise Exception(
23
- "len(main_rew) != len(opponent_rew)",
24
- len(main_rew),
25
- len(opponent_rew),
26
- result["hist_stats"].keys(),
27
- "episode len",
28
- len(opponent_rew),
29
- )
30
-
31
  won = 0
32
  for r_main, r_opponent in zip(main_rew, opponent_rew):
33
  if r_main > r_opponent:
@@ -35,54 +46,85 @@ def create_self_play_callback(win_rate_thr, opponent_policies):
35
  win_rate = won / len(main_rew)
36
 
37
  result["win_rate"] = win_rate
38
- print(f"Iter={algorithm.iteration} win-rate={win_rate} -> ", end="")
39
 
40
- # If win rate is good -> Snapshot current policy and play against
41
- # it next, keeping the snapshot fixed and only improving the "learned"
42
- # policy.
43
  if win_rate > self.win_rate_threshold:
44
- self.current_opponent += 1
45
- new_pol_id = f"learned_v{self.current_opponent}"
46
- print(
47
- f"Iter={algorithm.iteration} ### Adding new opponent to the mix ({new_pol_id})."
48
- )
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- # Re-define the mapping function, such that "learned" is forced
51
- # to play against any of the previously played policies
52
- # (excluding "random").
53
  def policy_mapping_fn(agent_id, episode, worker, **kwargs):
54
- # agent_id = [0|1] -> policy depends on episode ID
55
- # This way, we make sure that both policies sometimes play
56
- # (start player) and sometimes agent1 (player to move 2nd).
57
  return (
58
  "learned"
59
  if episode.episode_id % 2 == int(agent_id[-1:])
60
- else np.random.choice(
61
- opponent_policies
62
- + [
63
- f"learned_v{i}"
64
- for i in range(1, self.current_opponent + 1)
65
- ]
66
- )
67
  )
68
 
69
- new_policy = algorithm.add_policy(
70
- policy_id=new_pol_id,
71
- policy_cls=type(algorithm.get_policy("learned")),
72
- policy_mapping_fn=policy_mapping_fn,
73
  )
74
 
75
- # Set the weights of the new policy to the learned policy.
76
- # We'll keep training the learned policy, whereas `new_pol_id` will
77
- # remain fixed.
78
- learned_state = algorithm.get_policy("learned").get_state()
79
- new_policy.set_state(learned_state)
80
- # We need to sync the just copied local weights (from learned policy)
81
- # to all the remote workers as well.
 
 
 
 
 
 
 
82
  algorithm.workers.sync_weights()
 
83
  else:
84
- print("not good enough; will keep learning ...")
 
 
 
 
 
 
 
85
 
86
- result["league_size"] = self.current_opponent + len(opponent_policies) + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  return SelfPlayCallback
 
1
+ from collections import deque
2
+
3
  import numpy as np
4
+ from ray.rllib.algorithms.callbacks import DefaultCallbacks
5
+
6
+ from connectfour.training.dummy_policies import (
7
+ AlwaysSameHeuristic,
8
+ BeatLastHeuristic,
9
+ LinearHeuristic,
10
+ RandomHeuristic,
11
+ )
12
 
13
 
14
+ def create_self_play_callback(win_rate_thr, opponent_policies, opponent_count=10):
15
  class SelfPlayCallback(DefaultCallbacks):
16
  win_rate_threshold = win_rate_thr
17
 
18
  def __init__(self):
19
  super().__init__()
20
  self.current_opponent = 0
21
+ self.opponent_policies = deque(opponent_policies, maxlen=opponent_count)
22
+ self.policy_to_remove = None
23
+ self.frozen_policies = {
24
+ "always_same": AlwaysSameHeuristic,
25
+ "linear": LinearHeuristic,
26
+ "beat_last": BeatLastHeuristic,
27
+ "random": RandomHeuristic,
28
+ }
29
 
30
  def on_train_result(self, *, algorithm, result, **kwargs):
31
+ """Called at the end of Algorithm.train().
32
+
33
+ Args:
34
+ algorithm: Current Algorithm instance.
35
+ result: Dict of results returned from Algorithm.train() call.
36
+ You can mutate this object to add additional metrics.
37
+ kwargs: Forward compatibility placeholder.
38
+ """
39
  main_rew = result["hist_stats"].pop("policy_learned_reward")
40
  opponent_rew = result["hist_stats"].pop("episode_reward")
41
 
 
 
 
 
 
 
 
 
 
 
42
  won = 0
43
  for r_main, r_opponent in zip(main_rew, opponent_rew):
44
  if r_main > r_opponent:
 
46
  win_rate = won / len(main_rew)
47
 
48
  result["win_rate"] = win_rate
49
+ print(f"Iter={algorithm.iteration} win-rate={win_rate}")
50
 
 
 
 
51
  if win_rate > self.win_rate_threshold:
52
+ if len(self.opponent_policies) == self.opponent_policies.maxlen:
53
+ self.policy_to_remove = self.opponent_policies[0]
54
+
55
+ new_pol_id = None
56
+ while new_pol_id is None:
57
+ if np.random.choice(range(6)) == 0:
58
+ new_pol_id = np.random.choice(list(self.frozen_policies.keys()))
59
+ else:
60
+ self.current_opponent += 1
61
+ new_pol_id = f"learned_v{self.current_opponent}"
62
+
63
+ if new_pol_id in self.opponent_policies:
64
+ new_pol_id = None
65
+ else:
66
+ self.opponent_policies.append(new_pol_id)
67
+
68
+ print("Non trainable policies", list(self.opponent_policies))
69
 
 
 
 
70
  def policy_mapping_fn(agent_id, episode, worker, **kwargs):
 
 
 
71
  return (
72
  "learned"
73
  if episode.episode_id % 2 == int(agent_id[-1:])
74
+ else np.random.choice(list(self.opponent_policies))
 
 
 
 
 
 
75
  )
76
 
77
+ print(
78
+ f"Iter={algorithm.iteration} Adding new opponent to the mix ({new_pol_id}). League size {len(self.opponent_policies) + 1}"
 
 
79
  )
80
 
81
+ if new_pol_id in list(self.frozen_policies.keys()):
82
+ new_policy = algorithm.add_policy(
83
+ policy_id=new_pol_id,
84
+ policy_cls=self.frozen_policies[new_pol_id],
85
+ policy_mapping_fn=policy_mapping_fn,
86
+ )
87
+ else:
88
+ new_policy = algorithm.add_policy(
89
+ policy_id=new_pol_id,
90
+ policy_cls=type(algorithm.get_policy("learned")),
91
+ policy_mapping_fn=policy_mapping_fn,
92
+ )
93
+ learned_state = algorithm.get_policy("learned").get_state()
94
+ new_policy.set_state(learned_state)
95
  algorithm.workers.sync_weights()
96
+
97
  else:
98
+ print("Not good enough... Keep learning ...")
99
+
100
+ result["league_size"] = len(self.opponent_policies) + 1
101
+
102
+ def on_evaluate_end(self, *, algorithm, evaluation_metrics, **kwargs):
103
+ """Runs when the evaluation is done.
104
+
105
+ Runs at the end of Algorithm.evaluate().
106
 
107
+ Args:
108
+ algorithm: Reference to the algorithm instance.
109
+ evaluation_metrics: Results dict to be returned from algorithm.evaluate().
110
+ You can mutate this object to add additional metrics.
111
+ kwargs: Forward compatibility placeholder.
112
+ """
113
+
114
+ def policy_mapping_fn(agent_id, episode, worker, **kwargs):
115
+ return (
116
+ "learned"
117
+ if episode.episode_id % 2 == int(agent_id[-1:])
118
+ else np.random.choice(list(self.opponent_policies))
119
+ )
120
+
121
+ if self.policy_to_remove is not None:
122
+ print("Remove ", self.policy_to_remove, "from opponent policies")
123
+ algorithm.remove_policy(
124
+ self.policy_to_remove,
125
+ policy_mapping_fn=policy_mapping_fn,
126
+ )
127
+ self.policy_to_remove = None
128
+ algorithm.workers.sync_weights()
129
 
130
  return SelfPlayCallback
connectfour/training/dummy_policies.py CHANGED
@@ -23,6 +23,10 @@ class HeuristicBase(Policy):
23
  """No weights to set."""
24
  pass
25
 
 
 
 
 
26
  @override(Policy)
27
  def compute_actions(
28
  self,
 
23
  """No weights to set."""
24
  pass
25
 
26
+ @override(Policy)
27
+ def export_model(self, export_dir: str, onnx=None) -> None:
28
+ pass
29
+
30
  @override(Policy)
31
  def compute_actions(
32
  self,
connectfour/training/train.py CHANGED
@@ -8,6 +8,7 @@ from ray import air, tune
8
  from ray.rllib.policy.policy import PolicySpec
9
  from ray.rllib.utils.framework import try_import_torch
10
  from ray.tune import CLIReporter, register_env
 
11
 
12
  from connectfour.training.callbacks import create_self_play_callback
13
  from connectfour.training.dummy_policies import (
@@ -29,12 +30,21 @@ def get_cli_args():
29
  python connectfour/training/train.py --num-cpus 4 --num-gpus 1 --stop-iters 10 --win-rate-threshold 0.50
30
  python connectfour/training/train.py --num-gpus 1 --stop-iters 1 --win-rate-threshold 0.50
31
  python connectfour/training/train.py --num-cpus 5 --num-gpus 1 --stop-iters 200
 
 
 
32
  """
33
  parser = argparse.ArgumentParser()
34
  parser.add_argument("--num-cpus", type=int, default=0)
35
  parser.add_argument("--num-gpus", type=int, default=0)
36
  parser.add_argument("--num-workers", type=int, default=2)
37
-
 
 
 
 
 
 
38
  parser.add_argument(
39
  "--stop-iters", type=int, default=200, help="Number of iterations to train."
40
  )
@@ -57,13 +67,6 @@ def get_cli_args():
57
  return args
58
 
59
 
60
- def select_policy(agent_id, episode, **kwargs):
61
- if episode.episode_id % 2 == int(agent_id[-1:]):
62
- return "learned"
63
- else:
64
- return random.choice(["always_same", "beat_last", "random", "linear"])
65
-
66
-
67
  if __name__ == "__main__":
68
  args = get_cli_args()
69
 
@@ -80,20 +83,23 @@ if __name__ == "__main__":
80
  # register that way to make the environment under an rllib name
81
  register_env("connect4", lambda config: Connect4Env(env_creator(config)))
82
 
 
 
 
 
 
 
83
  config = (
84
- ppo.PPOConfig()
85
- .environment("connect4")
86
- .framework("torch")
87
- .training(model={"custom_model": Connect4MaskModel})
88
- .callbacks(
89
- create_self_play_callback(
90
- win_rate_thr=args.win_rate_threshold,
91
- opponent_policies=["always_same", "beat_last", "random", "linear"],
92
  )
93
- )
94
- .rollouts(
95
- num_rollout_workers=args.num_workers,
96
- num_envs_per_worker=5,
97
  )
98
  .multi_agent(
99
  policies={
@@ -106,19 +112,88 @@ if __name__ == "__main__":
106
  policy_mapping_fn=select_policy,
107
  policies_to_train=["learned"],
108
  )
 
 
 
 
 
 
 
 
109
  )
110
 
111
- stop = {
112
- "timesteps_total": args.stop_timesteps,
113
- "training_iteration": args.stop_iters,
114
- }
115
-
116
- results = tune.Tuner(
117
- "PPO",
118
- param_space=config.to_dict(),
119
- run_config=air.RunConfig(
120
- stop=stop,
121
- verbose=2,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  progress_reporter=CLIReporter(
123
  metric_columns={
124
  "training_iteration": "iter",
@@ -128,21 +203,18 @@ if __name__ == "__main__":
128
  "policy_reward_mean/learned": "reward",
129
  "win_rate": "win_rate",
130
  "league_size": "league_size",
131
- },
132
- mode="max",
133
- metric="win_rate",
134
- sort_by_metric=True,
135
- ),
136
- checkpoint_config=air.CheckpointConfig(
137
- checkpoint_at_end=True,
138
- checkpoint_frequency=10,
139
  ),
140
- ),
141
- ).fit()
142
 
143
- print(
144
- "Best checkpoint",
145
- results.get_best_result(metric="win_rate", mode="max").checkpoint,
146
- )
 
 
 
 
 
147
 
148
  ray.shutdown()
 
8
  from ray.rllib.policy.policy import PolicySpec
9
  from ray.rllib.utils.framework import try_import_torch
10
  from ray.tune import CLIReporter, register_env
11
+ from ray.rllib.algorithms.algorithm import Algorithm
12
 
13
  from connectfour.training.callbacks import create_self_play_callback
14
  from connectfour.training.dummy_policies import (
 
30
  python connectfour/training/train.py --num-cpus 4 --num-gpus 1 --stop-iters 10 --win-rate-threshold 0.50
31
  python connectfour/training/train.py --num-gpus 1 --stop-iters 1 --win-rate-threshold 0.50
32
  python connectfour/training/train.py --num-cpus 5 --num-gpus 1 --stop-iters 200
33
+ python connectfour/training/train.py --num-cpus 5 --num-gpus 1 --win-rate-threshold 0.95 --stop-iters 2000 > training.log 2>&1
34
+ python connectfour/training/train.py --num-cpus 5 --num-gpus 1 --win-rate-threshold 0.96 --stop-iters 10000 > training.log 2>&1
35
+ python connectfour/training/train.py --num-cpus 5 --num-gpus 1 --win-rate-threshold 0.99 --stop-iters 5000 --from-checkpoint ~/ray_results/PPO/PPO_connect4_8414a_00000_0_2023-04-03_12-44-31/checkpoint_004000 > training.log 2>&1
36
  """
37
  parser = argparse.ArgumentParser()
38
  parser.add_argument("--num-cpus", type=int, default=0)
39
  parser.add_argument("--num-gpus", type=int, default=0)
40
  parser.add_argument("--num-workers", type=int, default=2)
41
+ parser.add_argument(
42
+ "--from-checkpoint",
43
+ type=str,
44
+ default=None,
45
+ help="Full path to a experiment directory to resume tuning from "
46
+ "a previously saved Algorithm state.",
47
+ )
48
  parser.add_argument(
49
  "--stop-iters", type=int, default=200, help="Number of iterations to train."
50
  )
 
67
  return args
68
 
69
 
 
 
 
 
 
 
 
70
  if __name__ == "__main__":
71
  args = get_cli_args()
72
 
 
83
  # register that way to make the environment under an rllib name
84
  register_env("connect4", lambda config: Connect4Env(env_creator(config)))
85
 
86
+ def select_policy(agent_id, episode, **kwargs):
87
+ if episode.episode_id % 2 == int(agent_id[-1:]):
88
+ return "learned"
89
+ else:
90
+ return random.choice(["always_same", "beat_last", "random", "linear"])
91
+
92
  config = (
93
+ (
94
+ ppo.PPOConfig()
95
+ .environment("connect4")
96
+ .framework("torch")
97
+ .training(model={"custom_model": Connect4MaskModel})
98
+ .rollouts(
99
+ num_rollout_workers=args.num_workers,
100
+ num_envs_per_worker=5,
101
  )
102
+ # .checkpointing(checkpoint_trainable_policies_only=True)
 
 
 
103
  )
104
  .multi_agent(
105
  policies={
 
112
  policy_mapping_fn=select_policy,
113
  policies_to_train=["learned"],
114
  )
115
+ .callbacks(
116
+ create_self_play_callback(
117
+ win_rate_thr=args.win_rate_threshold,
118
+ opponent_policies=["always_same", "beat_last", "random", "linear"],
119
+ opponent_count=15,
120
+ )
121
+ )
122
+ .evaluation(evaluation_interval=1)
123
  )
124
 
125
+ if args.from_checkpoint is None:
126
+ stop = {
127
+ "timesteps_total": args.stop_timesteps,
128
+ "training_iteration": args.stop_iters,
129
+ }
130
+
131
+ results = tune.Tuner(
132
+ "PPO",
133
+ param_space=config.to_dict(),
134
+ run_config=air.RunConfig(
135
+ stop=stop,
136
+ verbose=2,
137
+ progress_reporter=CLIReporter(
138
+ metric_columns={
139
+ "training_iteration": "iter",
140
+ "time_total_s": "time_total_s",
141
+ "timesteps_total": "ts",
142
+ "episodes_this_iter": "train_episodes",
143
+ "policy_reward_mean/learned": "reward",
144
+ "win_rate": "win_rate",
145
+ "league_size": "league_size",
146
+ },
147
+ mode="max",
148
+ metric="win_rate",
149
+ sort_by_metric=True,
150
+ ),
151
+ checkpoint_config=air.CheckpointConfig(
152
+ num_to_keep=10,
153
+ checkpoint_at_end=True,
154
+ checkpoint_frequency=10,
155
+ checkpoint_score_order="max",
156
+ ),
157
+ ),
158
+ ).fit()
159
+
160
+ best_checkpoint = results.get_best_result(
161
+ metric="win_rate", mode="max"
162
+ ).checkpoint
163
+ print("Best checkpoint", best_checkpoint)
164
+
165
+ else:
166
+ algo = Algorithm.from_checkpoint(checkpoint=args.from_checkpoint)
167
+
168
+ config = algo.config.copy(False)
169
+ config.checkpointing(export_native_model_files=True)
170
+
171
+ opponent_policies = list(algo.workers.local_worker().policy_map.keys())
172
+ opponent_policies.remove("learned")
173
+ opponent_policies.sort()
174
+
175
+ config.callbacks(
176
+ create_self_play_callback(
177
+ win_rate_thr=args.win_rate_threshold,
178
+ opponent_policies=opponent_policies,
179
+ opponent_count=len(opponent_policies),
180
+ )
181
+ )
182
+ config.evaluation(evaluation_interval=None)
183
+
184
+ analysis = tune.run(
185
+ "PPO",
186
+ config=config.to_dict(),
187
+ restore=args.from_checkpoint,
188
+ checkpoint_freq=10,
189
+ checkpoint_at_end=True,
190
+ keep_checkpoints_num=10,
191
+ mode="max",
192
+ metric="win_rate",
193
+ stop={
194
+ "win_rate": args.win_rate_threshold,
195
+ "training_iteration": args.stop_iters,
196
+ },
197
  progress_reporter=CLIReporter(
198
  metric_columns={
199
  "training_iteration": "iter",
 
203
  "policy_reward_mean/learned": "reward",
204
  "win_rate": "win_rate",
205
  "league_size": "league_size",
206
+ }
 
 
 
 
 
 
 
207
  ),
208
+ )
 
209
 
210
+ algo = Algorithm.from_checkpoint(analysis.best_checkpoint)
211
+ ppo_policy = algo.get_policy("learned")
212
+
213
+ # Save as torch model
214
+ ppo_policy.export_model("models")
215
+ # Save as ONNX model
216
+ ppo_policy.export_model("models", onnx=11)
217
+
218
+ print("Best checkpoint", analysis.best_checkpoint)
219
 
220
  ray.shutdown()
connectfour/training/wrappers.py CHANGED
@@ -1,112 +1,23 @@
1
  from typing import Optional
2
 
3
- from ray.rllib.env.multi_agent_env import MultiAgentEnv
4
  from ray.rllib.utils.annotations import PublicAPI
5
- from ray.rllib.utils.gym import convert_old_gym_space_to_gymnasium_space
6
 
7
 
8
  @PublicAPI
9
- class Connect4Env(MultiAgentEnv):
10
- """An interface to the PettingZoo MARL environment library.
11
-
12
- See: https://github.com/Farama-Foundation/PettingZoo
13
-
14
- Inherits from MultiAgentEnv and exposes a given AEC
15
- (actor-environment-cycle) game from the PettingZoo project via the
16
- MultiAgentEnv public API.
17
-
18
- Note that the wrapper has some important limitations:
19
-
20
- 1. All agents have the same action_spaces and observation_spaces.
21
- Note: If, within your aec game, agents do not have homogeneous action /
22
- observation spaces, apply SuperSuit wrappers
23
- to apply padding functionality: https://github.com/Farama-Foundation/
24
- SuperSuit#built-in-multi-agent-only-functions
25
- 2. Environments are positive sum games (-> Agents are expected to cooperate
26
- to maximize reward). This isn't a hard restriction, it just that
27
- standard algorithms aren't expected to work well in highly competitive
28
- games."""
29
-
30
- def __init__(self, env):
31
- super().__init__()
32
- self.env = env
33
- env.reset()
34
-
35
- # Since all agents have the same spaces, do not provide full observation-
36
- # and action-spaces as Dicts, mapping agent IDs to the individual
37
- # agents' spaces. Instead, `self.[action|observation]_space` are the single
38
- # agent spaces.
39
- self._obs_space_in_preferred_format = False
40
- self._action_space_in_preferred_format = False
41
-
42
- # Collect the individual agents' spaces (they should all be the same):
43
- first_obs_space = self.env.observation_space(self.env.agents[0])
44
- first_action_space = self.env.action_space(self.env.agents[0])
45
-
46
- for agent in self.env.agents:
47
- if self.env.observation_space(agent) != first_obs_space:
48
- raise ValueError(
49
- "Observation spaces for all agents must be identical. Perhaps "
50
- "SuperSuit's pad_observations wrapper can help (useage: "
51
- "`supersuit.aec_wrappers.pad_observations(env)`"
52
- )
53
- if self.env.action_space(agent) != first_action_space:
54
- raise ValueError(
55
- "Action spaces for all agents must be identical. Perhaps "
56
- "SuperSuit's pad_action_space wrapper can help (usage: "
57
- "`supersuit.aec_wrappers.pad_action_space(env)`"
58
- )
59
-
60
- # Convert from gym to gymnasium, if necessary.
61
- self.observation_space = convert_old_gym_space_to_gymnasium_space(
62
- first_obs_space
63
- )
64
- self.action_space = convert_old_gym_space_to_gymnasium_space(first_action_space)
65
-
66
- self._agent_ids = set(self.env.agents)
67
 
68
  def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
 
 
69
  info = self.env.reset(seed=seed, options=options)
70
  return (
71
  {self.env.agent_selection: self.env.observe(self.env.agent_selection)},
72
  info or {},
73
  )
74
 
75
- def step(self, action):
76
- self.env.step(action[self.env.agent_selection])
77
- obs_d = {}
78
- rew_d = {}
79
- terminated_d = {}
80
- truncated_d = {}
81
- info_d = {}
82
- while self.env.agents:
83
- obs, rew, terminated, truncated, info = self.env.last()
84
- agent_id = self.env.agent_selection
85
- obs_d[agent_id] = obs
86
- rew_d[agent_id] = rew
87
- terminated_d[agent_id] = terminated
88
- truncated_d[agent_id] = truncated
89
- info_d[agent_id] = info
90
- if (
91
- self.env.terminations[self.env.agent_selection]
92
- or self.env.truncations[self.env.agent_selection]
93
- ):
94
- self.env.step(None)
95
- else:
96
- break
97
-
98
- all_gone = not self.env.agents
99
- terminated_d["__all__"] = all_gone and all(terminated_d.values())
100
- truncated_d["__all__"] = all_gone and all(truncated_d.values())
101
-
102
- return obs_d, rew_d, terminated_d, truncated_d, info_d
103
-
104
- def close(self):
105
- self.env.close()
106
-
107
  def render(self):
 
 
108
  return self.env.render()
109
-
110
- @property
111
- def get_sub_environments(self):
112
- return self.env.unwrapped
 
1
  from typing import Optional
2
 
 
3
  from ray.rllib.utils.annotations import PublicAPI
4
+ from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
5
 
6
 
7
  @PublicAPI
8
+ class Connect4Env(PettingZooEnv):
9
+ """An interface to the PettingZoo MARL environment library"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
12
+ # In base class =>
13
+ # info = self.env.reset(seed=seed, return_info=True, options=options)
14
  info = self.env.reset(seed=seed, options=options)
15
  return (
16
  {self.env.agent_selection: self.env.observe(self.env.agent_selection)},
17
  info or {},
18
  )
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def render(self):
21
+ # In base class =>
22
+ # return self.env.render(self.render_mode)
23
  return self.env.render()
 
 
 
 
models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ MODEL_PATH = Path(__file__).parent.absolute()
connectfour/checkpoint/algorithm_state.pkl → models/model.onnx RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cbbc198c3406897931f5f18046a88181b8abff1aedbea1d869329731c9a50853
3
- size 66321
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae8bf3e7080eca6ba6c8e68dbceffa03ca8d05f65249c474544a66e039352d2a
3
+ size 361882
poetry.lock CHANGED
@@ -290,6 +290,18 @@ d = ["aiohttp (>=3.7.4)"]
290
  jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
291
  uvloop = ["uvloop (>=0.15.2)"]
292
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  [[package]]
294
  name = "certifi"
295
  version = "2022.12.7"
@@ -456,6 +468,24 @@ files = [
456
  {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
457
  ]
458
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
  [[package]]
460
  name = "contourpy"
461
  version = "1.0.7"
@@ -692,6 +722,18 @@ files = [
692
  docs = ["furo (>=2022.12.7)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)"]
693
  testing = ["covdefaults (>=2.3)", "coverage (>=7.2.2)", "diff-cover (>=7.5)", "pytest (>=7.2.2)", "pytest-cov (>=4)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"]
694
 
 
 
 
 
 
 
 
 
 
 
 
 
695
  [[package]]
696
  name = "fonttools"
697
  version = "4.39.3"
@@ -848,16 +890,60 @@ files = [
848
  {file = "gast-0.5.3.tar.gz", hash = "sha256:cfbea25820e653af9c7d1807f659ce0a0a9c64f2439421a7bba4f0983f532dea"},
849
  ]
850
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
851
  [[package]]
852
  name = "gradio"
853
- version = "3.23.0"
854
  description = "Python library for easily interacting with trained machine learning models"
855
  category = "main"
856
  optional = false
857
  python-versions = ">=3.7"
858
  files = [
859
- {file = "gradio-3.23.0-py3-none-any.whl", hash = "sha256:1f637f80e4b740892c3cc5ec559158afe2b6b97f0a18eca588c235ddedddc7ec"},
860
- {file = "gradio-3.23.0.tar.gz", hash = "sha256:bfb9f59d799271029e6309207c7f42ed8adb297dcb8bbe31c83c3ebf32dbe193"},
861
  ]
862
 
863
  [package.dependencies]
@@ -866,7 +952,7 @@ aiohttp = "*"
866
  altair = ">=4.2.0"
867
  fastapi = "*"
868
  ffmpy = "*"
869
- fsspec = "*"
870
  httpx = "*"
871
  huggingface-hub = ">=0.13.0"
872
  jinja2 = "*"
@@ -888,6 +974,25 @@ typing-extensions = "*"
888
  uvicorn = "*"
889
  websockets = ">=10.0"
890
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
891
  [[package]]
892
  name = "grpcio"
893
  version = "1.49.1"
@@ -1138,6 +1243,21 @@ testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "jedi", "pytest", "pytest
1138
  torch = ["torch"]
1139
  typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
1140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1141
  [[package]]
1142
  name = "idna"
1143
  version = "3.4"
@@ -1531,6 +1651,24 @@ docs = ["sphinx (>=1.6.0)", "sphinx-bootstrap-theme"]
1531
  flake8 = ["flake8"]
1532
  tests = ["psutil", "pytest (!=3.3.0)", "pytest-cov"]
1533
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1534
  [[package]]
1535
  name = "markdown-it-py"
1536
  version = "2.2.0"
@@ -2184,6 +2322,111 @@ files = [
2184
  setuptools = "*"
2185
  wheel = "*"
2186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2187
  [[package]]
2188
  name = "orjson"
2189
  version = "3.8.8"
@@ -2471,39 +2714,65 @@ testing = ["pytest", "pytest-benchmark"]
2471
 
2472
  [[package]]
2473
  name = "protobuf"
2474
- version = "3.17.0"
2475
  description = "Protocol Buffers"
2476
  category = "main"
2477
  optional = false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2478
  python-versions = "*"
2479
  files = [
2480
- {file = "protobuf-3.17.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:15351df904347da2081a2eebc42b192c29724eb57dbe56dae440be843f1e4779"},
2481
- {file = "protobuf-3.17.0-cp27-cp27mu-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:5356981c1919782b8c2e3ea5c5d85ad5937b8178a025ac9edc2f2ca5b4a717ae"},
2482
- {file = "protobuf-3.17.0-cp35-cp35m-macosx_10_9_intel.whl", hash = "sha256:eac0a2a7ea99e17175f6e7b53cdc9004ed786c072fbdf933def0e454e14fd323"},
2483
- {file = "protobuf-3.17.0-cp35-cp35m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:4c8d0997fdc0a4cf9de7950d598ce6974b22e8618bbcf1d15e9842010cf8420a"},
2484
- {file = "protobuf-3.17.0-cp35-cp35m-win32.whl", hash = "sha256:9ae321459d4890c3939c536382f75e232c9e91ce506310353c8a15ad5c379e0d"},
2485
- {file = "protobuf-3.17.0-cp35-cp35m-win_amd64.whl", hash = "sha256:295944ef0772498d7bf75f6aa5d4dfcfd02f5ce70f735b406e52e43ac3914d38"},
2486
- {file = "protobuf-3.17.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:850f429bd2399525d339d05bc809f090f16d3d88737bed637d355a5ee8d3b81a"},
2487
- {file = "protobuf-3.17.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:809a96d5a1a74538728710f9104f43ae77f5e48bde274ee321b10a324ba52e4f"},
2488
- {file = "protobuf-3.17.0-cp36-cp36m-win32.whl", hash = "sha256:8a3ac375539055164f31a330770f137875307e6f04c21e2647f2e7139c501295"},
2489
- {file = "protobuf-3.17.0-cp36-cp36m-win_amd64.whl", hash = "sha256:3d338910b10b88b18581cf6877b3938b2e262e8fdc2c1057f5a291787de63183"},
2490
- {file = "protobuf-3.17.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:1488f786bd1912f97796cf5def8cacf433735616896cf7ed9dc786cee693dfc8"},
2491
- {file = "protobuf-3.17.0-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:bcaff977db178f0bfde10bab0d23a5f5adf5964adba70c315e45922a1c55eb90"},
2492
- {file = "protobuf-3.17.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:939ce06846ddfec99c0bff510510b3ee45778e7a3aec6544d1f36526e5fecb67"},
2493
- {file = "protobuf-3.17.0-cp37-cp37m-win32.whl", hash = "sha256:3237acce5b666c7b0f45785cc2d0809796d4df3593bd68338aebf25408139188"},
2494
- {file = "protobuf-3.17.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2f77afe33bb86c7d34221a86193256d69aa10818620fe4a7513d98211d67d672"},
2495
- {file = "protobuf-3.17.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:acc9f2091ace3de429eee424ab7ba0bc52a6aa9ffc9909e5c4de259a3f71db46"},
2496
- {file = "protobuf-3.17.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:a29631f4f8bcf79b12a59e83d238d888de5034871461d788c74c68218ad75049"},
2497
- {file = "protobuf-3.17.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:05c304396e309661c45e3a97bd2d8da1fc2bab743ed2ca880bcb757271c40c0e"},
2498
- {file = "protobuf-3.17.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:baea44967071e6a51e705e4e88aebf35f530a14004cc69f60a185e5d7e13de7e"},
2499
- {file = "protobuf-3.17.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:3b5c461af5a3cebd796c73370db929b7e24cbaba655eefdc044226bc8a843d6b"},
2500
- {file = "protobuf-3.17.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:44399393c3a8cc04a4cfbdc721dd7f2114497efda582e946a91b8c4290ae5ff5"},
2501
- {file = "protobuf-3.17.0-py2.py3-none-any.whl", hash = "sha256:e32ef0c9f4b548c80d94dfff8b4130ca2ff3d50caaf2455889e3f5b8a01e8038"},
2502
- {file = "protobuf-3.17.0.tar.gz", hash = "sha256:05dfe9319939a8473c21b469f34f6486646e54fb8542637cf7ed8e2fbfe21538"},
2503
- ]
2504
-
2505
- [package.dependencies]
2506
- six = ">=1.9"
2507
 
2508
  [[package]]
2509
  name = "pydantic"
@@ -2705,6 +2974,18 @@ files = [
2705
  [package.extras]
2706
  diagrams = ["jinja2", "railroad-diagrams"]
2707
 
 
 
 
 
 
 
 
 
 
 
 
 
2708
  [[package]]
2709
  name = "pyrsistent"
2710
  version = "0.19.3"
@@ -2997,6 +3278,25 @@ urllib3 = ">=1.21.1,<1.27"
2997
  socks = ["PySocks (>=1.5.6,!=1.5.7)"]
2998
  use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
2999
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3000
  [[package]]
3001
  name = "rfc3986"
3002
  version = "1.5.0"
@@ -3035,6 +3335,21 @@ typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9
3035
  [package.extras]
3036
  jupyter = ["ipywidgets (>=7.5.1,<9)"]
3037
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3038
  [[package]]
3039
  name = "scikit-image"
3040
  version = "0.20.0"
@@ -3231,6 +3546,56 @@ files = [
3231
  [package.extras]
3232
  widechars = ["wcwidth"]
3233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3234
  [[package]]
3235
  name = "tensorboardx"
3236
  version = "2.6"
@@ -3625,6 +3990,24 @@ files = [
3625
  {file = "websockets-10.4.tar.gz", hash = "sha256:eef610b23933c54d5d921c92578ae5f89813438fded840c2e9809d378dc765d3"},
3626
  ]
3627
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3628
  [[package]]
3629
  name = "wheel"
3630
  version = "0.40.0"
@@ -3832,4 +4215,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more
3832
  [metadata]
3833
  lock-version = "2.0"
3834
  python-versions = ">=3.8,<3.11"
3835
- content-hash = "6ac937208f9e895063d08a5437db35ed4c5d75275bcf70e8503c609f42b8f302"
 
290
  jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
291
  uvloop = ["uvloop (>=0.15.2)"]
292
 
293
+ [[package]]
294
+ name = "cachetools"
295
+ version = "5.3.0"
296
+ description = "Extensible memoizing collections and decorators"
297
+ category = "main"
298
+ optional = false
299
+ python-versions = "~=3.7"
300
+ files = [
301
+ {file = "cachetools-5.3.0-py3-none-any.whl", hash = "sha256:429e1a1e845c008ea6c85aa35d4b98b65d6a9763eeef3e37e92728a12d1de9d4"},
302
+ {file = "cachetools-5.3.0.tar.gz", hash = "sha256:13dfddc7b8df938c21a940dfa6557ce6e94a2f1cdfa58eb90c805721d58f2c14"},
303
+ ]
304
+
305
  [[package]]
306
  name = "certifi"
307
  version = "2022.12.7"
 
468
  {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
469
  ]
470
 
471
+ [[package]]
472
+ name = "coloredlogs"
473
+ version = "15.0.1"
474
+ description = "Colored terminal output for Python's logging module"
475
+ category = "main"
476
+ optional = false
477
+ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
478
+ files = [
479
+ {file = "coloredlogs-15.0.1-py2.py3-none-any.whl", hash = "sha256:612ee75c546f53e92e70049c9dbfcc18c935a2b9a53b66085ce9ef6a6e5c0934"},
480
+ {file = "coloredlogs-15.0.1.tar.gz", hash = "sha256:7c991aa71a4577af2f82600d8f8f3a89f936baeaf9b50a9c197da014e5bf16b0"},
481
+ ]
482
+
483
+ [package.dependencies]
484
+ humanfriendly = ">=9.1"
485
+
486
+ [package.extras]
487
+ cron = ["capturer (>=2.4)"]
488
+
489
  [[package]]
490
  name = "contourpy"
491
  version = "1.0.7"
 
722
  docs = ["furo (>=2022.12.7)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)"]
723
  testing = ["covdefaults (>=2.3)", "coverage (>=7.2.2)", "diff-cover (>=7.5)", "pytest (>=7.2.2)", "pytest-cov (>=4)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"]
724
 
725
+ [[package]]
726
+ name = "flatbuffers"
727
+ version = "23.3.3"
728
+ description = "The FlatBuffers serialization format for Python"
729
+ category = "main"
730
+ optional = false
731
+ python-versions = "*"
732
+ files = [
733
+ {file = "flatbuffers-23.3.3-py2.py3-none-any.whl", hash = "sha256:5ad36d376240090757e8f0a2cfaf6abcc81c6536c0dc988060375fd0899121f8"},
734
+ {file = "flatbuffers-23.3.3.tar.gz", hash = "sha256:cabd87c4882f37840f6081f094b2c5bc28cefc2a6357732746936d055ab45c3d"},
735
+ ]
736
+
737
  [[package]]
738
  name = "fonttools"
739
  version = "4.39.3"
 
890
  {file = "gast-0.5.3.tar.gz", hash = "sha256:cfbea25820e653af9c7d1807f659ce0a0a9c64f2439421a7bba4f0983f532dea"},
891
  ]
892
 
893
+ [[package]]
894
+ name = "google-auth"
895
+ version = "2.17.1"
896
+ description = "Google Authentication Library"
897
+ category = "main"
898
+ optional = false
899
+ python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*"
900
+ files = [
901
+ {file = "google-auth-2.17.1.tar.gz", hash = "sha256:8f379b46bad381ad2a0b989dfb0c13ad28d3c2a79f27348213f8946a1d15d55a"},
902
+ {file = "google_auth-2.17.1-py2.py3-none-any.whl", hash = "sha256:357ff22a75b4c0f6093470f21816a825d2adee398177569824e37b6c10069e19"},
903
+ ]
904
+
905
+ [package.dependencies]
906
+ cachetools = ">=2.0.0,<6.0"
907
+ pyasn1-modules = ">=0.2.1"
908
+ rsa = {version = ">=3.1.4,<5", markers = "python_version >= \"3.6\""}
909
+ six = ">=1.9.0"
910
+
911
+ [package.extras]
912
+ aiohttp = ["aiohttp (>=3.6.2,<4.0.0dev)", "requests (>=2.20.0,<3.0.0dev)"]
913
+ enterprise-cert = ["cryptography (==36.0.2)", "pyopenssl (==22.0.0)"]
914
+ pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"]
915
+ reauth = ["pyu2f (>=0.1.5)"]
916
+ requests = ["requests (>=2.20.0,<3.0.0dev)"]
917
+
918
+ [[package]]
919
+ name = "google-auth-oauthlib"
920
+ version = "0.4.6"
921
+ description = "Google Authentication Library"
922
+ category = "main"
923
+ optional = false
924
+ python-versions = ">=3.6"
925
+ files = [
926
+ {file = "google-auth-oauthlib-0.4.6.tar.gz", hash = "sha256:a90a072f6993f2c327067bf65270046384cda5a8ecb20b94ea9a687f1f233a7a"},
927
+ {file = "google_auth_oauthlib-0.4.6-py2.py3-none-any.whl", hash = "sha256:3f2a6e802eebbb6fb736a370fbf3b055edcb6b52878bf2f26330b5e041316c73"},
928
+ ]
929
+
930
+ [package.dependencies]
931
+ google-auth = ">=1.0.0"
932
+ requests-oauthlib = ">=0.7.0"
933
+
934
+ [package.extras]
935
+ tool = ["click (>=6.0.0)"]
936
+
937
  [[package]]
938
  name = "gradio"
939
+ version = "3.24.0"
940
  description = "Python library for easily interacting with trained machine learning models"
941
  category = "main"
942
  optional = false
943
  python-versions = ">=3.7"
944
  files = [
945
+ {file = "gradio-3.24.0-py3-none-any.whl", hash = "sha256:cedd67f7cbd17764b3613fb4df274a7c450c74e31a2e3229097d43cb4ffa50c7"},
946
+ {file = "gradio-3.24.0.tar.gz", hash = "sha256:4ac2bf531b3c0ff5ec9e93959f2d1dbc49eac1767bafa2d80f8950a3bc40c4ed"},
947
  ]
948
 
949
  [package.dependencies]
 
952
  altair = ">=4.2.0"
953
  fastapi = "*"
954
  ffmpy = "*"
955
+ gradio-client = ">=0.0.5"
956
  httpx = "*"
957
  huggingface-hub = ">=0.13.0"
958
  jinja2 = "*"
 
974
  uvicorn = "*"
975
  websockets = ">=10.0"
976
 
977
+ [[package]]
978
+ name = "gradio-client"
979
+ version = "0.0.5"
980
+ description = "Python library for easily interacting with trained machine learning models"
981
+ category = "main"
982
+ optional = false
983
+ python-versions = ">=3.7"
984
+ files = [
985
+ {file = "gradio_client-0.0.5-py3-none-any.whl", hash = "sha256:ca4167ebae72d920ebec2be47010cf60e31e0296ad9baac771befb17b87f0eef"},
986
+ {file = "gradio_client-0.0.5.tar.gz", hash = "sha256:dc6479a119314aac0bbf6821da6e946df17f048cc571559379a89590618f7b5d"},
987
+ ]
988
+
989
+ [package.dependencies]
990
+ fsspec = "*"
991
+ huggingface-hub = ">=0.13.0"
992
+ packaging = "*"
993
+ requests = "*"
994
+ websockets = "*"
995
+
996
  [[package]]
997
  name = "grpcio"
998
  version = "1.49.1"
 
1243
  torch = ["torch"]
1244
  typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
1245
 
1246
+ [[package]]
1247
+ name = "humanfriendly"
1248
+ version = "10.0"
1249
+ description = "Human friendly output for text interfaces using Python"
1250
+ category = "main"
1251
+ optional = false
1252
+ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
1253
+ files = [
1254
+ {file = "humanfriendly-10.0-py2.py3-none-any.whl", hash = "sha256:1697e1a8a8f550fd43c2865cd84542fc175a61dcb779b6fee18cf6b6ccba1477"},
1255
+ {file = "humanfriendly-10.0.tar.gz", hash = "sha256:6b0b831ce8f15f7300721aa49829fc4e83921a9a301cc7f606be6686a2288ddc"},
1256
+ ]
1257
+
1258
+ [package.dependencies]
1259
+ pyreadline3 = {version = "*", markers = "sys_platform == \"win32\" and python_version >= \"3.8\""}
1260
+
1261
  [[package]]
1262
  name = "idna"
1263
  version = "3.4"
 
1651
  flake8 = ["flake8"]
1652
  tests = ["psutil", "pytest (!=3.3.0)", "pytest-cov"]
1653
 
1654
+ [[package]]
1655
+ name = "markdown"
1656
+ version = "3.4.3"
1657
+ description = "Python implementation of John Gruber's Markdown."
1658
+ category = "main"
1659
+ optional = false
1660
+ python-versions = ">=3.7"
1661
+ files = [
1662
+ {file = "Markdown-3.4.3-py3-none-any.whl", hash = "sha256:065fd4df22da73a625f14890dd77eb8040edcbd68794bcd35943be14490608b2"},
1663
+ {file = "Markdown-3.4.3.tar.gz", hash = "sha256:8bf101198e004dc93e84a12a7395e31aac6a9c9942848ae1d99b9d72cf9b3520"},
1664
+ ]
1665
+
1666
+ [package.dependencies]
1667
+ importlib-metadata = {version = ">=4.4", markers = "python_version < \"3.10\""}
1668
+
1669
+ [package.extras]
1670
+ testing = ["coverage", "pyyaml"]
1671
+
1672
  [[package]]
1673
  name = "markdown-it-py"
1674
  version = "2.2.0"
 
2322
  setuptools = "*"
2323
  wheel = "*"
2324
 
2325
+ [[package]]
2326
+ name = "oauthlib"
2327
+ version = "3.2.2"
2328
+ description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic"
2329
+ category = "main"
2330
+ optional = false
2331
+ python-versions = ">=3.6"
2332
+ files = [
2333
+ {file = "oauthlib-3.2.2-py3-none-any.whl", hash = "sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca"},
2334
+ {file = "oauthlib-3.2.2.tar.gz", hash = "sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918"},
2335
+ ]
2336
+
2337
+ [package.extras]
2338
+ rsa = ["cryptography (>=3.0.0)"]
2339
+ signals = ["blinker (>=1.4.0)"]
2340
+ signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"]
2341
+
2342
+ [[package]]
2343
+ name = "onnx"
2344
+ version = "1.12.0"
2345
+ description = "Open Neural Network Exchange"
2346
+ category = "main"
2347
+ optional = false
2348
+ python-versions = "*"
2349
+ files = [
2350
+ {file = "onnx-1.12.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:bdbd2578424c70836f4d0f9dda16c21868ddb07cc8192f9e8a176908b43d694b"},
2351
+ {file = "onnx-1.12.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:213e73610173f6b2e99f99a4b0636f80b379c417312079d603806e48ada4ca8b"},
2352
+ {file = "onnx-1.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fd2f4e23078df197bb76a59b9cd8f5a43a6ad2edc035edb3ecfb9042093e05a"},
2353
+ {file = "onnx-1.12.0-cp310-cp310-win32.whl", hash = "sha256:23781594bb8b7ee985de1005b3c601648d5b0568a81e01365c48f91d1f5648e4"},
2354
+ {file = "onnx-1.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:81a3555fd67be2518bf86096299b48fb9154652596219890abfe90bd43a9ec13"},
2355
+ {file = "onnx-1.12.0-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:5578b93dc6c918cec4dee7fb7d9dd3b09d338301ee64ca8b4f28bc217ed42dca"},
2356
+ {file = "onnx-1.12.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c11162ffc487167da140f1112f49c4f82d815824f06e58bc3095407699f05863"},
2357
+ {file = "onnx-1.12.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:341c7016e23273e9ffa9b6e301eee95b8c37d0f04df7cedbdb169d2c39524c96"},
2358
+ {file = "onnx-1.12.0-cp37-cp37m-win32.whl", hash = "sha256:3c6e6bcffc3f5c1e148df3837dc667fa4c51999788c1b76b0b8fbba607e02da8"},
2359
+ {file = "onnx-1.12.0-cp37-cp37m-win_amd64.whl", hash = "sha256:8a7aa61aea339bd28f310f4af4f52ce6c4b876386228760b16308efd58f95059"},
2360
+ {file = "onnx-1.12.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:56ceb7e094c43882b723cfaa107d85ad673cfdf91faeb28d7dcadacca4f43a07"},
2361
+ {file = "onnx-1.12.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b3629e8258db15d4e2c9b7f1be91a3186719dd94661c218c6f5fde3cc7de3d4d"},
2362
+ {file = "onnx-1.12.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d9a7db54e75529160337232282a4816cc50667dc7dc34be178fd6f6b79d4705"},
2363
+ {file = "onnx-1.12.0-cp38-cp38-win32.whl", hash = "sha256:fea5156a03398fe0e23248042d8651c1eaac5f6637d4dd683b4c1f1320b9f7b4"},
2364
+ {file = "onnx-1.12.0-cp38-cp38-win_amd64.whl", hash = "sha256:f66d2996e65f490a57b3ae952e4e9189b53cc9fe3f75e601d50d4db2dc1b1cd9"},
2365
+ {file = "onnx-1.12.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:c39a7a0352c856f1df30dccf527eb6cb4909052e5eaf6fa2772a637324c526aa"},
2366
+ {file = "onnx-1.12.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fab13feb4d94342aae6d357d480f2e47d41b9f4e584367542b21ca6defda9e0a"},
2367
+ {file = "onnx-1.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7a9b3ea02c30efc1d2662337e280266aca491a8e86be0d8a657f874b7cccd1e"},
2368
+ {file = "onnx-1.12.0-cp39-cp39-win32.whl", hash = "sha256:f8800f28c746ab06e51ef8449fd1215621f4ddba91be3ffc264658937d38a2af"},
2369
+ {file = "onnx-1.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:af90427ca04c6b7b8107c2021e1273227a3ef1a7a01f3073039cae7855a59833"},
2370
+ {file = "onnx-1.12.0.tar.gz", hash = "sha256:13b3e77d27523b9dbf4f30dfc9c959455859d5e34e921c44f712d69b8369eff9"},
2371
+ ]
2372
+
2373
+ [package.dependencies]
2374
+ numpy = ">=1.16.6"
2375
+ protobuf = ">=3.12.2,<=3.20.1"
2376
+ typing-extensions = ">=3.6.2.1"
2377
+
2378
+ [package.extras]
2379
+ lint = ["clang-format (==13.0.0)", "flake8", "mypy (==0.782)", "types-protobuf (==3.18.4)"]
2380
+
2381
+ [[package]]
2382
+ name = "onnxruntime"
2383
+ version = "1.14.1"
2384
+ description = "ONNX Runtime is a runtime accelerator for Machine Learning models"
2385
+ category = "main"
2386
+ optional = false
2387
+ python-versions = "*"
2388
+ files = [
2389
+ {file = "onnxruntime-1.14.1-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:193ef1ac512e530c6e6e259c26e67212e2cd3f2bfaad6ff935ed3f4281053056"},
2390
+ {file = "onnxruntime-1.14.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d2853bbb36cb272d99f6c225e5040eb0ddb37a667fce20d186ecdf0a6fac8af8"},
2391
+ {file = "onnxruntime-1.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e1b173365c6894616b8207e23cbb891da9638c5373668d6653e4081ef5f04d0"},
2392
+ {file = "onnxruntime-1.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:24bf0401c5f92be7230ac660ff07ba06f7c175e99e225d5d48ff09062a3b76e9"},
2393
+ {file = "onnxruntime-1.14.1-cp310-cp310-manylinux_2_27_aarch64.whl", hash = "sha256:0a2d09260bbdbe1df678e0a237a5f7b1a44fd11a2f52688d8b6a53a9d03a26db"},
2394
+ {file = "onnxruntime-1.14.1-cp310-cp310-manylinux_2_27_x86_64.whl", hash = "sha256:d99d35b9d5c3f46cad1673a39cc753fb57d60784369b59e6f8cd3dfb77df1885"},
2395
+ {file = "onnxruntime-1.14.1-cp310-cp310-win32.whl", hash = "sha256:f400356df1b27d9adc5513319e8a89753e48ef0d6c5084caf5db8e132f46e7e8"},
2396
+ {file = "onnxruntime-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:96a4059dbab162fe5cdb6750f8c70b2106ef2de5d49a7f72085171937d0e36d3"},
2397
+ {file = "onnxruntime-1.14.1-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:fa23df6a349218636290f9fe56d7baaceb1a50cf92255234d495198b47d92327"},
2398
+ {file = "onnxruntime-1.14.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc70e44d9e123d126648da24ffb39e56464272a1660a3eb91f4f5b74263be3ba"},
2399
+ {file = "onnxruntime-1.14.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:deff8138045a3affb6be064b598e3ec69a88e4d445359c50464ee5379b8eaf19"},
2400
+ {file = "onnxruntime-1.14.1-cp37-cp37m-manylinux_2_27_aarch64.whl", hash = "sha256:7c02acdc1107cbf698dcbf6dadc6f5b6aa179e7fa9a026251e99cf8613bd3129"},
2401
+ {file = "onnxruntime-1.14.1-cp37-cp37m-manylinux_2_27_x86_64.whl", hash = "sha256:6efa3b2f4b1eaa6c714c07861993bfd9bb33bd73cdbcaf5b4aadcf1ec13fcaf7"},
2402
+ {file = "onnxruntime-1.14.1-cp37-cp37m-win32.whl", hash = "sha256:72fc0acc82c54bf03eba065ad9025baa438c00c54a2ee0beb8ae4b6085cd3a0d"},
2403
+ {file = "onnxruntime-1.14.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4d6f08ea40d63ccf90f203f4a2a498f4e590737dcaf16867075cc8e0a86c5554"},
2404
+ {file = "onnxruntime-1.14.1-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:c2d9e8f1bc6037f14d8aaa480492792c262fc914936153e40b06b3667bb25549"},
2405
+ {file = "onnxruntime-1.14.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e7424d3befdd95b537c90787bbfaa053b2bb19eb60135abb898cb0e099d7d7ad"},
2406
+ {file = "onnxruntime-1.14.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9066d275e6e41d0597e234d2d88c074d4325e650c74a9527a52cadbcf42a0fe2"},
2407
+ {file = "onnxruntime-1.14.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8224d3c1f2cd0b899cea7b5a39f28b971debe0da30fcbc61382801d97d6f5740"},
2408
+ {file = "onnxruntime-1.14.1-cp38-cp38-manylinux_2_27_aarch64.whl", hash = "sha256:f4ac52ff4ac793683ebd1fbd1ee24197e3b4ca825ee68ff739296a820867debe"},
2409
+ {file = "onnxruntime-1.14.1-cp38-cp38-manylinux_2_27_x86_64.whl", hash = "sha256:b1dd8cdd3be36c32ddd8f5763841ed571c3e81da59439a622947bd97efee6e77"},
2410
+ {file = "onnxruntime-1.14.1-cp38-cp38-win32.whl", hash = "sha256:95d0f0cd95360c07f1c3ba20962b9bb813627df4bfc1b4b274e1d40044df5ad1"},
2411
+ {file = "onnxruntime-1.14.1-cp38-cp38-win_amd64.whl", hash = "sha256:de40a558e00fc00f92e298d5be99eb8075dba51368dabcb259670a00f4670e56"},
2412
+ {file = "onnxruntime-1.14.1-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:c65b587a42a89fceceaad367bd69d071ee5c9c7010b76e2adac5e9efd9356fb5"},
2413
+ {file = "onnxruntime-1.14.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6e47ef6a2c6e6dd6ff48bc13f2331d124dff00e1d76627624bb3268c8058f19c"},
2414
+ {file = "onnxruntime-1.14.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0afd0f671d068dd99b9d071d88e93a9a57a5ed59af440c0f4d65319ee791603f"},
2415
+ {file = "onnxruntime-1.14.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc65e9061349cdf98ce16b37722b557109f16076632fbfed9a3151895cfd3bb7"},
2416
+ {file = "onnxruntime-1.14.1-cp39-cp39-manylinux_2_27_aarch64.whl", hash = "sha256:2ff17c71187391a71e6ccc78ca89aed83bcaed1c085c95267ab1a70897868bdd"},
2417
+ {file = "onnxruntime-1.14.1-cp39-cp39-manylinux_2_27_x86_64.whl", hash = "sha256:9b795189916942ce848192200dde5b1f32799ee6c84fc600969a44d88e8a5404"},
2418
+ {file = "onnxruntime-1.14.1-cp39-cp39-win32.whl", hash = "sha256:17ca3100112af045118750d24643a01ed4e6d86071a8efaef75cc1d434ea64aa"},
2419
+ {file = "onnxruntime-1.14.1-cp39-cp39-win_amd64.whl", hash = "sha256:b5e8c489329ba0fa0639dfd7ec02d6b07cece1bab52ef83884b537247efbda74"},
2420
+ ]
2421
+
2422
+ [package.dependencies]
2423
+ coloredlogs = "*"
2424
+ flatbuffers = "*"
2425
+ numpy = ">=1.21.6"
2426
+ packaging = "*"
2427
+ protobuf = "*"
2428
+ sympy = "*"
2429
+
2430
  [[package]]
2431
  name = "orjson"
2432
  version = "3.8.8"
 
2714
 
2715
  [[package]]
2716
  name = "protobuf"
2717
+ version = "3.19.6"
2718
  description = "Protocol Buffers"
2719
  category = "main"
2720
  optional = false
2721
+ python-versions = ">=3.5"
2722
+ files = [
2723
+ {file = "protobuf-3.19.6-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:010be24d5a44be7b0613750ab40bc8b8cedc796db468eae6c779b395f50d1fa1"},
2724
+ {file = "protobuf-3.19.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11478547958c2dfea921920617eb457bc26867b0d1aa065ab05f35080c5d9eb6"},
2725
+ {file = "protobuf-3.19.6-cp310-cp310-win32.whl", hash = "sha256:559670e006e3173308c9254d63facb2c03865818f22204037ab76f7a0ff70b5f"},
2726
+ {file = "protobuf-3.19.6-cp310-cp310-win_amd64.whl", hash = "sha256:347b393d4dd06fb93a77620781e11c058b3b0a5289262f094379ada2920a3730"},
2727
+ {file = "protobuf-3.19.6-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:a8ce5ae0de28b51dff886fb922012dad885e66176663950cb2344c0439ecb473"},
2728
+ {file = "protobuf-3.19.6-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90b0d02163c4e67279ddb6dc25e063db0130fc299aefabb5d481053509fae5c8"},
2729
+ {file = "protobuf-3.19.6-cp36-cp36m-win32.whl", hash = "sha256:30f5370d50295b246eaa0296533403961f7e64b03ea12265d6dfce3a391d8992"},
2730
+ {file = "protobuf-3.19.6-cp36-cp36m-win_amd64.whl", hash = "sha256:0c0714b025ec057b5a7600cb66ce7c693815f897cfda6d6efb58201c472e3437"},
2731
+ {file = "protobuf-3.19.6-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:5057c64052a1f1dd7d4450e9aac25af6bf36cfbfb3a1cd89d16393a036c49157"},
2732
+ {file = "protobuf-3.19.6-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:bb6776bd18f01ffe9920e78e03a8676530a5d6c5911934c6a1ac6eb78973ecb6"},
2733
+ {file = "protobuf-3.19.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:84a04134866861b11556a82dd91ea6daf1f4925746b992f277b84013a7cc1229"},
2734
+ {file = "protobuf-3.19.6-cp37-cp37m-win32.whl", hash = "sha256:4bc98de3cdccfb5cd769620d5785b92c662b6bfad03a202b83799b6ed3fa1fa7"},
2735
+ {file = "protobuf-3.19.6-cp37-cp37m-win_amd64.whl", hash = "sha256:aa3b82ca1f24ab5326dcf4ea00fcbda703e986b22f3d27541654f749564d778b"},
2736
+ {file = "protobuf-3.19.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2b2d2913bcda0e0ec9a784d194bc490f5dc3d9d71d322d070b11a0ade32ff6ba"},
2737
+ {file = "protobuf-3.19.6-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:d0b635cefebd7a8a0f92020562dead912f81f401af7e71f16bf9506ff3bdbb38"},
2738
+ {file = "protobuf-3.19.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a552af4dc34793803f4e735aabe97ffc45962dfd3a237bdde242bff5a3de684"},
2739
+ {file = "protobuf-3.19.6-cp38-cp38-win32.whl", hash = "sha256:0469bc66160180165e4e29de7f445e57a34ab68f49357392c5b2f54c656ab25e"},
2740
+ {file = "protobuf-3.19.6-cp38-cp38-win_amd64.whl", hash = "sha256:91d5f1e139ff92c37e0ff07f391101df77e55ebb97f46bbc1535298d72019462"},
2741
+ {file = "protobuf-3.19.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c0ccd3f940fe7f3b35a261b1dd1b4fc850c8fde9f74207015431f174be5976b3"},
2742
+ {file = "protobuf-3.19.6-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:30a15015d86b9c3b8d6bf78d5b8c7749f2512c29f168ca259c9d7727604d0e39"},
2743
+ {file = "protobuf-3.19.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:878b4cd080a21ddda6ac6d1e163403ec6eea2e206cf225982ae04567d39be7b0"},
2744
+ {file = "protobuf-3.19.6-cp39-cp39-win32.whl", hash = "sha256:5a0d7539a1b1fb7e76bf5faa0b44b30f812758e989e59c40f77a7dab320e79b9"},
2745
+ {file = "protobuf-3.19.6-cp39-cp39-win_amd64.whl", hash = "sha256:bbf5cea5048272e1c60d235c7bd12ce1b14b8a16e76917f371c718bd3005f045"},
2746
+ {file = "protobuf-3.19.6-py2.py3-none-any.whl", hash = "sha256:14082457dc02be946f60b15aad35e9f5c69e738f80ebbc0900a19bc83734a5a4"},
2747
+ {file = "protobuf-3.19.6.tar.gz", hash = "sha256:5f5540d57a43042389e87661c6eaa50f47c19c6176e8cf1c4f287aeefeccb5c4"},
2748
+ ]
2749
+
2750
+ [[package]]
2751
+ name = "pyasn1"
2752
+ version = "0.4.8"
2753
+ description = "ASN.1 types and codecs"
2754
+ category = "main"
2755
+ optional = false
2756
+ python-versions = "*"
2757
+ files = [
2758
+ {file = "pyasn1-0.4.8-py2.py3-none-any.whl", hash = "sha256:39c7e2ec30515947ff4e87fb6f456dfc6e84857d34be479c9d4a4ba4bf46aa5d"},
2759
+ {file = "pyasn1-0.4.8.tar.gz", hash = "sha256:aef77c9fb94a3ac588e87841208bdec464471d9871bd5050a287cc9a475cd0ba"},
2760
+ ]
2761
+
2762
+ [[package]]
2763
+ name = "pyasn1-modules"
2764
+ version = "0.2.8"
2765
+ description = "A collection of ASN.1-based protocols modules."
2766
+ category = "main"
2767
+ optional = false
2768
  python-versions = "*"
2769
  files = [
2770
+ {file = "pyasn1-modules-0.2.8.tar.gz", hash = "sha256:905f84c712230b2c592c19470d3ca8d552de726050d1d1716282a1f6146be65e"},
2771
+ {file = "pyasn1_modules-0.2.8-py2.py3-none-any.whl", hash = "sha256:a50b808ffeb97cb3601dd25981f6b016cbb3d31fbf57a8b8a87428e6158d0c74"},
2772
+ ]
2773
+
2774
+ [package.dependencies]
2775
+ pyasn1 = ">=0.4.6,<0.5.0"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2776
 
2777
  [[package]]
2778
  name = "pydantic"
 
2974
  [package.extras]
2975
  diagrams = ["jinja2", "railroad-diagrams"]
2976
 
2977
+ [[package]]
2978
+ name = "pyreadline3"
2979
+ version = "3.4.1"
2980
+ description = "A python implementation of GNU readline."
2981
+ category = "main"
2982
+ optional = false
2983
+ python-versions = "*"
2984
+ files = [
2985
+ {file = "pyreadline3-3.4.1-py3-none-any.whl", hash = "sha256:b0efb6516fd4fb07b45949053826a62fa4cb353db5be2bbb4a7aa1fdd1e345fb"},
2986
+ {file = "pyreadline3-3.4.1.tar.gz", hash = "sha256:6f3d1f7b8a31ba32b73917cefc1f28cc660562f39aea8646d30bd6eff21f7bae"},
2987
+ ]
2988
+
2989
  [[package]]
2990
  name = "pyrsistent"
2991
  version = "0.19.3"
 
3278
  socks = ["PySocks (>=1.5.6,!=1.5.7)"]
3279
  use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
3280
 
3281
+ [[package]]
3282
+ name = "requests-oauthlib"
3283
+ version = "1.3.1"
3284
+ description = "OAuthlib authentication support for Requests."
3285
+ category = "main"
3286
+ optional = false
3287
+ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
3288
+ files = [
3289
+ {file = "requests-oauthlib-1.3.1.tar.gz", hash = "sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a"},
3290
+ {file = "requests_oauthlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5"},
3291
+ ]
3292
+
3293
+ [package.dependencies]
3294
+ oauthlib = ">=3.0.0"
3295
+ requests = ">=2.0.0"
3296
+
3297
+ [package.extras]
3298
+ rsa = ["oauthlib[signedtoken] (>=3.0.0)"]
3299
+
3300
  [[package]]
3301
  name = "rfc3986"
3302
  version = "1.5.0"
 
3335
  [package.extras]
3336
  jupyter = ["ipywidgets (>=7.5.1,<9)"]
3337
 
3338
+ [[package]]
3339
+ name = "rsa"
3340
+ version = "4.9"
3341
+ description = "Pure-Python RSA implementation"
3342
+ category = "main"
3343
+ optional = false
3344
+ python-versions = ">=3.6,<4"
3345
+ files = [
3346
+ {file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"},
3347
+ {file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"},
3348
+ ]
3349
+
3350
+ [package.dependencies]
3351
+ pyasn1 = ">=0.1.3"
3352
+
3353
  [[package]]
3354
  name = "scikit-image"
3355
  version = "0.20.0"
 
3546
  [package.extras]
3547
  widechars = ["wcwidth"]
3548
 
3549
+ [[package]]
3550
+ name = "tensorboard"
3551
+ version = "2.12.0"
3552
+ description = "TensorBoard lets you watch Tensors Flow"
3553
+ category = "main"
3554
+ optional = false
3555
+ python-versions = ">=3.8"
3556
+ files = [
3557
+ {file = "tensorboard-2.12.0-py3-none-any.whl", hash = "sha256:3cbdc32448d7a28dc1bf0b1754760c08b8e0e2e37c451027ebd5ff4896613012"},
3558
+ ]
3559
+
3560
+ [package.dependencies]
3561
+ absl-py = ">=0.4"
3562
+ google-auth = ">=1.6.3,<3"
3563
+ google-auth-oauthlib = ">=0.4.1,<0.5"
3564
+ grpcio = ">=1.48.2"
3565
+ markdown = ">=2.6.8"
3566
+ numpy = ">=1.12.0"
3567
+ protobuf = ">=3.19.6"
3568
+ requests = ">=2.21.0,<3"
3569
+ setuptools = ">=41.0.0"
3570
+ tensorboard-data-server = ">=0.7.0,<0.8.0"
3571
+ tensorboard-plugin-wit = ">=1.6.0"
3572
+ werkzeug = ">=1.0.1"
3573
+ wheel = ">=0.26"
3574
+
3575
+ [[package]]
3576
+ name = "tensorboard-data-server"
3577
+ version = "0.7.0"
3578
+ description = "Fast data loading for TensorBoard"
3579
+ category = "main"
3580
+ optional = false
3581
+ python-versions = ">=3.7"
3582
+ files = [
3583
+ {file = "tensorboard_data_server-0.7.0-py3-none-any.whl", hash = "sha256:753d4214799b31da7b6d93837959abebbc6afa86e69eacf1e9a317a48daa31eb"},
3584
+ {file = "tensorboard_data_server-0.7.0-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:eb7fa518737944dbf4f0cf83c2e40a7ac346bf91be2e6a0215de98be74e85454"},
3585
+ {file = "tensorboard_data_server-0.7.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64aa1be7c23e80b1a42c13b686eb0875bb70f5e755f4d2b8de5c1d880cf2267f"},
3586
+ ]
3587
+
3588
+ [[package]]
3589
+ name = "tensorboard-plugin-wit"
3590
+ version = "1.8.1"
3591
+ description = "What-If Tool TensorBoard plugin."
3592
+ category = "main"
3593
+ optional = false
3594
+ python-versions = "*"
3595
+ files = [
3596
+ {file = "tensorboard_plugin_wit-1.8.1-py3-none-any.whl", hash = "sha256:ff26bdd583d155aa951ee3b152b3d0cffae8005dc697f72b44a8e8c2a77a8cbe"},
3597
+ ]
3598
+
3599
  [[package]]
3600
  name = "tensorboardx"
3601
  version = "2.6"
 
3990
  {file = "websockets-10.4.tar.gz", hash = "sha256:eef610b23933c54d5d921c92578ae5f89813438fded840c2e9809d378dc765d3"},
3991
  ]
3992
 
3993
+ [[package]]
3994
+ name = "werkzeug"
3995
+ version = "2.2.3"
3996
+ description = "The comprehensive WSGI web application library."
3997
+ category = "main"
3998
+ optional = false
3999
+ python-versions = ">=3.7"
4000
+ files = [
4001
+ {file = "Werkzeug-2.2.3-py3-none-any.whl", hash = "sha256:56433961bc1f12533306c624f3be5e744389ac61d722175d543e1751285da612"},
4002
+ {file = "Werkzeug-2.2.3.tar.gz", hash = "sha256:2e1ccc9417d4da358b9de6f174e3ac094391ea1d4fbef2d667865d819dfd0afe"},
4003
+ ]
4004
+
4005
+ [package.dependencies]
4006
+ MarkupSafe = ">=2.1.1"
4007
+
4008
+ [package.extras]
4009
+ watchdog = ["watchdog"]
4010
+
4011
  [[package]]
4012
  name = "wheel"
4013
  version = "0.40.0"
 
4215
  [metadata]
4216
  lock-version = "2.0"
4217
  python-versions = ">=3.8,<3.11"
4218
+ content-hash = "81eac0c68b289dd9d22d17dd34ae6bcd31f53c8201f229d31ff96d7094ad2392"
pyproject.toml CHANGED
@@ -23,8 +23,11 @@ pygame = "^2.3.0"
23
  torch = "^2.0.0"
24
  libclang = "15.0.6.1"
25
  tensorflow-probability = "^0.19.0"
26
- protobuf = "3.17.0"
27
  scipy = ">=1.8,<1.9.2"
 
 
 
28
 
29
  [tool.poetry.dev-dependencies]
30
  pylint = "*"
 
23
  torch = "^2.0.0"
24
  libclang = "15.0.6.1"
25
  tensorflow-probability = "^0.19.0"
26
+ protobuf = "3.19.6"
27
  scipy = ">=1.8,<1.9.2"
28
+ onnx = "1.12.0"
29
+ tensorboard = "^2.12.0"
30
+ onnxruntime = "^1.14.1"
31
 
32
  [tool.poetry.dev-dependencies]
33
  pylint = "*"