codeShare commited on
Commit
8833bbc
·
verified ·
1 Parent(s): 9fb003c

Upload sd_token_similarity_calculator.ipynb

Browse files
Files changed (1) hide show
  1. sd_token_similarity_calculator.ipynb +286 -14
sd_token_similarity_calculator.ipynb CHANGED
@@ -17,7 +17,7 @@
17
  {
18
  "cell_type": "markdown",
19
  "source": [
20
- "This Notebook is a Stable-diffusion tool which allows you to find similiar tokens from the SD 1.5 vocab.json that you can use for text-to-image generation"
21
  ],
22
  "metadata": {
23
  "id": "L7JTcbOdBPfh"
@@ -101,13 +101,15 @@
101
  {
102
  "cell_type": "code",
103
  "source": [
104
- "\n",
105
  "from transformers import AutoTokenizer\n",
106
  "tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n",
107
  "prompt= \"banana\" # @param {type:'string'}\n",
108
  "tokenizer_output = tokenizer(text = prompt)\n",
109
  "input_ids = tokenizer_output['input_ids']\n",
110
- "print(input_ids)"
 
 
 
111
  ],
112
  "metadata": {
113
  "id": "RPdkYzT2_X85"
@@ -115,16 +117,62 @@
115
  "execution_count": null,
116
  "outputs": []
117
  },
 
 
 
 
 
 
 
 
 
118
  {
119
  "cell_type": "code",
120
  "source": [
121
- "#Produce a list id IDs that are most similiar to the prompt ID at positiion 1\n",
 
 
122
  "\n",
123
- "id_A = input_ids[1]\n",
124
- "A = token[id_A]\n",
125
- "_A = LA.vector_norm(A, ord=2)\n",
126
- "dots = torch.zeros(NUM_TOKENS)\n",
127
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  "for index in range(NUM_TOKENS):\n",
129
  " id_B = index\n",
130
  " B = token[id_B]\n",
@@ -135,8 +183,12 @@
135
  "\n",
136
  "sorted, indices = torch.sort(dots,dim=0 , descending=True)\n",
137
  "#----#\n",
138
- "print(f'Calculated all cosine-similarities between the token {vocab[id_A]} with ID = {id_A} the rest of the {NUM_TOKENS} tokens as a 1x{sorted.shape[0]} tensor')\n",
139
- "print(f'Calculated indices as a 1x{indices.shape[0]} tensor')"
 
 
 
 
140
  ],
141
  "metadata": {
142
  "id": "juxsvco9B0iV"
@@ -144,6 +196,15 @@
144
  "execution_count": null,
145
  "outputs": []
146
  },
 
 
 
 
 
 
 
 
 
147
  {
148
  "cell_type": "code",
149
  "source": [
@@ -152,7 +213,7 @@
152
  "print_ID = False # @param {type:\"boolean\"}\n",
153
  "print_Similarity = True # @param {type:\"boolean\"}\n",
154
  "print_Name = True # @param {type:\"boolean\"}\n",
155
- "print_Divider = False # @param {type:\"boolean\"}\n",
156
  "\n",
157
  "for index in range(list_size):\n",
158
  " id = indices[index].item()\n",
@@ -166,10 +227,221 @@
166
  " print('--------')"
167
  ],
168
  "metadata": {
169
- "id": "YIEmLAzbHeuo"
 
 
 
 
170
  },
171
- "execution_count": null,
172
- "outputs": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  },
174
  {
175
  "cell_type": "markdown",
 
17
  {
18
  "cell_type": "markdown",
19
  "source": [
20
+ "This Notebook is a Stable-diffusion tool which allows you to find similiar tokens from the SD 1.5 vocab.json that you can use for text-to-image generation."
21
  ],
22
  "metadata": {
23
  "id": "L7JTcbOdBPfh"
 
101
  {
102
  "cell_type": "code",
103
  "source": [
 
104
  "from transformers import AutoTokenizer\n",
105
  "tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n",
106
  "prompt= \"banana\" # @param {type:'string'}\n",
107
  "tokenizer_output = tokenizer(text = prompt)\n",
108
  "input_ids = tokenizer_output['input_ids']\n",
109
+ "print(input_ids)\n",
110
+ "id_A = input_ids[1]\n",
111
+ "A = token[id_A]\n",
112
+ "_A = LA.vector_norm(A, ord=2)"
113
  ],
114
  "metadata": {
115
  "id": "RPdkYzT2_X85"
 
117
  "execution_count": null,
118
  "outputs": []
119
  },
120
+ {
121
+ "cell_type": "markdown",
122
+ "source": [
123
+ "OPTIONAL : Add/subtract + normalize above result with another token"
124
+ ],
125
+ "metadata": {
126
+ "id": "JKnz0aLFVGXc"
127
+ }
128
+ },
129
  {
130
  "cell_type": "code",
131
  "source": [
132
+ "mix_with = \"\" # @param {type:'string'}\n",
133
+ "mix_method = 'None' # @param [\"None\" , \"Average\", \"Subtract\"] {allow-input: true}\n",
134
+ "w = 0.5 # @param {type:\"slider\", min:0, max:1, step:0.01}\n",
135
  "\n",
 
 
 
 
136
  "\n",
137
+ "\n",
138
+ "tokenizer_output = tokenizer(text = mix_with)\n",
139
+ "input_ids = tokenizer_output['input_ids']\n",
140
+ "id_C = input_ids[1]\n",
141
+ "C = token[id_C]\n",
142
+ "_C = LA.vector_norm(C, ord=2)\n",
143
+ "\n",
144
+ "if (mix_method == \"Average\"):\n",
145
+ " A = w*A + (1-w)*C\n",
146
+ " _A = LA.vector_norm(A, ord=2)\n",
147
+ "\n",
148
+ "if (mix_method == \"Subtract\"):\n",
149
+ " tmp = w*A - (1-w)*C\n",
150
+ " _tmp = LA.vector_norm(tmp, ord=2)\n",
151
+ " A = tmp*((w*_A + (1-w)*_C)/_tmp)\n",
152
+ " _A = LA.vector_norm(A, ord=2)\n",
153
+ "\n",
154
+ "\n"
155
+ ],
156
+ "metadata": {
157
+ "id": "oXbNSRSKPgRr"
158
+ },
159
+ "execution_count": 6,
160
+ "outputs": []
161
+ },
162
+ {
163
+ "cell_type": "markdown",
164
+ "source": [
165
+ "Produce a list id IDs that are most similiar to the prompt ID at positiion 1 based on above result"
166
+ ],
167
+ "metadata": {
168
+ "id": "3uBSZ1vWVCew"
169
+ }
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "source": [
174
+ "\n",
175
+ "dots = torch.zeros(NUM_TOKENS)\n",
176
  "for index in range(NUM_TOKENS):\n",
177
  " id_B = index\n",
178
  " B = token[id_B]\n",
 
183
  "\n",
184
  "sorted, indices = torch.sort(dots,dim=0 , descending=True)\n",
185
  "#----#\n",
186
+ "if (mix_method == \"Average\"):\n",
187
+ " print(f'Calculated all cosine-similarities between the average of token {vocab[id_A]} and {vocab[id_C]} with ID = {id_A} and mixed ID = {id_C} as a 1x{sorted.shape[0]} tensor')\n",
188
+ "if (mix_method == \"Subtract\"):\n",
189
+ " print(f'Calculated all cosine-similarities between the subtract of token {vocab[id_A]} and {vocab[id_C]} with ID = {id_A} and mixed ID = {id_C} as a 1x{sorted.shape[0]} tensor')\n",
190
+ "if (mix_method == \"None\"):\n",
191
+ " print(f'Calculated all cosine-similarities between the token {vocab[id_A]} with ID = {id_A} the rest of the {NUM_TOKENS} tokens as a 1x{sorted.shape[0]} tensor')"
192
  ],
193
  "metadata": {
194
  "id": "juxsvco9B0iV"
 
196
  "execution_count": null,
197
  "outputs": []
198
  },
199
+ {
200
+ "cell_type": "markdown",
201
+ "source": [
202
+ "Print the sorted list from above result"
203
+ ],
204
+ "metadata": {
205
+ "id": "y-Ig3glrVQC3"
206
+ }
207
+ },
208
  {
209
  "cell_type": "code",
210
  "source": [
 
213
  "print_ID = False # @param {type:\"boolean\"}\n",
214
  "print_Similarity = True # @param {type:\"boolean\"}\n",
215
  "print_Name = True # @param {type:\"boolean\"}\n",
216
+ "print_Divider = True # @param {type:\"boolean\"}\n",
217
  "\n",
218
  "for index in range(list_size):\n",
219
  " id = indices[index].item()\n",
 
227
  " print('--------')"
228
  ],
229
  "metadata": {
230
+ "id": "YIEmLAzbHeuo",
231
+ "outputId": "843fbd7c-b208-49e0-9793-69bb36622c27",
232
+ "colab": {
233
+ "base_uri": "https://localhost:8080/"
234
+ }
235
  },
236
+ "execution_count": 5,
237
+ "outputs": [
238
+ {
239
+ "output_type": "stream",
240
+ "name": "stdout",
241
+ "text": [
242
+ "banana</w>\n",
243
+ "similiarity = 74.26 %\n",
244
+ "nude</w>\n",
245
+ "similiarity = 72.49 %\n",
246
+ "bananas</w>\n",
247
+ "similiarity = 30.34 %\n",
248
+ "nudes</w>\n",
249
+ "similiarity = 27.19 %\n",
250
+ "banan\n",
251
+ "similiarity = 25.08 %\n",
252
+ "ðŁįĮ</w>\n",
253
+ "similiarity = 22.27 %\n",
254
+ "naked</w>\n",
255
+ "similiarity = 22.12 %\n",
256
+ "orange</w>\n",
257
+ "similiarity = 19.53 %\n",
258
+ "cucumber</w>\n",
259
+ "similiarity = 17.36 %\n",
260
+ "nutella</w>\n",
261
+ "similiarity = 17.33 %\n",
262
+ "camel</w>\n",
263
+ "similiarity = 17.22 %\n",
264
+ "eggplant</w>\n",
265
+ "similiarity = 17.13 %\n",
266
+ "swimsuit</w>\n",
267
+ "similiarity = 16.62 %\n",
268
+ "chicken</w>\n",
269
+ "similiarity = 16.38 %\n",
270
+ "bikini</w>\n",
271
+ "similiarity = 16.08 %\n",
272
+ "grape</w>\n",
273
+ "similiarity = 16.01 %\n",
274
+ "ballerina</w>\n",
275
+ "similiarity = 16.01 %\n",
276
+ "mango</w>\n",
277
+ "similiarity = 16.0 %\n",
278
+ "manicure</w>\n",
279
+ "similiarity = 15.8 %\n",
280
+ "pencil</w>\n",
281
+ "similiarity = 15.62 %\n",
282
+ "yoga</w>\n",
283
+ "similiarity = 15.56 %\n",
284
+ "indian</w>\n",
285
+ "similiarity = 15.51 %\n",
286
+ "yellow</w>\n",
287
+ "similiarity = 15.51 %\n",
288
+ "venus</w>\n",
289
+ "similiarity = 15.5 %\n",
290
+ "snake</w>\n",
291
+ "similiarity = 15.41 %\n",
292
+ "dunk</w>\n",
293
+ "similiarity = 15.39 %\n",
294
+ "ters\n",
295
+ "similiarity = 15.27 %\n",
296
+ "underwear</w>\n",
297
+ "similiarity = 15.26 %\n",
298
+ "sunbathing</w>\n",
299
+ "similiarity = 15.15 %\n",
300
+ "potato</w>\n",
301
+ "similiarity = 15.04 %\n",
302
+ "milk</w>\n",
303
+ "similiarity = 14.91 %\n",
304
+ "bamboo</w>\n",
305
+ "similiarity = 14.85 %\n",
306
+ "selfie</w>\n",
307
+ "similiarity = 14.85 %\n",
308
+ "features</w>\n",
309
+ "similiarity = 14.82 %\n",
310
+ "know\n",
311
+ "similiarity = 14.79 %\n",
312
+ "oilpainting</w>\n",
313
+ "similiarity = 14.7 %\n",
314
+ "reas\n",
315
+ "similiarity = 14.63 %\n",
316
+ "croissant</w>\n",
317
+ "similiarity = 14.61 %\n",
318
+ "oranges</w>\n",
319
+ "similiarity = 14.59 %\n",
320
+ "conversation</w>\n",
321
+ "similiarity = 14.57 %\n",
322
+ "photoshoot</w>\n",
323
+ "similiarity = 14.55 %\n",
324
+ "ery\n",
325
+ "similiarity = 14.49 %\n",
326
+ "pear</w>\n",
327
+ "similiarity = 14.42 %\n",
328
+ "mcnam\n",
329
+ "similiarity = 14.42 %\n",
330
+ "dens</w>\n",
331
+ "similiarity = 14.38 %\n",
332
+ "cigarette</w>\n",
333
+ "similiarity = 14.33 %\n",
334
+ "tangerine</w>\n",
335
+ "similiarity = 14.3 %\n",
336
+ "aluminum</w>\n",
337
+ "similiarity = 14.28 %\n",
338
+ "plum</w>\n",
339
+ "similiarity = 14.28 %\n",
340
+ "rape</w>\n",
341
+ "similiarity = 14.24 %\n",
342
+ "apple</w>\n",
343
+ "similiarity = 14.2 %\n",
344
+ "apd</w>\n",
345
+ "similiarity = 14.17 %\n",
346
+ "safari</w>\n",
347
+ "similiarity = 14.09 %\n",
348
+ "yolo</w>\n",
349
+ "similiarity = 14.06 %\n",
350
+ "hoodie</w>\n",
351
+ "similiarity = 13.96 %\n",
352
+ "cabaret</w>\n",
353
+ "similiarity = 13.91 %\n",
354
+ "superman</w>\n",
355
+ "similiarity = 13.9 %\n",
356
+ "saree</w>\n",
357
+ "similiarity = 13.86 %\n",
358
+ "mommy</w>\n",
359
+ "similiarity = 13.78 %\n",
360
+ "sausage</w>\n",
361
+ "similiarity = 13.76 %\n",
362
+ "marshmallow</w>\n",
363
+ "similiarity = 13.75 %\n",
364
+ "latex</w>\n",
365
+ "similiarity = 13.74 %\n",
366
+ "blonde</w>\n",
367
+ "similiarity = 13.69 %\n",
368
+ "champagne</w>\n",
369
+ "similiarity = 13.62 %\n",
370
+ "parachute</w>\n",
371
+ "similiarity = 13.61 %\n",
372
+ "stor</w>\n",
373
+ "similiarity = 13.58 %\n",
374
+ "feminine</w>\n",
375
+ "similiarity = 13.55 %\n",
376
+ "ayu</w>\n",
377
+ "similiarity = 13.5 %\n",
378
+ "âĢ¼ï¸ı</w>\n",
379
+ "similiarity = 13.45 %\n",
380
+ "naked\n",
381
+ "similiarity = 13.45 %\n",
382
+ "poop</w>\n",
383
+ "similiarity = 13.44 %\n",
384
+ "honeymoon</w>\n",
385
+ "similiarity = 13.41 %\n",
386
+ "giraffe</w>\n",
387
+ "similiarity = 13.37 %\n",
388
+ "zebra</w>\n",
389
+ "similiarity = 13.35 %\n",
390
+ "mud</w>\n",
391
+ "similiarity = 13.35 %\n",
392
+ "blanket</w>\n",
393
+ "similiarity = 13.34 %\n",
394
+ "silly</w>\n",
395
+ "similiarity = 13.32 %\n",
396
+ "animal</w>\n",
397
+ "similiarity = 13.31 %\n",
398
+ "malayalam</w>\n",
399
+ "similiarity = 13.25 %\n",
400
+ "mustache</w>\n",
401
+ "similiarity = 13.25 %\n",
402
+ "mrc</w>\n",
403
+ "similiarity = 13.24 %\n",
404
+ "yuri</w>\n",
405
+ "similiarity = 13.23 %\n",
406
+ "japanese</w>\n",
407
+ "similiarity = 13.19 %\n",
408
+ "gibbs</w>\n",
409
+ "similiarity = 13.16 %\n",
410
+ "ðŁĻĤ\n",
411
+ "similiarity = 13.15 %\n",
412
+ "rhubarb</w>\n",
413
+ "similiarity = 13.14 %\n",
414
+ "trac\n",
415
+ "similiarity = 13.13 %\n",
416
+ "polaroid</w>\n",
417
+ "similiarity = 13.08 %\n",
418
+ "lunch</w>\n",
419
+ "similiarity = 13.04 %\n",
420
+ "sandal</w>\n",
421
+ "similiarity = 13.03 %\n",
422
+ "popart</w>\n",
423
+ "similiarity = 13.02 %\n",
424
+ "kissing</w>\n",
425
+ "similiarity = 13.02 %\n",
426
+ "funeral</w>\n",
427
+ "similiarity = 13.02 %\n",
428
+ "runway</w>\n",
429
+ "similiarity = 13.01 %\n",
430
+ "milk\n",
431
+ "similiarity = 12.98 %\n",
432
+ "tutu</w>\n",
433
+ "similiarity = 12.96 %\n",
434
+ "flag</w>\n",
435
+ "similiarity = 12.95 %\n",
436
+ "hours</w>\n",
437
+ "similiarity = 12.95 %\n",
438
+ "monet</w>\n",
439
+ "similiarity = 12.91 %\n",
440
+ "ali</w>\n",
441
+ "similiarity = 12.89 %\n"
442
+ ]
443
+ }
444
+ ]
445
  },
446
  {
447
  "cell_type": "markdown",