Commit
·
b7e2104
1
Parent(s):
bdef5c4
Make it work
Browse files
views.py
CHANGED
@@ -8,7 +8,7 @@ from streamlit_plotly_events import plotly_events
|
|
8 |
import utils
|
9 |
import pandas as pd
|
10 |
from scipy.spatial import distance
|
11 |
-
|
12 |
dimensionality_reduction_model_name = "PCA"
|
13 |
|
14 |
def diffs(embeddings: np.ndarray, corrector, encoder, tokenizer):
|
@@ -26,15 +26,22 @@ def diffs(embeddings: np.ndarray, corrector, encoder, tokenizer):
|
|
26 |
with st.form(key="foo") as form:
|
27 |
submit_button = st.form_submit_button("Synthesize")
|
28 |
|
29 |
-
sent1 = st.text_input("Sentence 1")
|
30 |
st.latex("-")
|
31 |
-
sent2 = st.text_input("Sentence 2")
|
32 |
st.latex("+")
|
33 |
-
sent3 = st.text_input("Sentence 3")
|
34 |
st.latex("=")
|
35 |
|
36 |
if submit_button:
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
sent4 = st.text_input("Sentence 4", value=generated_sentence, disabled=True)
|
40 |
|
|
|
8 |
import utils
|
9 |
import pandas as pd
|
10 |
from scipy.spatial import distance
|
11 |
+
from resources import get_gtr_embeddings
|
12 |
dimensionality_reduction_model_name = "PCA"
|
13 |
|
14 |
def diffs(embeddings: np.ndarray, corrector, encoder, tokenizer):
|
|
|
26 |
with st.form(key="foo") as form:
|
27 |
submit_button = st.form_submit_button("Synthesize")
|
28 |
|
29 |
+
sent1 = st.text_input("Sentence 1", value="I am a king")
|
30 |
st.latex("-")
|
31 |
+
sent2 = st.text_input("Sentence 2", value="I am a man")
|
32 |
st.latex("+")
|
33 |
+
sent3 = st.text_input("Sentence 3", value="I am a woman")
|
34 |
st.latex("=")
|
35 |
|
36 |
if submit_button:
|
37 |
+
v1, v2, v3 = get_gtr_embeddings([sent1, sent2, sent3], encoder, tokenizer).to("cpu")
|
38 |
+
v4 = v1 - v2 + v3
|
39 |
+
generated_sentence, = vec2text.invert_embeddings(
|
40 |
+
embeddings=v4.unsqueeze(0).cuda(),
|
41 |
+
corrector=corrector,
|
42 |
+
num_steps=20,
|
43 |
+
)
|
44 |
+
generated_sentence = generated_sentence.strip()
|
45 |
|
46 |
sent4 = st.text_input("Sentence 4", value=generated_sentence, disabled=True)
|
47 |
|