Fraser commited on
Commit
4c770f6
1 Parent(s): 9361f12
Files changed (2) hide show
  1. app.py +94 -0
  2. train.py +1 -1
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import jax.numpy as jnp
3
+ from transformers import AutoTokenizer
4
+ from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
5
+ from t5_vae_flax_alt.src.t5_vae import FlaxT5VaeForAutoencoding
6
+
7
+
8
+ st.title('T5-VAE')
9
+ st.text('''
10
+ Try interpolating between lines of Python code using this T5-VAE.
11
+ ''')
12
+
13
+
14
+ @st.cache(allow_output_mutation=True)
15
+ def get_model():
16
+ tokenizer = AutoTokenizer.from_pretrained("t5-base")
17
+ model = FlaxT5VaeForAutoencoding.from_pretrained("flax-community/t5-vae-python")
18
+ assert model.params['t5']['shared']['embedding'].shape[0] == len(tokenizer), "T5 Tokenizer doesn't match T5Vae embedding size."
19
+ return model, tokenizer
20
+
21
+
22
+ model, tokenizer = get_model()
23
+
24
+
25
+ def add_decoder_input_ids(examples):
26
+ arr_input_ids = jnp.array(examples["input_ids"])
27
+ pad = tokenizer.pad_token_id * jnp.ones((arr_input_ids.shape[0], 1), dtype=jnp.int32)
28
+ arr_pad_input_ids = jnp.concatenate((arr_input_ids, pad), axis=1)
29
+ examples['decoder_input_ids'] = shift_tokens_right(arr_pad_input_ids, tokenizer.pad_token_id, model.config.decoder_start_token_id)
30
+
31
+ arr_attention_mask = jnp.array(examples['attention_mask'])
32
+ ones = jnp.ones((arr_attention_mask.shape[0], 1), dtype=jnp.int32)
33
+ examples['decoder_attention_mask'] = jnp.concatenate((ones, arr_attention_mask), axis=1)
34
+
35
+ for k in ['decoder_input_ids', 'decoder_attention_mask']:
36
+ examples[k] = examples[k].tolist()
37
+
38
+ return examples
39
+
40
+
41
+ def prepare_inputs(inputs):
42
+ for k, v in inputs.items():
43
+ inputs[k] = jnp.array(v)
44
+ return add_decoder_input_ids(inputs)
45
+
46
+
47
+ def get_latent(text):
48
+ return model(**prepare_inputs(tokenizer([text]))).latent_codes[0]
49
+
50
+
51
+ def tokens_from_latent(latent_codes):
52
+ model.config.is_encoder_decoder = True
53
+ output_ids = model.generate(
54
+ latent_codes=jnp.array([latent_codes]),
55
+ bos_token_id=model.config.decoder_start_token_id,
56
+ min_length=1,
57
+ max_length=32,
58
+ )
59
+ return output_ids
60
+
61
+
62
+ def slerp(ratio, t1, t2):
63
+ '''
64
+ Perform a spherical interpolation between 2 vectors.
65
+ Most of the volume of a high-dimensional orange is in the skin, not the pulp.
66
+ This also applies for multivariate Gaussian distributions.
67
+ To that end we can interpolate between samples by following the surface of a n-dimensional sphere rather than a straight line.
68
+
69
+ Args:
70
+ ratio: Interpolation ratio.
71
+ t1: Tensor1
72
+ t2: Tensor2
73
+ '''
74
+ low_norm = t1 / jnp.linalg.norm(t1, axis=1, keepdims=True)
75
+ high_norm = t2 / jnp.linalg.norm(t2, axis=1, keepdims=True)
76
+ omega = jnp.arccos((low_norm * high_norm).sum(1))
77
+ so = jnp.sin(omega)
78
+ res = (jnp.sin((1.0 - ratio) * omega) / so)[0] * t1 + (jnp.sin(ratio * omega) / so)[0] * t2
79
+ return res
80
+
81
+
82
+ def decode(ratio, txt_1, txt_2):
83
+ if not txt_1 or not txt_2:
84
+ return ''
85
+ lt_1, lt_2 = get_latent(txt_1), get_latent(txt_2)
86
+ lt_new = slerp(ratio, lt_1, lt_2)
87
+ tkns = tokens_from_latent(lt_new)
88
+ return tokenizer.decode(tkns.sequences[0], skip_special_tokens=True)
89
+
90
+
91
+ in_1 = st.text_input("A line of Python code.", "x = 1")
92
+ in_2 = st.text_input("Another line of Python code.", "x = 9")
93
+ r = st.slider('Interpolation Ratio')
94
+ st.write(decode(r, in_1, in_2))
train.py CHANGED
@@ -363,7 +363,7 @@ def main():
363
  model = FlaxT5VaeForAutoencoding.from_pretrained(
364
  model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
365
  )
366
- assert model.params['t5']['shared'].shape[0] == len(tokenizer), "T5 Tokenizer doesn't match T5Vae embedding size."
367
  else:
368
  vocab_size = len(tokenizer)
369
  config.t5.vocab_size = vocab_size
 
363
  model = FlaxT5VaeForAutoencoding.from_pretrained(
364
  model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
365
  )
366
+ assert model.params['t5']['shared']['embedding'].shape[0] == len(tokenizer), "T5 Tokenizer doesn't match T5Vae embedding size."
367
  else:
368
  vocab_size = len(tokenizer)
369
  config.t5.vocab_size = vocab_size