Spaces:
Running
Running
Lev McKinney
commited on
Commit
β’
c35da92
1
Parent(s):
7a724e0
upgraded app to use tuned_lens=0.1.0
Browse files- README.md +1 -0
- app.py +21 -15
- requirements.txt +1 -1
README.md
CHANGED
@@ -3,6 +3,7 @@ title: Tuned Lens
|
|
3 |
emoji: π
|
4 |
colorFrom: pink
|
5 |
colorTo: blue
|
|
|
6 |
sdk: docker
|
7 |
pinned: false
|
8 |
license: mit
|
|
|
3 |
emoji: π
|
4 |
colorFrom: pink
|
5 |
colorTo: blue
|
6 |
+
port: 7860
|
7 |
sdk: docker
|
8 |
pinned: false
|
9 |
license: mit
|
app.py
CHANGED
@@ -1,17 +1,20 @@
|
|
1 |
import torch
|
2 |
from tuned_lens.nn.lenses import TunedLens, LogitLens
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
-
from tuned_lens.plotting import
|
5 |
import gradio as gr
|
6 |
from plotly import graph_objects as go
|
7 |
|
8 |
device = torch.device("cpu")
|
9 |
print(f"Using device {device} for inference")
|
10 |
-
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m-deduped
|
11 |
model = model.to(device)
|
12 |
-
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped
|
13 |
-
tuned_lens = TunedLens.
|
14 |
-
|
|
|
|
|
|
|
15 |
|
16 |
lens_options_dict = {
|
17 |
"Tuned Lens": tuned_lens,
|
@@ -20,13 +23,15 @@ lens_options_dict = {
|
|
20 |
|
21 |
statistic_options_dict = {
|
22 |
"Entropy": "entropy",
|
23 |
-
"Cross Entropy": "
|
24 |
"Forward KL": "forward_kl",
|
25 |
}
|
26 |
|
27 |
|
28 |
def make_plot(lens, text, statistic, token_cutoff):
|
29 |
input_ids = tokenizer.encode(text, return_tensors="pt")
|
|
|
|
|
30 |
|
31 |
if len(input_ids[0]) == 0:
|
32 |
return go.Figure(layout=dict(title="Please enter some text."))
|
@@ -34,18 +39,19 @@ def make_plot(lens, text, statistic, token_cutoff):
|
|
34 |
if token_cutoff < 1:
|
35 |
return go.Figure(layout=dict(title="Please provide valid token cut off."))
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
layer_stride=2,
|
42 |
input_ids=input_ids,
|
43 |
-
|
44 |
-
|
|
|
45 |
)
|
46 |
|
47 |
-
return
|
48 |
-
|
|
|
49 |
|
50 |
preamble = """
|
51 |
# The Tuned Lens π
|
|
|
1 |
import torch
|
2 |
from tuned_lens.nn.lenses import TunedLens, LogitLens
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
+
from tuned_lens.plotting import PredictionTrajectory
|
5 |
import gradio as gr
|
6 |
from plotly import graph_objects as go
|
7 |
|
8 |
device = torch.device("cpu")
|
9 |
print(f"Using device {device} for inference")
|
10 |
+
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m-deduped")
|
11 |
model = model.to(device)
|
12 |
+
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped")
|
13 |
+
tuned_lens = TunedLens.from_model_and_pretrained(
|
14 |
+
model=model,
|
15 |
+
map_location=device,
|
16 |
+
)
|
17 |
+
logit_lens = LogitLens.from_model(model)
|
18 |
|
19 |
lens_options_dict = {
|
20 |
"Tuned Lens": tuned_lens,
|
|
|
23 |
|
24 |
statistic_options_dict = {
|
25 |
"Entropy": "entropy",
|
26 |
+
"Cross Entropy": "cross_entropy",
|
27 |
"Forward KL": "forward_kl",
|
28 |
}
|
29 |
|
30 |
|
31 |
def make_plot(lens, text, statistic, token_cutoff):
|
32 |
input_ids = tokenizer.encode(text, return_tensors="pt")
|
33 |
+
input_ids = [tokenizer.bos_token_id] + input_ids
|
34 |
+
targets = input_ids[1:] + [tokenizer.eos_token_id]
|
35 |
|
36 |
if len(input_ids[0]) == 0:
|
37 |
return go.Figure(layout=dict(title="Please enter some text."))
|
|
|
39 |
if token_cutoff < 1:
|
40 |
return go.Figure(layout=dict(title="Please provide valid token cut off."))
|
41 |
|
42 |
+
start_pos=max(len(input_ids[0]) - token_cutoff, 0),
|
43 |
+
pred_traj = PredictionTrajectory.from_lens_and_model(
|
44 |
+
lens=lens,
|
45 |
+
model=model,
|
|
|
46 |
input_ids=input_ids,
|
47 |
+
tokenizer=tokenizer,
|
48 |
+
targets=targets,
|
49 |
+
start_pos=start_pos,
|
50 |
)
|
51 |
|
52 |
+
return getattr(pred_traj, statistic)().figure(
|
53 |
+
title=f"{lens.__class__.__name__} ({model.name_or_path}) {statistic}",
|
54 |
+
)
|
55 |
|
56 |
preamble = """
|
57 |
# The Tuned Lens π
|
requirements.txt
CHANGED
@@ -1,2 +1,2 @@
|
|
1 |
-
tuned_lens==0.0
|
2 |
gradio
|
|
|
1 |
+
tuned_lens==0.1.0
|
2 |
gradio
|