VoyagerYuan
commited on
Commit
·
10ea276
1
Parent(s):
e13d6a0
Update app.py
Browse files
app.py
CHANGED
@@ -234,12 +234,8 @@ def train_signal_game(NUM_SENDERS, NUM_RECEIVERS, num_rounds):
|
|
234 |
loss_plot_placeholder = st.empty() # 创建一个空位占位符来显示损失图
|
235 |
interactions_placeholder = st.empty() # 创建一个空位占位符来显示交互
|
236 |
|
237 |
-
for round
|
238 |
-
|
239 |
-
break
|
240 |
-
states = [batch.to(device) for _ in range(NUM_SENDERS)]
|
241 |
-
# for round in range(num_rounds):
|
242 |
-
# states = [torch.randint(VOCAB_SIZE, (BATCH_SIZE, 16)).to(device) for _ in range(NUM_SENDERS)]
|
243 |
loss, recon_loss, kld_loss, interactions = game.play_round(states)
|
244 |
losses.append(loss)
|
245 |
recon_losses.append(recon_loss)
|
|
|
234 |
loss_plot_placeholder = st.empty() # 创建一个空位占位符来显示损失图
|
235 |
interactions_placeholder = st.empty() # 创建一个空位占位符来显示交互
|
236 |
|
237 |
+
for round in range(num_rounds):
|
238 |
+
states = [torch.randint(VOCAB_SIZE, (BATCH_SIZE, 16)).to(device) for _ in range(NUM_SENDERS)]
|
|
|
|
|
|
|
|
|
239 |
loss, recon_loss, kld_loss, interactions = game.play_round(states)
|
240 |
losses.append(loss)
|
241 |
recon_losses.append(recon_loss)
|