imwithye commited on
Commit
970c966
·
1 Parent(s): e71fda3
rlcube/cube2.ipynb CHANGED
@@ -49,7 +49,7 @@
49
  },
50
  {
51
  "cell_type": "code",
52
- "execution_count": 16,
53
  "id": "defde44e",
54
  "metadata": {},
55
  "outputs": [
@@ -57,17 +57,43 @@
57
  "name": "stdout",
58
  "output_type": "stream",
59
  "text": [
60
- "[2, 3, 7, 6, 8, 6, 3, 2, 2, 5]\n",
61
- "tensor([[ 1.1924],\n",
62
- " [ 0.0826],\n",
63
- " [ 1.0202],\n",
64
- " [ 0.0826],\n",
65
- " [ 1.1121],\n",
66
- " [-0.0302],\n",
67
- " [-1.5963],\n",
68
- " [-0.0302],\n",
69
- " [-1.3707],\n",
70
- " [-2.4068]], grad_fn=<AddmmBackward0>)\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  ]
72
  }
73
  ],
@@ -87,44 +113,55 @@
87
  "obs = torch.tensor(np.array(obs), dtype=torch.float32)\n",
88
  "values, policies = net(obs)\n",
89
  "print(actions)\n",
90
- "print(values)"
 
 
 
 
 
 
91
  ]
92
  },
93
  {
94
  "cell_type": "code",
95
- "execution_count": 18,
96
- "id": "cae20b12",
97
  "metadata": {},
98
  "outputs": [
99
  {
100
- "name": "stderr",
101
- "output_type": "stream",
102
- "text": [
103
- " 14%|█▍ | 43/300 [00:00<00:02, 127.98it/s]"
104
- ]
105
- },
106
- {
107
- "name": "stdout",
108
- "output_type": "stream",
109
- "text": [
110
- "[4, 3, 7, 11]\n"
111
- ]
112
- },
113
- {
114
- "name": "stderr",
115
- "output_type": "stream",
116
- "text": [
117
- "\n"
118
- ]
 
119
  }
120
  ],
121
  "source": [
122
- "from rlcube.models.search import MonteCarloTree\n",
123
- "\n",
124
- "tree = MonteCarloTree(env.obs(), max_simulations=300)\n",
125
- "if tree.is_solved:\n",
126
- " print([action for _, action in tree.solved_path])"
127
  ]
 
 
 
 
 
 
 
 
128
  }
129
  ],
130
  "metadata": {
 
49
  },
