File size: 5,835 Bytes
010edb7
 
 
 
 
 
 
 
 
3eacaec
b7e2104
6b30d5d
2eb6e76
010edb7
6b30d5d
bdef5c4
 
 
 
 
 
 
 
 
c4e154c
bdef5c4
 
 
 
b7e2104
bdef5c4
b7e2104
bdef5c4
b7e2104
bdef5c4
 
 
c4e154c
b7e2104
 
c4e154c
b7e2104
 
 
 
bdef5c4
 
 
 
 
 
 
 
 
010edb7
05dd656
010edb7
 
 
 
 
 
 
 
2eb6e76
010edb7
 
 
 
 
 
 
 
 
 
 
 
 
3eacaec
010edb7
 
bbb7e28
010edb7
 
 
05dd656
010edb7
 
 
 
3eacaec
 
010edb7
 
 
 
 
 
bbb7e28
010edb7
bbb7e28
010edb7
 
 
bbb7e28
05dd656
010edb7
 
 
 
 
 
 
 
3eacaec
 
 
010edb7
3eacaec
010edb7
3eacaec
bbb7e28
 
 
 
3eacaec
 
bdef5c4
3eacaec
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import streamlit as st
import vec2text
import torch
from umap import UMAP
import plotly.express as px
import numpy as np
from streamlit_plotly_events import plotly_events
import utils
import pandas as pd
from scipy.spatial import distance
from resources import get_gtr_embeddings
from transformers import PreTrainedModel, PreTrainedTokenizer
dimensionality_reduction_model_name = "PCA"

def diffs(embeddings: np.ndarray, corrector, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer):
    st.title('"A man is to king, what woman is to queen"')
    st.markdown("A well known pehnomenon in semantic vectors is the way we can do vector operations like addition and subtraction to find spacial relations in the vector space.")
    st.markdown(
        'In word embedding models, we have found that the relationship between words can be captured mathematically, '
        'such that "king" is to "man" as "queen" is to "woman," demonstrating that vector arithmetic can encode analogies and semantic relationships in high-dimensional space ([Mikolov et al., 2013](https://arxiv.org/abs/1301.3781)).'
        )
    st.markdown("This application lets you freely explore to which extent that property applies to embedding inversion models given the other factors of inaccuracy")

    generated_sentence = ""
    device = encoder.device

    with st.form(key="foo") as form:
        submit_button = st.form_submit_button("Synthesize")

        sent1 = st.text_input("Sentence 1", value="I am a king")
        st.latex("-")
        sent2 = st.text_input("Sentence 2", value="I am a man")
        st.latex("+")
        sent3 = st.text_input("Sentence 3", value="I am a woman")
        st.latex("=")

        if submit_button:
            v1, v2, v3 = get_gtr_embeddings([sent1, sent2, sent3], encoder, tokenizer, device=encoder.device).to(device)
            v4 = v1 - v2 + v3
            generated_sentence, = vec2text.invert_embeddings(
                embeddings=v4.unsqueeze(0).to(device),
                corrector=corrector,
                num_steps=20,
            )
            generated_sentence = generated_sentence.strip()
        
        sent4 = st.text_input("Sentence 4", value=generated_sentence, disabled=True)

        

        if submit_button:
            generated_sentence = "HI!"

    # st.html('<a href="https://www.flaticon.com/free-icons/array" title="array icons">Array icons created by Voysla - Flaticon</a>')

def plot(df: pd.DataFrame, embeddings: np.ndarray, vectors_2d, reducer, corrector, device):
    

    # Add a scatter plot using Plotly
    fig = px.scatter(
        x=vectors_2d[:, 0], 
        y=vectors_2d[:, 1], 
        opacity=0.6,
        hover_data={"Title": df["title"]}, 
        labels={'x': f'{dimensionality_reduction_model_name} Component 1', 'y': f'{dimensionality_reduction_model_name} Component 2'},
        title="UMAP Scatter Plot of Reddit Titles",
        color_discrete_sequence=["#ff504c"]  # Set default blue color for points
    )

    # Customize the layout to adapt to browser settings (light/dark mode)
    fig.update_layout(
        template=None,  # Let Plotly adapt automatically based on user settings
        plot_bgcolor="rgba(0, 0, 0, 0)",
        paper_bgcolor="rgba(0, 0, 0, 0)"
    )

    x, y = 0.0, 0.0
    vec = np.array([x, y]).astype("float32")
    inferred_embedding = None
    # Add a card container to the right of the content with Streamlit columns
    col1, col2 = st.columns([0.6, 0.4])  # Adjusting ratio to allocate space for the card container
    inversion_output_text = None

    with col1:
        # Main content stays here (scatterplot, form, etc.)
        selected_points = plotly_events(fig, click_event=True, hover_event=False,# override_height="600", override_width="600"
                                        )
        with st.form(key="form1_main"):
            if selected_points:
                clicked_point = selected_points[0]
                x = clicked_point['x']
                y = clicked_point['y']

            x = st.number_input("X Coordinate", value=x, format="%.10f")
            y = st.number_input("Y Coordinate", value=y, format="%.10f")
            vec = np.array([x, y]).astype("float32")


            submit_button = st.form_submit_button("Synthesize")

            if submit_button:
                inferred_embedding = reducer.inverse_transform(np.array([[x, y]]) if not isinstance(reducer, UMAP) else np.array([[x, y]]))
                inferred_embedding = inferred_embedding.astype("float32")

                inversion_output_text, = vec2text.invert_embeddings(
                    embeddings=torch.tensor(inferred_embedding).to(device),
                    corrector=corrector,
                    num_steps=20,
                )
            else:
                st.text("Click on a point in the scatterplot to see its coordinates.")

    with col2:
        closest_sentence_index = utils.find_exact_match(vectors_2d, vec, decimals=3)
        selected_sentence = df.title.iloc[closest_sentence_index] if closest_sentence_index > -1 else None
        selected_sentence_embedding = embeddings[closest_sentence_index] if closest_sentence_index > -1 else None

        st.markdown(
            f"### Selected text:\n```console\n{selected_sentence}\n```"
        )

        st.markdown(
            f"### Synthesized text:\n```console\n{inversion_output_text}\n```"
        )

        if inferred_embedding is not None and (closest_sentence_index != -1):
            couple = selected_sentence_embedding.squeeze(), inferred_embedding.squeeze()
            st.markdown("### Inferred embedding distance:")
            st.number_input("Euclidean", value=distance.euclidean(
                *couple
            ), disabled=True)
            st.number_input("Cosine", value=distance.cosine(*couple), disabled=True)