File size: 6,118 Bytes
0e403da
 
 
 
1d30073
0e403da
 
7bbddfb
 
70e7b84
7bbddfb
 
 
 
 
70e7b84
7bbddfb
51051f5
7bbddfb
 
 
 
 
 
 
51051f5
01703c9
0997afc
6b66811
0e403da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bbddfb
0e403da
 
7bbddfb
0e403da
 
7bbddfb
0e403da
 
 
 
8e1a8c8
 
51051f5
943ee2f
 
7bbddfb
 
597e1ba
7bbddfb
 
51051f5
01703c9
0997afc
 
 
01703c9
7bbddfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51051f5
7bbddfb
 
bae7bad
943ee2f
597e1ba
7bbddfb
 
51051f5
7bbddfb
0997afc
 
 
 
7bbddfb
 
 
 
 
 
 
 
 
 
 
 
 
 
a3759e2
 
 
7bbddfb
51051f5
7bbddfb
 
 
 
 
 
597e1ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import streamlit as st
import jax.numpy as jnp
from transformers import AutoTokenizer
from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
from t5_vae_flax_alt.src.t5_vae import FlaxT5VaeForAutoencoding


st.set_page_config(
    page_title="T5-VAE",
    page_icon="😐",
    layout="wide",
    initial_sidebar_state="expanded"
)


st.title('T5-VAE πŸ™πŸ˜πŸ™‚')

st.markdown('''
This is a variational autoencoder trained on text.

It allows interpolating on text at a high level, try it out!

See how it works [here](http://fras.uk/ml/large%20prior-free%20models/transformer-vae/2020/08/13/Transformers-as-Variational-Autoencoders.html).
''')

st.markdown('''
### [t5-vae-python](https://huggingface.co/flax-community/t5-vae-python)

This model is trained on lines of Python code from GitHub ([dataset](https://huggingface.co/datasets/Fraser/python-lines)).
''')


@st.cache(allow_output_mutation=True)
def get_model():
    tokenizer = AutoTokenizer.from_pretrained("t5-base")
    model = FlaxT5VaeForAutoencoding.from_pretrained("flax-community/t5-vae-python")
    assert model.params['t5']['shared']['embedding'].shape[0] == len(tokenizer), "T5 Tokenizer doesn't match T5Vae embedding size."
    return model, tokenizer


model, tokenizer = get_model()


def add_decoder_input_ids(examples):
    arr_input_ids = jnp.array(examples["input_ids"])
    pad = tokenizer.pad_token_id * jnp.ones((arr_input_ids.shape[0], 1), dtype=jnp.int32)
    arr_pad_input_ids = jnp.concatenate((arr_input_ids, pad), axis=1)
    examples['decoder_input_ids'] = shift_tokens_right(arr_pad_input_ids, tokenizer.pad_token_id, model.config.decoder_start_token_id)

    arr_attention_mask = jnp.array(examples['attention_mask'])
    ones = jnp.ones((arr_attention_mask.shape[0], 1), dtype=jnp.int32)
    examples['decoder_attention_mask'] = jnp.concatenate((ones, arr_attention_mask), axis=1)

    for k in ['decoder_input_ids', 'decoder_attention_mask']:
        examples[k] = examples[k].tolist()

    return examples


def prepare_inputs(inputs):
    for k, v in inputs.items():
        inputs[k] = jnp.array(v)
    return add_decoder_input_ids(inputs)


def get_latent(text):
    return model(**prepare_inputs(tokenizer([text]))).latent_codes[0]


def tokens_from_latent(latent_codes):
    model.config.is_encoder_decoder = True
    output_ids = model.generate(
        latent_codes=jnp.array([latent_codes]),
        bos_token_id=model.config.decoder_start_token_id,
        min_length=1,
        max_length=32,
    )
    return output_ids


