Балаганский Никита Николаевич commited on
Commit
9852b1b
1 Parent(s): d0289f9
Files changed (2) hide show
  1. generator.py +18 -1
  2. sampling.py +2 -0
generator.py CHANGED
@@ -4,6 +4,8 @@ import torch
4
  import transformers
5
  import streamlit as st
6
 
 
 
7
 
8
  class Generator:
9
  def __init__(self, lm_model_name, device, entropy=None):
@@ -55,6 +57,8 @@ class Generator:
55
  num_samples,
56
  )
57
  text = st.empty()
 
 
58
  for i in range(max_length):
59
  is_caif_step = (
60
  i % caif_period == 0 and self.caif_sampler is not None
@@ -70,7 +74,20 @@ class Generator:
70
  progress_bar.progress((i+1)/max_length)
71
  if ended_sequences.all():
72
  break
73
- text.text(self.tokenizer.decode(input_ids[0]))
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  return (
76
  [
 
4
  import transformers
5
  import streamlit as st
6
 
7
+ from plotly import graph_objects as go
8
+
9
 
10
  class Generator:
11
  def __init__(self, lm_model_name, device, entropy=None):
 
57
  num_samples,
58
  )
59
  text = st.empty()
60
+ plot = st.empty()
61
+ gen_history = []
62
  for i in range(max_length):
63
  is_caif_step = (
64
  i % caif_period == 0 and self.caif_sampler is not None
 
74
  progress_bar.progress((i+1)/max_length)
75
  if ended_sequences.all():
76
  break
77
+ current_decoded = self.tokenizer.decode(input_ids[0])
78
+ if self.caif_sampler is not None:
79
+ probs = torch.exp(
80
+ self.caif_sampler.get_classifier_log_probs(
81
+ current_decoded, target_cls_id=sampler_kwargs["target_cls_id"]
82
+ )
83
+ ).item()
84
+ gen_history += [probs]
85
+ scatter_data = go.Scatter({
86
+ "x": list(range(len(gen_history))),
87
+ "y": gen_history
88
+ })
89
+ plot.plotly_chart(scatter_data, use_container_width=True)
90
+ text.text(current_decoded)
91
 
92
  return (
93
  [
sampling.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import torch
2
  from torch.nn import functional as F
3
 
 
1
+ from typing import List
2
+
3
  import torch
4
  from torch.nn import functional as F
5