Балаганский Никита Николаевич commited on
Commit
a905c49
1 Parent(s): d949a12

add prompt

Browse files
Files changed (1) hide show
  1. generator.py +9 -1
generator.py CHANGED
@@ -72,6 +72,14 @@ class Generator:
72
 
73
  })
74
  inp_len = len(input_ids[0])
 
 
 
 
 
 
 
 
75
  for i in range(max_length):
76
  is_caif_step = (
77
  i % caif_period == 0 and self.caif_sampler is not None
@@ -98,7 +106,7 @@ class Generator:
98
  scatter_data = go.Scatter({
99
  "x": list(range(len(gen_history))),
100
  "y": gen_history,
101
- "hovertext": [self.tokenizer.decode(t) for t in input_ids[0][inp_len:]]
102
  })
103
  fig = go.Figure([scatter_data], layout=layout)
104
  plot.plotly_chart(fig, use_container_width=True)
 
72
 
73
  })
74
  inp_len = len(input_ids[0])
75
+ if self.caif_sampler is not None:
76
+ current_decoded = self.tokenizer.decode(input_ids[0])
77
+ probs = torch.exp(
78
+ self.caif_sampler.get_classifier_log_probs(
79
+ current_decoded, target_cls_id=sampler_kwargs["target_cls_id"]
80
+ )
81
+ ).item()
82
+ gen_history += [probs]
83
  for i in range(max_length):
84
  is_caif_step = (
85
  i % caif_period == 0 and self.caif_sampler is not None
 
106
  scatter_data = go.Scatter({
107
  "x": list(range(len(gen_history))),
108
  "y": gen_history,
109
+ "hovertext": ["[PROMPT]"] + [self.tokenizer.decode(t) for t in input_ids[0][inp_len:]]
110
  })
111
  fig = go.Figure([scatter_data], layout=layout)
112
  plot.plotly_chart(fig, use_container_width=True)