imwithye commited on
Commit
67ab4fa
Β·
1 Parent(s): 1aa9298
Files changed (2) hide show
  1. README.md +12 -12
  2. rlcube/rlcube/train/train.py +1 -0
README.md CHANGED
@@ -14,21 +14,21 @@ Solve the Rubik's Cube using Reinforcement Learning! πŸš€
14
  ## πŸ‹οΈβ€β™‚οΈ Train the Model
15
 
16
  1. Navigate to the `rlcube` directory:
17
- ```
18
- cd rlcube
19
- ```
20
  2. Install dependencies:
21
- ```
22
- uv sync
23
- ```
24
  3. Activate the virtual environment:
25
- ```
26
- source .venv/bin/activate
27
- ```
28
  4. Start training:
29
- ```
30
- python -m rlcube.train.train
31
- ```
32
 
33
  After training, your model will be saved in the `models` folder.
34
  Please rename the trained file to `model_final.pth` so it can be used by the API. 🎯
 
14
  ## πŸ‹οΈβ€β™‚οΈ Train the Model
15
 
16
  1. Navigate to the `rlcube` directory:
17
+ ```
18
+ cd rlcube
19
+ ```
20
  2. Install dependencies:
21
+ ```
22
+ uv sync
23
+ ```
24
  3. Activate the virtual environment:
25
+ ```
26
+ source .venv/bin/activate
27
+ ```
28
  4. Start training:
29
+ ```
30
+ python -m rlcube.train.train
31
+ ```
32
 
33
  After training, your model will be saved in the `models` folder.
34
  Please rename the trained file to `model_final.pth` so it can be used by the API. 🎯
rlcube/rlcube/train/train.py CHANGED
@@ -59,6 +59,7 @@ def train(epochs: int = 100):
59
  target_values = target_values.detach()
60
 
61
  indices = indices.reshape(-1)
 
62
  weights = 1 / D.reshape(-1).detach()
63
 
64
  loss_v = value_loss_fn(values, target_values).reshape(-1) * weights
 
59
  target_values = target_values.detach()
60
 
61
  indices = indices.reshape(-1)
62
+ indices = indices * masks.reshape(-1)
63
  weights = 1 / D.reshape(-1).detach()
64
 
65
  loss_v = value_loss_fn(values, target_values).reshape(-1) * weights