Балаганский Никита Николаевич
commited on
Commit
•
9852b1b
1
Parent(s):
d0289f9
fix
Browse files- generator.py +18 -1
- sampling.py +2 -0
generator.py
CHANGED
@@ -4,6 +4,8 @@ import torch
|
|
4 |
import transformers
|
5 |
import streamlit as st
|
6 |
|
|
|
|
|
7 |
|
8 |
class Generator:
|
9 |
def __init__(self, lm_model_name, device, entropy=None):
|
@@ -55,6 +57,8 @@ class Generator:
|
|
55 |
num_samples,
|
56 |
)
|
57 |
text = st.empty()
|
|
|
|
|
58 |
for i in range(max_length):
|
59 |
is_caif_step = (
|
60 |
i % caif_period == 0 and self.caif_sampler is not None
|
@@ -70,7 +74,20 @@ class Generator:
|
|
70 |
progress_bar.progress((i+1)/max_length)
|
71 |
if ended_sequences.all():
|
72 |
break
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
return (
|
76 |
[
|
|
|
4 |
import transformers
|
5 |
import streamlit as st
|
6 |
|
7 |
+
from plotly import graph_objects as go
|
8 |
+
|
9 |
|
10 |
class Generator:
|
11 |
def __init__(self, lm_model_name, device, entropy=None):
|
|
|
57 |
num_samples,
|
58 |
)
|
59 |
text = st.empty()
|
60 |
+
plot = st.empty()
|
61 |
+
gen_history = []
|
62 |
for i in range(max_length):
|
63 |
is_caif_step = (
|
64 |
i % caif_period == 0 and self.caif_sampler is not None
|
|
|
74 |
progress_bar.progress((i+1)/max_length)
|
75 |
if ended_sequences.all():
|
76 |
break
|
77 |
+
current_decoded = self.tokenizer.decode(input_ids[0])
|
78 |
+
if self.caif_sampler is not None:
|
79 |
+
probs = torch.exp(
|
80 |
+
self.caif_sampler.get_classifier_log_probs(
|
81 |
+
current_decoded, target_cls_id=sampler_kwargs["target_cls_id"]
|
82 |
+
)
|
83 |
+
).item()
|
84 |
+
gen_history += [probs]
|
85 |
+
scatter_data = go.Scatter({
|
86 |
+
"x": list(range(len(gen_history))),
|
87 |
+
"y": gen_history
|
88 |
+
})
|
89 |
+
plot.plotly_chart(scatter_data, use_container_width=True)
|
90 |
+
text.text(current_decoded)
|
91 |
|
92 |
return (
|
93 |
[
|
sampling.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import torch
|
2 |
from torch.nn import functional as F
|
3 |
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
import torch
|
4 |
from torch.nn import functional as F
|
5 |
|