Fraser commited on
Commit
0e403da
1 Parent(s): b97edac
Files changed (6) hide show
  1. .gitignore +3 -0
  2. .gitmodules +3 -0
  3. Makefile +6 -0
  4. requirements.txt +8 -0
  5. streamlit_app.py +95 -0
  6. t5_vae_flax +1 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ venv
2
+ .vscode
3
+ *.pyc
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ [submodule "t5_vae_flax"]
2
+ path = t5_vae_flax
3
+ url = https://github.com/Fraser-Greenlee/t5-vae-flax.git
Makefile ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
1
+
2
+ run:
3
+ streamlit run streamlit_app.py
4
+
5
+ test-unit:
6
+ streamlit hello
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ watchdog
3
+
4
+ wheel
5
+ requests
6
+ flax
7
+
8
+ -e git+https://github.com/huggingface/transformers.git#egg=transformers[flax]
streamlit_app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import jax.numpy as jnp
3
+ from transformers import AutoTokenizer
4
+ from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
5
+ from t5_vae_flax.src.t5_vae import FlaxT5VaeForAutoencoding
6
+
7
+
8
+ st.title('T5-VAE')
9
+ st.text('''
10
+ Try interpolating between lines of Python code using this T5-VAE.
11
+ ''')
12
+
13
+
14
+ @st.cache(allow_output_mutation=True)
15
+ def get_model():
16
+ tokenizer = AutoTokenizer.from_pretrained("t5-base")
17
+ model = FlaxT5VaeForAutoencoding.from_pretrained("flax-community/t5-vae-python")
18
+ assert model.params['t5']['shared']['embedding'].shape[0] == len(tokenizer), "T5 Tokenizer doesn't match T5Vae embedding size."
19
+ return model, tokenizer
20
+
21
+
22
+ model, tokenizer = get_model()
23
+
24
+
25
+ def add_decoder_input_ids(examples):
26
+ arr_input_ids = jnp.array(examples["input_ids"])
27
+ pad = tokenizer.pad_token_id * jnp.ones((arr_input_ids.shape[0], 1), dtype=jnp.int32)
28
+ arr_pad_input_ids = jnp.concatenate((arr_input_ids, pad), axis=1)
29
+ examples['decoder_input_ids'] = shift_tokens_right(arr_pad_input_ids, tokenizer.pad_token_id, model.config.decoder_start_token_id)
30
+
31
+ arr_attention_mask = jnp.array(examples['attention_mask'])
32
+ ones = jnp.ones((arr_attention_mask.shape[0], 1), dtype=jnp.int32)
33
+ examples['decoder_attention_mask'] = jnp.concatenate((ones, arr_attention_mask), axis=1)
34
+
35
+ for k in ['decoder_input_ids', 'decoder_attention_mask']:
36
+ examples[k] = examples[k].tolist()
37
+
38
+ return examples
39
+
40
+
41
+ def prepare_inputs(inputs):
42
+ for k, v in inputs.items():
43
+ inputs[k] = jnp.array(v)
44
+ return add_decoder_input_ids(inputs)
45
+
46
+
47
+ def get_latent(text):
48
+ return model(**prepare_inputs(tokenizer([text]))).latent_codes[0]
49
+
50
+
51
+ def tokens_from_latent(latent_codes):
52
+ model.config.is_encoder_decoder = True
53
+ output_ids = model.generate(
54
+ latent_codes=jnp.array([latent_codes]),
55
+ bos_token_id=model.config.decoder_start_token_id,
56
+ min_length=1,
57
+ max_length=32,
58
+ )
59
+ return output_ids
60
+
61
+
62
+ def slerp(ratio, t1, t2):
63
+ '''
64
+ Perform a spherical interpolation between 2 vectors.
65
+ Most of the volume of a high-dimensional orange is in the skin, not the pulp.
66
+ This also applies for multivariate Gaussian distributions.
67
+ To that end we can interpolate between samples by following the surface of a n-dimensional sphere rather than a straight line.
68
+
69
+ Args:
70
+ ratio: Interpolation ratio.
71
+ t1: Tensor1
72
+ t2: Tensor2
73
+ '''
74
+ low_norm = t1 / jnp.linalg.norm(t1, axis=1, keepdims=True)
75
+ high_norm = t2 / jnp.linalg.norm(t2, axis=1, keepdims=True)
76
+ omega = jnp.arccos((low_norm * high_norm).sum(1))
77
+ so = jnp.sin(omega)
78
+ res = (jnp.sin((1.0 - ratio) * omega) / so)[0] * t1 + (jnp.sin(ratio * omega) / so)[0] * t2
79
+ return res
80
+
81
+
82
+ def decode(ratio, txt_1, txt_2):
83
+ if not txt_1 or not txt_2:
84
+ return ''
85
+ lt_1, lt_2 = get_latent(txt_1), get_latent(txt_2)
86
+ lt_new = slerp(ratio, lt_1, lt_2)
87
+ tkns = tokens_from_latent(lt_new)
88
+ return tokenizer.decode(tkns.sequences[0], skip_special_tokens=True)
89
+
90
+
91
+ # TODO while loop here?
92
+ st.text_input("x = 3", key="in_1")
93
+ st.text_input("y += 'hello'", key="in_2")
94
+ r = st.slider('Interpolation Ratio')
95
+ st.write(decode(r, st.session_state.in_1, st.session_state.in_2))
t5_vae_flax ADDED
@@ -0,0 +1 @@
 
1
+ Subproject commit 0c030dca4751e6def730968a2f33fe093a608cdb