marksverdhei commited on
Commit
b7e2104
·
1 Parent(s): bdef5c4

Make it work

Browse files
Files changed (1) hide show
  1. views.py +12 -5
views.py CHANGED
@@ -8,7 +8,7 @@ from streamlit_plotly_events import plotly_events
8
  import utils
9
  import pandas as pd
10
  from scipy.spatial import distance
11
-
12
  dimensionality_reduction_model_name = "PCA"
13
 
14
  def diffs(embeddings: np.ndarray, corrector, encoder, tokenizer):
@@ -26,15 +26,22 @@ def diffs(embeddings: np.ndarray, corrector, encoder, tokenizer):
26
  with st.form(key="foo") as form:
27
  submit_button = st.form_submit_button("Synthesize")
28
 
29
- sent1 = st.text_input("Sentence 1")
30
  st.latex("-")
31
- sent2 = st.text_input("Sentence 2")
32
  st.latex("+")
33
- sent3 = st.text_input("Sentence 3")
34
  st.latex("=")
35
 
36
  if submit_button:
37
- generated_sentence = "HI"
 
 
 
 
 
 
 
38
 
39
  sent4 = st.text_input("Sentence 4", value=generated_sentence, disabled=True)
40
 
 
8
  import utils
9
  import pandas as pd
10
  from scipy.spatial import distance
11
+ from resources import get_gtr_embeddings
12
  dimensionality_reduction_model_name = "PCA"
13
 
14
  def diffs(embeddings: np.ndarray, corrector, encoder, tokenizer):
 
26
  with st.form(key="foo") as form:
27
  submit_button = st.form_submit_button("Synthesize")
28
 
29
+ sent1 = st.text_input("Sentence 1", value="I am a king")
30
  st.latex("-")
31
+ sent2 = st.text_input("Sentence 2", value="I am a man")
32
  st.latex("+")
33
+ sent3 = st.text_input("Sentence 3", value="I am a woman")
34
  st.latex("=")
35
 
36
  if submit_button:
37
+ v1, v2, v3 = get_gtr_embeddings([sent1, sent2, sent3], encoder, tokenizer).to("cpu")
38
+ v4 = v1 - v2 + v3
39
+ generated_sentence, = vec2text.invert_embeddings(
40
+ embeddings=v4.unsqueeze(0).cuda(),
41
+ corrector=corrector,
42
+ num_steps=20,
43
+ )
44
+ generated_sentence = generated_sentence.strip()
45
 
46
  sent4 = st.text_input("Sentence 4", value=generated_sentence, disabled=True)
47