def slerp(ratio, t1, t2):
    '''
        Perform a spherical interpolation between 2 vectors.
        Most of the volume of a high-dimensional orange is in the skin, not the pulp.
        This also applies for multivariate Gaussian distributions.
        To that end we can interpolate between samples by following the surface of a n-dimensional sphere rather than a straight line.

        Args:
            ratio: Interpolation ratio.
            t1: Tensor1
            t2: Tensor2
    '''
    low_norm = t1 / jnp.linalg.norm(t1, axis=1, keepdims=True)
    high_norm = t2 / jnp.linalg.norm(t2, axis=1, keepdims=True)
    omega = jnp.arccos((low_norm * high_norm).sum(1))
    so = jnp.sin(omega)
    res = (jnp.sin((1.0 - ratio) * omega) / so)[0] * t1 + (jnp.sin(ratio * omega) / so)[0] * t2
    return res


def decode(cnt, ratio, txt_1, txt_2):
    if not txt_1 or not txt_2:
        return ''
    cnt.write('Getting latents...')
    lt_1, lt_2 = get_latent(txt_1), get_latent(txt_2)
    lt_new = slerp(ratio, lt_1, lt_2)
    cnt.write('Decoding latent...')
    tkns = tokens_from_latent(lt_new)
    return tokenizer.decode(tkns.sequences[0], skip_special_tokens=True)


in_1 = st.text_input("A line of Python code.", "x = a - 1")
in_2 = st.text_input("Another line of Python code.", "x = a + 10 * 2")
r = st.slider('Python Interpolation Ratio',  min_value=0.0, max_value=1.0, value=0.5)
container = st.empty()
container.write('Loading...')
out = decode(container, r, in_1, in_2)
container.empty()
st.write('Output: ' + out)


st.markdown('''
### [t5-vae-wiki](https://huggingface.co/flax-community/t5-vae-wiki)

This model is trained on just 5% of the sentences on wikipedia.

We'll release another model trained on the full [dataset](https://github.com/ChunyuanLI/Optimus/blob/master/download_datasets.md) soon.
''')


@st.cache(allow_output_mutation=True)
def get_wiki_model():
    tokenizer = AutoTokenizer.from_pretrained("t5-base")
    model = FlaxT5VaeForAutoencoding.from_pretrained("flax-community/t5-vae-wiki")
    assert model.params['t5']['shared']['embedding'].shape[0] == len(tokenizer), "T5 Tokenizer doesn't match T5Vae embedding size."
    return model, tokenizer


model, tokenizer = get_wiki_model()


in_1 = st.text_input("A sentence.", "Children are looking for the water to be clear.")
in_2 = st.text_input("Another sentence.", "There are two people playing soccer.")
r = st.slider('English Interpolation Ratio',  min_value=0.0, max_value=1.0, value=0.5)
container = st.empty()
container.write('Loading...')
out = decode(container, r, in_1, in_2)
container.empty()
st.write('Output: ' + out)


st.markdown('''
Try arithmetic in latent space.

Here latent codes for each sentence are found and arithmetic is done with them.

Here it runs the sum `C + (B - A) = ?`
''')


def arithmetic(cnt, txt_a, txt_b, txt_c):
    if not txt_a or not txt_b or not txt_c:
        return ''
    cnt.write('getting latents...')
    lt_a, lt_b, lt_c = get_latent(txt_a), get_latent(txt_b), get_latent(txt_c)
    lt_d = lt_c + (lt_b - lt_a)
    cnt.write('decoding C + (B - A)...')
    tkns = tokens_from_latent(lt_d)
    return tokenizer.decode(tkns.sequences[0], skip_special_tokens=True)


in_a = st.text_input("A", "A girl makes a silly face.")
in_b = st.text_input("B", "Two girls are playing soccer.")
in_c = st.text_input("C", "A girl is looking through a microscope.")

st.markdown('''
A is to B as C is to...
''')
container = st.empty()
container.write('Loading...')
out = arithmetic(container, in_a, in_b, in_c)
container.empty()
st.write('Output: ' + out)