50
  {
51
  "cell_type": "code",
52
+ "execution_count": null,
53
  "id": "defde44e",
54
  "metadata": {},
55
  "outputs": [
 
57
  "name": "stdout",
58
  "output_type": "stream",
59
  "text": [
60
+ "[11, 11, 3, 10, 9, 4, 5, 3, 11, 11]\n",
61
+ "tensor([[ 1.2608],\n",
62
+ " [ 0.2146],\n",
63
+ " [-0.8424],\n",
64
+ " [-0.6595],\n",
65
+ " [-0.4404],\n",
66
+ " [-1.2381],\n",
67
+ " [-0.4404],\n",
68
+ " [-1.6949],\n",
69
+ " [-3.1237],\n",
70
+ " [-2.8188]], grad_fn=<AddmmBackward0>)\n"
71
+ ]
72
+ },
73
+ {
74
+ "name": "stderr",
75
+ "output_type": "stream",
76
+ "text": [
77
+ " 9%|▉ | 469/5000 [00:04<00:48, 94.14it/s] \n"
78
+ ]
79
+ },
80
+ {
81
+ "ename": "KeyboardInterrupt",
82
+ "evalue": "",
83
+ "output_type": "error",
84
+ "traceback": [
85
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
86
+ "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
87
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 20\u001b[39m\n\u001b[32m 16\u001b[39m \u001b[38;5;28mprint\u001b[39m(values)\n\u001b[32m 18\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mrlcube\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mmodels\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01msearch\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m MonteCarloTree\n\u001b[32m---> \u001b[39m\u001b[32m20\u001b[39m tree = \u001b[43mMonteCarloTree\u001b[49m\u001b[43m(\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mobs\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_simulations\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m5000\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 21\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m tree.is_solved:\n\u001b[32m 22\u001b[39m \u001b[38;5;28mprint\u001b[39m([action \u001b[38;5;28;01mfor\u001b[39;00m _, action \u001b[38;5;129;01min\u001b[39;00m tree.solved_path])\n",
88
+ "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/Workspace/imwithye/rlcube/rlcube/rlcube/models/search.py:59\u001b[39m, in \u001b[36mMonteCarloTree.__init__\u001b[39m\u001b[34m(self, obs, max_simulations)\u001b[39m\n\u001b[32m 57\u001b[39m \u001b[38;5;28mself\u001b[39m.is_solved = \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[32m 58\u001b[39m \u001b[38;5;28mself\u001b[39m.solved_path = []\n\u001b[32m---> \u001b[39m\u001b[32m59\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_build\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
89
+ "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/Workspace/imwithye/rlcube/rlcube/rlcube/models/search.py:80\u001b[39m, in \u001b[36mMonteCarloTree._build\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 78\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[32m12\u001b[39m):\n\u001b[32m 79\u001b[39m obs = adjacent_obs[i]\n\u001b[32m---> \u001b[39m\u001b[32m80\u001b[39m child = \u001b[43mNode\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnode\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 81\u001b[39m node.children[i] = child\n\u001b[32m 82\u001b[39m \u001b[38;5;28mself\u001b[39m.nodes.append(child)\n",
90
+ "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/Workspace/imwithye/rlcube/rlcube/rlcube/models/search.py:21\u001b[39m, in \u001b[36mNode.__init__\u001b[39m\u001b[34m(self, obs, parent)\u001b[39m\n\u001b[32m 18\u001b[39m value = value.detach()\n\u001b[32m 19\u001b[39m policy = torch.softmax(policy.detach(), dim=\u001b[32m1\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m21\u001b[39m \u001b[38;5;28mself\u001b[39m.is_solved = \u001b[43mCube2Env\u001b[49m\u001b[43m.\u001b[49m\u001b[43mfrom_obs\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobs\u001b[49m\u001b[43m)\u001b[49m.is_solved()\n\u001b[32m 22\u001b[39m \u001b[38;5;28mself\u001b[39m.value = torch.tensor(\u001b[32m1\u001b[39m) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.is_solved \u001b[38;5;28;01melse\u001b[39;00m value.view(-\u001b[32m1\u001b[39m)\n\u001b[32m 23\u001b[39m \u001b[38;5;28mself\u001b[39m.policy = policy.view(-\u001b[32m1\u001b[39m)\n",
91
+ "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/Workspace/imwithye/rlcube/rlcube/rlcube/envs/cube2.py:30\u001b[39m, in \u001b[36mCube2Env.from_obs\u001b[39m\u001b[34m(obs)\u001b[39m\n\u001b[32m 28\u001b[39m idx = i * \u001b[32m4\u001b[39m + j\n\u001b[32m 29\u001b[39m state[i, j] = np.argmax(obs[idx])\n\u001b[32m---> \u001b[39m\u001b[32m30\u001b[39m env = \u001b[43mCube2Env\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 31\u001b[39m env.reset(state=state)\n\u001b[32m 32\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m env\n",
92
+ "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/Workspace/imwithye/rlcube/rlcube/rlcube/envs/cube2.py:16\u001b[39m, in \u001b[36mCube2Env.__init__\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 14\u001b[39m \u001b[38;5;28msuper\u001b[39m(Cube2Env, \u001b[38;5;28mself\u001b[39m).\u001b[34m__init__\u001b[39m()\n\u001b[32m 15\u001b[39m \u001b[38;5;28mself\u001b[39m.action_space = gym.spaces.Discrete(\u001b[32m12\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m16\u001b[39m \u001b[38;5;28mself\u001b[39m.observation_space = \u001b[43mgym\u001b[49m\u001b[43m.\u001b[49m\u001b[43mspaces\u001b[49m\u001b[43m.\u001b[49m\u001b[43mBox\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 17\u001b[39m \u001b[43m \u001b[49m\u001b[43mlow\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhigh\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mshape\u001b[49m\u001b[43m=\u001b[49m\u001b[43m(\u001b[49m\u001b[32;43m24\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[32;43m6\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m=\u001b[49m\u001b[43mnp\u001b[49m\u001b[43m.\u001b[49m\u001b[43mint8\u001b[49m\n\u001b[32m 18\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 19\u001b[39m \u001b[38;5;28mself\u001b[39m.state = np.zeros((\u001b[32m6\u001b[39m, \u001b[32m4\u001b[39m), dtype=np.int8)\n\u001b[32m 20\u001b[39m \u001b[38;5;28mself\u001b[39m.reset()\n",
93
+ "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/Workspace/imwithye/rlcube/rlcube/.venv/lib/python3.12/site-packages/gymnasium/spaces/box.py:149\u001b[39m, in \u001b[36mBox.__init__\u001b[39m\u001b[34m(self, low, high, shape, dtype, seed)\u001b[39m\n\u001b[32m 147\u001b[39m \u001b[38;5;66;03m# Cast `low` and `high` to ndarray for the dtype min and max for out of range tests\u001b[39;00m\n\u001b[32m 148\u001b[39m \u001b[38;5;28mself\u001b[39m.low, \u001b[38;5;28mself\u001b[39m.bounded_below = \u001b[38;5;28mself\u001b[39m._cast_low(low, dtype_min)\n\u001b[32m--> \u001b[39m\u001b[32m149\u001b[39m \u001b[38;5;28mself\u001b[39m.high, \u001b[38;5;28mself\u001b[39m.bounded_above = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_cast_high\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhigh\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype_max\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 151\u001b[39m \u001b[38;5;66;03m# recheck shape for case where shape and (low or high) are provided\u001b[39;00m\n\u001b[32m 152\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.low.shape != shape:\n",
94
+ "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/Workspace/imwithye/rlcube/rlcube/.venv/lib/python3.12/site-packages/gymnasium/spaces/box.py:251\u001b[39m, in \u001b[36mBox._cast_high\u001b[39m\u001b[34m(self, high, dtype_max)\u001b[39m\n\u001b[32m 241\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_cast_high\u001b[39m(\u001b[38;5;28mself\u001b[39m, high, dtype_max) -> \u001b[38;5;28mtuple\u001b[39m[np.ndarray, np.ndarray]:\n\u001b[32m 242\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Casts the input Box high value to ndarray with provided dtype.\u001b[39;00m\n\u001b[32m 243\u001b[39m \n\u001b[32m 244\u001b[39m \u001b[33;03m Args:\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 249\u001b[39m \u001b[33;03m The updated high value and for what values the input is bounded (above)\u001b[39;00m\n\u001b[32m 250\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m251\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mis_float_integer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhigh\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[32m 252\u001b[39m bounded_above = np.full(\u001b[38;5;28mself\u001b[39m.shape, high, dtype=\u001b[38;5;28mfloat\u001b[39m) < np.inf\n\u001b[32m 254\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m np.isnan(high):\n",
95
+ "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/Workspace/imwithye/rlcube/rlcube/.venv/lib/python3.12/site-packages/gymnasium/spaces/box.py:32\u001b[39m, in \u001b[36mis_float_integer\u001b[39m\u001b[34m(var)\u001b[39m\n\u001b[32m 28\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mstr\u001b[39m(np.min(arr))\n\u001b[32m 29\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mstr\u001b[39m(arr)\n\u001b[32m---> \u001b[39m\u001b[32m32\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mis_float_integer\u001b[39m(var: Any) -> \u001b[38;5;28mbool\u001b[39m:\n\u001b[32m 33\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Checks if a scalar variable is an integer or float (does not include bool).\"\"\"\u001b[39;00m\n\u001b[32m 34\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m np.issubdtype(\u001b[38;5;28mtype\u001b[39m(var), np.integer) \u001b[38;5;129;01mor\u001b[39;00m np.issubdtype(\u001b[38;5;28mtype\u001b[39m(var), np.floating)\n",
96
+ "\u001b[31mKeyboardInterrupt\u001b[39m: "
97
  ]
98
  }
99
  ],
 
113
  "obs = torch.tensor(np.array(obs), dtype=torch.float32)\n",
114
  "values, policies = net(obs)\n",
115
  "print(actions)\n",
116
+ "print(values)\n",
117
+ "\n",
118
+ "from rlcube.models.search import MonteCarloTree\n",
119
+ "\n",
120
+ "tree = MonteCarloTree(env.obs(), max_simulations=1000)\n",
121
+ "if tree.is_solved:\n",
122
+ " print([action for _, action in tree.solved_path])"
123
  ]
124
  },
125
  {
126
  "cell_type": "code",
127
+ "execution_count": 6,
128
+ "id": "a91732d7",
129
  "metadata": {},
130
  "outputs": [
131
  {
132
+ "data": {
133
+ "text/plain": [
134
+ "defaultdict(<function rlcube.models.search.Node.__init__.<locals>.<lambda>()>,\n",
135
+ " {0: 400,\n",
136
+ " 1: 0,\n",
137
+ " 2: 0,\n",
138
+ " 3: 0,\n",
139
+ " 4: 0,\n",
140
+ " 5: 0,\n",
141
+ " 6: 0,\n",
142
+ " 7: 0,\n",
143
+ " 8: 0,\n",
144
+ " 9: 0,\n",
145
+ " 10: 44,\n",
146
+ " 11: 0})"
147
+ ]
148
+ },
149
+ "execution_count": 6,
150
+ "metadata": {},
151
+ "output_type": "execute_result"
152
  }
153
  ],
154
  "source": [
155
+ "tree.root.N"
 
 
 
 
156
  ]
157
+ },
158
+ {
159
+ "cell_type": "code",
160
+ "execution_count": null,
161
+ "id": "99d79934",
162
+ "metadata": {},
163
+ "outputs": [],
164
+ "source": []
165
  }
166
  ],
167
  "metadata": {
rlcube/rlcube/models/models.py CHANGED
@@ -76,7 +76,7 @@ class DNN(nn.Module):
76
  torch.save(self.state_dict(), filepath)
77
 
78
  def load(self, filepath: str):
79
- self.load_state_dict(torch.load(filepath))
80
 
81
 
82
  class DNN2(nn.Module):
 
76
  torch.save(self.state_dict(), filepath)
77
 
78
  def load(self, filepath: str):
79
+ self.load_state_dict(torch.load(filepath, map_location=torch.device("cpu")))
80
 
81
 
82
  class DNN2(nn.Module):
src/components/ui-controls.tsx CHANGED
@@ -27,7 +27,7 @@ export const UIControls = () => {
27
  } = useControlContext();
28
 
29
  const scramble = () => {
30
- const scrambleSteps = Array.from({ length: 20 }, () => Actions[Math.floor(Math.random() * Actions.length)]);
31
  rubiksCubeRef?.current?.rotate(scrambleSteps);
32
  };
33
 
 
27
  } = useControlContext();
28
 
29
  const scramble = () => {
30
+ const scrambleSteps = Array.from({ length: 5 }, () => Actions[Math.floor(Math.random() * Actions.length)]);
31
  rubiksCubeRef?.current?.rotate(scrambleSteps);
32
  };
33
 
src/contexts/control-context.tsx CHANGED
@@ -23,7 +23,7 @@ export const ControlContext = createContext<ControlContextType>({
23
  setShowRotationIndicators: () => {},
24
  cubeRoughness: 0.5,
25
  setCubeRoughness: () => {},
26
- cubeSpeed: 2,
27
  setCubeSpeed: () => {},
28
  background: 'sunset',
29
  setBackground: () => {},
@@ -38,7 +38,7 @@ export const useControlContext = () => {
38
  export const ControlProvider = ({ children }: { children: React.ReactNode }) => {
39
  const [showRotationIndicators, setShowRotationIndicators] = useState(false);
40
  const [cubeRoughness, setCubeRoughness] = useState(0.5);
41
- const [cubeSpeed, setCubeSpeed] = useState(2);
42
  const [background, setBackground] = useState<PresetsType>('sunset');
43
  const [rubiksCubeRef, setRubiksCubeRef] = useState<RefObject<RubiksCubeRef | null> | undefined>(undefined);
44
 
 
23
  setShowRotationIndicators: () => {},
24
  cubeRoughness: 0.5,
25
  setCubeRoughness: () => {},
26
+ cubeSpeed: 8,
27
  setCubeSpeed: () => {},
28
  background: 'sunset',
29
  setBackground: () => {},
 
38
  export const ControlProvider = ({ children }: { children: React.ReactNode }) => {
39
  const [showRotationIndicators, setShowRotationIndicators] = useState(false);
40
  const [cubeRoughness, setCubeRoughness] = useState(0.5);
41
+ const [cubeSpeed, setCubeSpeed] = useState(8);
42
  const [background, setBackground] = useState<PresetsType>('sunset');
43
  const [rubiksCubeRef, setRubiksCubeRef] = useState<RefObject<RubiksCubeRef | null> | undefined>(undefined);
44