File size: 3,504 Bytes
99c5f63
cf31e9b
5cf0970
 
 
99c5f63
cf31e9b
5cf0970
 
 
 
e5edf10
 
5cf0970
 
 
 
 
 
 
 
 
 
 
 
624cb8c
e5edf10
5cf0970
 
 
 
 
 
 
 
 
 
 
ee17a02
6deae30
36d45d5
 
be3d74a
 
 
 
 
fec2994
be3d74a
 
 
4bc4fd8
 
 
 
624cb8c
42c9093
 
cac961a
 
 
 
 
47fd2b6
 
cac961a
 
 
 
 
 
 
 
 
 
 
 
 
090b788
cac961a
e8a54d9
 
2e52e56
e8a54d9
 
 
80f54d7
e8a54d9
 
cac961a
e8a54d9
6d7de7f
 
 
 
 
 
 
e8a54d9
 
74fd993
e8a54d9
 
 
be3d74a
 
5125cdc
50a5784
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import streamlit as st
import numpy as np
import plotly.express as px, plotly.graph_objects as go
from plotly.subplots import make_subplots
from transformers import AutoTokenizer, T5Tokenizer, T5ForConditionalGeneration, GenerationConfig, AutoModelForCausalLM

def top_token_ids(outputs, threshold=-np.inf):
    "Returns the index of the tokens whose score exceeds a threshold, for each output step"
    indexes = []
    for tensor in outputs['scores']:
        candidates = np.argwhere(tensor.flatten() > threshold).numpy()[0]
        ordering_mask = np.argsort(tensor[0][candidates])
        candidates = candidates[ordering_mask]
        if not isinstance(candidates, np.ndarray):
            indexes.append(np.array([candidates]))
        else:
            indexes.append(candidates)
    return indexes

def plot_word_scores(top_token_ids, outputs, tokenizer, boolq=False, width=600):
    fig = make_subplots(rows=len(top_token_ids), cols=1)
    for step, candidates in enumerate(top_token_ids):  
        fig.append_trace(
            go.Bar(
                y=[w[1:] for w in tokenizer.convert_ids_to_tokens(candidates)], 
                x=outputs['scores'][step][0][candidates], 
                orientation='h'
            ),
            row=step+1, col=1
        )
    fig.update_layout(
        width=500, 
        height=300*len(top_token_ids),
        showlegend=False
    )
    return fig

st.title('How do LLM choose their words?')

instruction = st.text_area(label='Write an instruction:', placeholder='Where is Venice located?')

col1, col2 = st.columns(2)

with col1:
    model_checkpoint = st.selectbox(
        "Model:",
        ("google/flan-t5-base", "google/flan-t5-large", "google/flan-t5-xl")
    )

with col2:
    temperature = st.slider('Temperature:', min_value=0.0, max_value=1.0, value=0.5)
    top_p = st.slider('Top p:', min_value=0.5, max_value=1.0, value=0.99)
    # max_tokens = st.number_input('Max output length:', min_value=1, max_value=64, format='%i')
    max_tokens = st.slider('Max output length: ', min_value=1, max_value=64)
    # threshold = st.number_input('Min token score:: ', value=-10.0)
    
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

model = T5ForConditionalGeneration.from_pretrained(  
    model_checkpoint,
    load_in_8bit=False,
    device_map="auto",
    offload_folder="offload"
)


prompts = [
    f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
    ### Instruction: {instruction}
    ### Response:"""
]

inputs = tokenizer(
    prompts[0],
    return_tensors="pt",
)
input_ids = inputs["input_ids"]#.to("cuda")

generation_config = GenerationConfig(
    do_sample=True,
    temperature=temperature,
    top_p=0.995,      # default 0.75
    top_k=100,        # default 80
    repetition_penalty=1.5,
    max_new_tokens=max_tokens,
)

if instruction:
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=torch.ones_like(input_ids),
            generation_config=generation_config,
            return_dict_in_generate=True, 
            output_scores=True
        )
    
    output_text = tokenizer.decode(
        outputs['sequences'][0],#.cuda(), 
        skip_special_tokens=False
    ).strip()
    
    st.write(output_text)

    fig = plot_word_scores(top_token_ids(outputs, threshold=-10.0), outputs, tokenizer)
    st.plotly_chart(fig, theme=None, use_container_width=False)