AgentVerse commited on
Commit
6ba08e7
1 Parent(s): df98a45

fix: fix bug of image switch

Browse files
Files changed (1) hide show
  1. app.py +12 -8
app.py CHANGED
@@ -62,12 +62,8 @@ class GUI:
62
  def get_avatar(self, idx):
63
  if idx == -1:
64
  img = cv2.imread("./imgs/db_diag/-1.png")
65
- elif self.task == "prisoner_dilemma":
66
  img = cv2.imread(f"./imgs/prison/{idx}.png")
67
- elif self.task == "db_diag":
68
- img = cv2.imread(f"./imgs/db_diag/{idx}.png")
69
- elif "sde" in self.task:
70
- img = cv2.imread(f"./imgs/sde/{idx}.png")
71
  else:
72
  img = cv2.imread(f"./imgs/{idx}.png")
73
  base64_str = cv2.imencode(".png", img)[1].tostring()
@@ -178,6 +174,13 @@ class GUI:
178
  self.backend = TaskSolving.from_task(task_dropdown, self.tasks_dir)
179
  else:
180
  self.backend = Simulation.from_task(task_dropdown, self.tasks_dir)
 
 
 
 
 
 
 
181
  self.backend.reset()
182
  self.turns_remain = self.backend.environment.max_turns
183
 
@@ -218,7 +221,7 @@ class GUI:
218
  # if len(data) != self.stu_num:
219
  if len(data) != self.stu_num + 1:
220
  raise gr.Error("data length is not equal to the total number of students.")
221
- if self.task == "prisoner_dilemma":
222
  img = cv2.imread("./imgs/speaking.png", cv2.IMREAD_UNCHANGED)
223
  if (
224
  len(self.messages) < 2
@@ -327,7 +330,9 @@ class GUI:
327
  """
328
 
329
  # data = self.backend.next_data()
 
330
  return_message = self.backend.next()
 
331
  data = self.return_format(return_message)
332
 
333
  # data.sort(key=lambda item: item["sender"])
@@ -460,7 +465,7 @@ class GUI:
460
  task_dropdown = gr.Dropdown(
461
  choices=[
462
  "simulation/nlp_classroom_9players",
463
- "simulation/prisoner_dilemma"
464
  ],
465
  value="simulation/nlp_classroom_9players",
466
  label="Task",
@@ -590,6 +595,5 @@ class GUI:
590
  demo.queue(concurrency_count=5, max_size=20).launch()
591
  # demo.launch()
592
 
593
-
594
 
595
  GUI().launch()
 
62
  def get_avatar(self, idx):
63
  if idx == -1:
64
  img = cv2.imread("./imgs/db_diag/-1.png")
65
+ elif self.task == "simulation/prisoner_dilemma":
66
  img = cv2.imread(f"./imgs/prison/{idx}.png")
 
 
 
 
67
  else:
68
  img = cv2.imread(f"./imgs/{idx}.png")
69
  base64_str = cv2.imencode(".png", img)[1].tostring()
 
174
  self.backend = TaskSolving.from_task(task_dropdown, self.tasks_dir)
175
  else:
176
  self.backend = Simulation.from_task(task_dropdown, self.tasks_dir)
177
+ self.agent_id = {
178
+ self.backend.agents[idx].name: idx
179
+ for idx in range(len(self.backend.agents))
180
+ }
181
+
182
+ self.task = task_dropdown
183
+ self.stu_num = len(self.agent_id) - 1
184
  self.backend.reset()
185
  self.turns_remain = self.backend.environment.max_turns
186
 
 
221
  # if len(data) != self.stu_num:
222
  if len(data) != self.stu_num + 1:
223
  raise gr.Error("data length is not equal to the total number of students.")
224
+ if self.task == "simulation/prisoner_dilemma":
225
  img = cv2.imread("./imgs/speaking.png", cv2.IMREAD_UNCHANGED)
226
  if (
227
  len(self.messages) < 2
 
330
  """
331
 
332
  # data = self.backend.next_data()
333
+
334
  return_message = self.backend.next()
335
+
336
  data = self.return_format(return_message)
337
 
338
  # data.sort(key=lambda item: item["sender"])
 
465
  task_dropdown = gr.Dropdown(
466
  choices=[
467
  "simulation/nlp_classroom_9players",
468
+ "simulation/prisoner_dilemma",
469
  ],
470
  value="simulation/nlp_classroom_9players",
471
  label="Task",
 
595
  demo.queue(concurrency_count=5, max_size=20).launch()
596
  # demo.launch()
597
 
 
598
 
599
  GUI().launch()