AlanFeder commited on
Commit
aa4694e
1 Parent(s): ea6fe12

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. .gitignore +214 -0
  2. README.md +2 -8
  3. emb_sim.py +158 -0
  4. requirements.txt +66 -0
.gitignore ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ .venv2
126
+ .venv3
127
+ .venv*
128
+ env/
129
+ venv/
130
+ ENV/
131
+ env.bak/
132
+ venv.bak/
133
+
134
+ # Spyder project settings
135
+ .spyderproject
136
+ .spyproject
137
+
138
+ # Rope project settings
139
+ .ropeproject
140
+
141
+ # mkdocs documentation
142
+ /site
143
+
144
+ # mypy
145
+ .mypy_cache/
146
+ .dmypy.json
147
+ dmypy.json
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ #.idea/
164
+
165
+ **/.DS_Store
166
+
167
+ old_stuff/
168
+
169
+ **/data/**/*.*
170
+ !.gitkeep
171
+
172
+ # VisualStudioCode.gitignore
173
+ .vscode/*
174
+ !.vscode/settings.json
175
+ !.vscode/tasks.json
176
+ !.vscode/launch.json
177
+ !.vscode/extensions.json
178
+ !.vscode/*.code-snippets
179
+
180
+ # Local History for Visual Studio Code
181
+ .history/
182
+
183
+ # Built Visual Studio Code Extensions
184
+ *.vsix
185
+
186
+ # macOS.gitignore
187
+ # General
188
+ .DS_Store
189
+ .AppleDouble
190
+ .LSOverride
191
+
192
+ # Icon must end with two \r
193
+ Icon
194
+
195
+ # Thumbnails
196
+ ._*
197
+
198
+ # Files that might appear in the root of a volume
199
+ .DocumentRevisions-V100
200
+ .fseventsd
201
+ .Spotlight-V100
202
+ .TemporaryItems
203
+ .Trashes
204
+ .VolumeIcon.icns
205
+ .com.apple.timemachine.donotpresent
206
+
207
+ # Directories potentially created on remote AFP share
208
+ .AppleDB
209
+ .AppleDesktop
210
+ Network Trash Folder
211
+ Temporary Items
212
+ .apdisk
213
+
214
+ logs/
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Embeddings UBalt
3
- emoji: 😻
4
- colorFrom: blue
5
- colorTo: blue
6
  sdk: gradio
7
  sdk_version: 4.42.0
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Embeddings-UBalt
3
+ app_file: emb_sim.py
 
 
4
  sdk: gradio
5
  sdk_version: 4.42.0
 
 
6
  ---
 
 
emb_sim.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from sklearn.decomposition import PCA
5
+ from matplotlib.colors import LinearSegmentedColormap
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+ from openai import OpenAI, AuthenticationError, RateLimitError
8
+ from dotenv import load_dotenv
9
+ import os
10
+
11
+
12
+ load_dotenv()
13
+ openai_api_key = os.getenv("OPENAI_API_KEY")
14
+ oai_client = OpenAI(api_key=openai_api_key)
15
+
16
+
17
+ def get_openai_embedding(word):
18
+ response = oai_client.embeddings.create(input=word, model="text-embedding-3-small")
19
+ return response.data[0].embedding
20
+
21
+
22
+ def calculate_embeddings(words):
23
+ # Get word embeddings
24
+ embeddings = [get_openai_embedding(word) for word in words]
25
+ return embeddings
26
+
27
+
28
+ def process_array(arr):
29
+ # Ensure the input is a square array
30
+ if arr.shape[0] != arr.shape[1]:
31
+ raise ValueError("Input must be a square array")
32
+
33
+ n = arr.shape[0]
34
+
35
+ # Step 1: Keep only the upper triangle (excluding diagonal)
36
+ upper_triangle = np.triu(arr, k=1)
37
+
38
+ # Step 2: Reverse horizontally
39
+ reversed_upper_triangle = np.fliplr(upper_triangle)
40
+
41
+ # Step 3: Drop the final row and column
42
+ result = reversed_upper_triangle[:-1, :-1]
43
+
44
+ # Step 4: Mask the zeros
45
+ masked_result = np.ma.masked_where(result == 0, result)
46
+
47
+ return masked_result
48
+
49
+
50
+ def plot_heatmap(masked_result, l1: list[str]):
51
+ n, _ = masked_result.shape
52
+
53
+ # Create the heatmap
54
+ fig, ax = plt.subplots(
55
+ figsize=(12, 10)
56
+ ) # Increased figure size for better visibility
57
+
58
+ # Create a custom colormap
59
+ colors = ["darkred", "lightgray", "dodgerblue"]
60
+ n_bins = 100
61
+ cmap = LinearSegmentedColormap.from_list("custom", colors, N=n_bins)
62
+ cmap.set_bad("white") # Set color for masked values (zeros) to white
63
+
64
+ # Plot the heatmap
65
+ im = ax.imshow(masked_result, cmap=cmap, vmin=-1, vmax=1)
66
+
67
+ # Add text annotations
68
+ for i in range(n):
69
+ for j in range(n):
70
+ if not np.ma.is_masked(masked_result[i, j]):
71
+ text = ax.text(
72
+ j,
73
+ i,
74
+ f"{masked_result[i, j]:.2f}",
75
+ ha="center",
76
+ va="center",
77
+ color="black",
78
+ )
79
+
80
+ # Set y and x axis labels
81
+ ax.set_yticks(range(n))
82
+ ax.set_yticklabels(l1[:-1])
83
+ ax.set_xticks(range(n))
84
+ ax.set_xticklabels(reversed(l1[1:]))
85
+
86
+ # Move x-axis to the top
87
+ ax.xaxis.tick_top()
88
+ ax.xaxis.set_label_position("top")
89
+
90
+ # Rotate x-axis labels for better readability
91
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor")
92
+
93
+ # Add colorbar
94
+ cbar = plt.colorbar(im)
95
+ cbar.set_ticks([-1, 0, 1])
96
+ cbar.set_ticklabels(["-1", "0", "1"])
97
+
98
+ # Add title
99
+ plt.title("Correlation Heatmap", pad=20)
100
+
101
+ # Adjust layout and display the plot
102
+ plt.tight_layout()
103
+ return fig
104
+
105
+
106
+ def plot_pca(embeddings, words):
107
+ fig, ax = plt.subplots(figsize=(12, 10))
108
+ pca = PCA(n_components=2)
109
+ embeddings_2d = pca.fit_transform(embeddings)
110
+
111
+ fig, ax = plt.subplots(figsize=(10, 8))
112
+ ax.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1])
113
+
114
+ for i, word in enumerate(words):
115
+ ax.annotate(word, (embeddings_2d[i, 0], embeddings_2d[i, 1]))
116
+
117
+ ax.set_title("PCA of Word Embeddings")
118
+ ax.set_xlabel("First Principal Component")
119
+ ax.set_ylabel("Second Principal Component")
120
+ plt.tight_layout()
121
+ return fig
122
+
123
+
124
+ def word_similarity_heatmap(input_text):
125
+ words = [word.strip() for word in input_text.split(",")]
126
+
127
+ if len(words) < 2:
128
+ return "Please enter at least two words."
129
+
130
+ try:
131
+ embeddings = calculate_embeddings(words)
132
+ similarities = cosine_similarity(embeddings)
133
+ new_array = process_array(similarities)
134
+ heatmap = plot_heatmap(new_array, words)
135
+ pca_plot = plot_pca(embeddings, words)
136
+ return heatmap, pca_plot
137
+ # return heatmap
138
+ except AuthenticationError as e:
139
+ print("OpenAI API key is invalid. Please check your API key.")
140
+ raise e
141
+ except RateLimitError as e:
142
+ print("OpenAI API rate limit exceeded. Please try again later.")
143
+ raise e
144
+ except Exception as e:
145
+ print(f"An error occurred: {str(e)}")
146
+ raise e
147
+
148
+
149
+ iface = gr.Interface(
150
+ fn=word_similarity_heatmap, # _and_pca,
151
+ inputs=gr.Textbox(lines=2, placeholder="Enter words separated by commas"),
152
+ outputs=[gr.Plot(label="Similarity Heatmap"), gr.Plot(label="PCA Plot")],
153
+ title="Word Similarity Heatmap and PCA Plot using OpenAI Embeddings",
154
+ description="Enter a list of words separated by commas. The app will calculate the cosine similarity between their OpenAI embeddings, display a compact heatmap of the upper triangle similarities, and show a PCA plot of the embeddings.",
155
+ )
156
+
157
+ # Launch the app
158
+ iface.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ annotated-types==0.7.0
3
+ anyio==4.4.0
4
+ certifi==2024.7.4
5
+ charset-normalizer==3.3.2
6
+ click==8.1.7
7
+ contourpy==1.2.1
8
+ cycler==0.12.1
9
+ distro==1.9.0
10
+ fastapi==0.112.2
11
+ ffmpy==0.4.0
12
+ filelock==3.15.4
13
+ fonttools==4.53.1
14
+ fsspec==2024.6.1
15
+ gradio==4.42.0
16
+ gradio_client==1.3.0
17
+ h11==0.14.0
18
+ httpcore==1.0.5
19
+ httpx==0.27.0
20
+ huggingface-hub==0.24.6
21
+ idna==3.8
22
+ importlib_resources==6.4.4
23
+ Jinja2==3.1.4
24
+ jiter==0.5.0
25
+ joblib==1.4.2
26
+ kiwisolver==1.4.5
27
+ markdown-it-py==3.0.0
28
+ MarkupSafe==2.1.5
29
+ matplotlib==3.9.2
30
+ mdurl==0.1.2
31
+ numpy==2.1.0
32
+ openai==1.42.0
33
+ orjson==3.10.7
34
+ packaging==24.1
35
+ pandas==2.2.2
36
+ pillow==10.4.0
37
+ pydantic==2.8.2
38
+ pydantic_core==2.20.1
39
+ pydub==0.25.1
40
+ Pygments==2.18.0
41
+ pyparsing==3.1.4
42
+ python-dateutil==2.9.0.post0
43
+ python-dotenv==1.0.1
44
+ python-multipart==0.0.9
45
+ pytz==2024.1
46
+ PyYAML==6.0.2
47
+ requests==2.32.3
48
+ rich==13.8.0
49
+ ruff==0.6.2
50
+ scikit-learn==1.5.1
51
+ scipy==1.14.1
52
+ seaborn==0.13.2
53
+ semantic-version==2.10.0
54
+ shellingham==1.5.4
55
+ six==1.16.0
56
+ sniffio==1.3.1
57
+ starlette==0.38.2
58
+ threadpoolctl==3.5.0
59
+ tomlkit==0.12.0
60
+ tqdm==4.66.5
61
+ typer==0.12.5
62
+ typing_extensions==4.12.2
63
+ tzdata==2024.1
64
+ urllib3==2.2.2
65
+ uvicorn==0.30.6
66
+ websockets==12.0