Upload sd_token_similarity_calculator.ipynb
Browse files- sd_token_similarity_calculator.ipynb +223 -172
sd_token_similarity_calculator.ipynb
CHANGED
|
@@ -118,8 +118,7 @@
|
|
| 118 |
],
|
| 119 |
"metadata": {
|
| 120 |
"id": "Ch9puvwKH1s3",
|
| 121 |
-
"collapsed": true
|
| 122 |
-
"cellView": "form"
|
| 123 |
},
|
| 124 |
"execution_count": null,
|
| 125 |
"outputs": []
|
|
@@ -133,7 +132,7 @@
|
|
| 133 |
"tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n",
|
| 134 |
"\n",
|
| 135 |
"# @markdown Write name of token to match against\n",
|
| 136 |
-
"token_name = \"
|
| 137 |
"\n",
|
| 138 |
"prompt = token_name\n",
|
| 139 |
"# @markdown (optional) Mix the token with something else\n",
|
|
@@ -298,8 +297,10 @@
|
|
| 298 |
"source": [
|
| 299 |
"# @title ⚡+🖼️ -> 📝 Token-Sampling Image interrogator\n",
|
| 300 |
"#-----#\n",
|
|
|
|
| 301 |
"import shelve\n",
|
| 302 |
"db_vocab = shelve.open(VOCAB_FILENAME)\n",
|
|
|
|
| 303 |
"# @markdown # What do you want to to mimic?\n",
|
| 304 |
"use = '🖼️image_encoding from image' # @param ['📝text_encoding from prompt', '🖼️image_encoding from image']\n",
|
| 305 |
"# @markdown --------------------------\n",
|
|
@@ -317,7 +318,7 @@
|
|
| 317 |
" return list(uploaded.keys())\n",
|
| 318 |
"#Get image\n",
|
| 319 |
"# You can use \"http://images.cocodataset.org/val2017/000000039769.jpg\" for testing\n",
|
| 320 |
-
"image_url = \"\" # @param {\"type\":\"string\",\"placeholder\":\"leave empty for local upload (scroll down to see it)\"}\n",
|
| 321 |
"colab_image_path = \"\" # @param {\"type\":\"string\",\"placeholder\": \"eval. as '/content/sd_tokens/' + **your input**\"}\n",
|
| 322 |
"# @markdown --------------------------\n",
|
| 323 |
"from PIL import Image\n",
|
|
@@ -360,13 +361,12 @@
|
|
| 360 |
"#-----#\n",
|
| 361 |
"# @markdown # The output...\n",
|
| 362 |
"must_start_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
|
| 363 |
-
"must_contain = \"
|
| 364 |
"must_end_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
|
| 365 |
-
"token_B = must_contain\n",
|
| 366 |
"# @markdown -----\n",
|
| 367 |
"# @markdown # Use a range of tokens from the vocab.json (slow method)\n",
|
| 368 |
"start_search_at_index = 1700 # @param {type:\"slider\", min:0, max: 49407, step:100}\n",
|
| 369 |
-
"# @markdown The lower the start_index, the more similiar the sampled tokens will be to the target token assigned in the '⚡ Get similiar tokens' cell\"\n",
|
| 370 |
"start_search_at_ID = start_search_at_index\n",
|
| 371 |
"search_range = 100 # @param {type:\"slider\", min:100, max: 2000, step:0}\n",
|
| 372 |
"restrictions = 'None' # @param [\"None\", \"Suffix only\", \"Prefix only\"]\n",
|
|
@@ -378,186 +378,238 @@
|
|
| 378 |
"_enable = False # param {\"type\":\"boolean\"}\n",
|
| 379 |
"prompt_items = \"\" # param {\"type\":\"string\",\"placeholder\":\"{item1|item2|...}\"}\n",
|
| 380 |
"#-----#\n",
|
| 381 |
-
"name_B = must_contain\n",
|
| 382 |
"#-----#\n",
|
| 383 |
"START = start_search_at_ID\n",
|
| 384 |
-
"RANGE = min(search_range ,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
"#-----#\n",
|
| 386 |
-
"
|
| 387 |
-
"
|
| 388 |
-
"import re\n",
|
| 389 |
"#-----#\n",
|
| 390 |
-
"
|
| 391 |
-
"
|
| 392 |
-
"
|
| 393 |
-
"
|
| 394 |
-
"
|
| 395 |
-
"
|
| 396 |
-
"
|
| 397 |
" #-----#\n",
|
| 398 |
-
"
|
| 399 |
-
"
|
| 400 |
-
"
|
| 401 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 402 |
" continue\n",
|
| 403 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
" if restrictions == \"Prefix only\":\n",
|
|
|
|
|
|
|
|
|
|
| 405 |
" continue\n",
|
| 406 |
-
"
|
| 407 |
-
"
|
| 408 |
-
"
|
| 409 |
-
"
|
| 410 |
-
"
|
| 411 |
-
"
|
| 412 |
-
"
|
| 413 |
-
"
|
| 414 |
-
"
|
| 415 |
-
"
|
| 416 |
-
"
|
| 417 |
-
"
|
| 418 |
-
"
|
| 419 |
-
"
|
| 420 |
-
"
|
| 421 |
-
"
|
| 422 |
-
"
|
| 423 |
-
" sim_CB = torch.nn.functional.cosine_similarity(text_features, image_features) * logit_scale\n",
|
| 424 |
-
" #-----#\n",
|
| 425 |
-
" if(use == '📝text_encoding from prompt'):\n",
|
| 426 |
-
" ids_CB = processor.tokenizer(text=name_CB, padding=use_token_padding, return_tensors=\"pt\")\n",
|
| 427 |
-
" text_features = model.get_text_features(**ids_CB)\n",
|
| 428 |
-
" text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
| 429 |
-
" sim_CB = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
|
| 430 |
-
" #-----#\n",
|
| 431 |
-
" #-----#\n",
|
| 432 |
-
" if restrictions == \"Prefix only\":\n",
|
| 433 |
" result = sim_CB\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
" result = result.item()\n",
|
| 435 |
" dots[index] = result\n",
|
| 436 |
-
" continue\n",
|
| 437 |
-
" #-----#\n",
|
| 438 |
-
" if(use == '🖼️image_encoding from image'):\n",
|
| 439 |
-
" name_BC = must_start_with + name_B + name_C + must_end_with\n",
|
| 440 |
-
" ids_BC = processor.tokenizer(text=name_BC, padding=use_token_padding, return_tensors=\"pt\")\n",
|
| 441 |
-
" text_features = model.get_text_features(**ids_BC)\n",
|
| 442 |
-
" text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
| 443 |
-
" logit_scale = model.logit_scale.exp()\n",
|
| 444 |
-
" torch.matmul(text_features, image_features.t()) * logit_scale\n",
|
| 445 |
-
" sim_BC = torch.nn.functional.cosine_similarity(text_features, image_features) * logit_scale\n",
|
| 446 |
-
" #-----#\n",
|
| 447 |
-
" if(use == '📝text_encoding from prompt'):\n",
|
| 448 |
-
" name_BC = must_start_with + name_B + name_C + must_end_with\n",
|
| 449 |
-
" ids_BC = processor.tokenizer(text=name_BC, padding=use_token_padding, return_tensors=\"pt\")\n",
|
| 450 |
-
" text_features = model.get_text_features(**ids_BC)\n",
|
| 451 |
-
" text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
| 452 |
-
" sim_BC = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
|
| 453 |
-
" #-----#\n",
|
| 454 |
-
" result = sim_CB\n",
|
| 455 |
-
" if(sim_BC > sim_CB):\n",
|
| 456 |
-
" is_BC[index] = 1\n",
|
| 457 |
-
" result = sim_BC\n",
|
| 458 |
-
" #-----#\n",
|
| 459 |
-
" #result = absolute_value(result.item())\n",
|
| 460 |
-
" result = result.item()\n",
|
| 461 |
-
" dots[index] = result\n",
|
| 462 |
-
"#----#\n",
|
| 463 |
-
"sorted, indices = torch.sort(dots,dim=0 , descending=True)\n",
|
| 464 |
-
"# @markdown ----------\n",
|
| 465 |
-
"# @markdown # Print options\n",
|
| 466 |
-
"list_size = 100 # @param {type:'number'}\n",
|
| 467 |
-
"print_ID = False # @param {type:\"boolean\"}\n",
|
| 468 |
-
"print_Similarity = True # @param {type:\"boolean\"}\n",
|
| 469 |
-
"print_Name = True # @param {type:\"boolean\"}\n",
|
| 470 |
-
"print_Divider = True # @param {type:\"boolean\"}\n",
|
| 471 |
-
"#----#\n",
|
| 472 |
-
"if (print_Divider):\n",
|
| 473 |
-
" print('//---//')\n",
|
| 474 |
-
"#----#\n",
|
| 475 |
-
"print('')\n",
|
| 476 |
-
"print(f'These token pairings within the range ID = {START} to ID = {START + RANGE} most closely match the text_encoding for {prompt_A} : ')\n",
|
| 477 |
-
"print('')\n",
|
| 478 |
-
"#----#\n",
|
| 479 |
-
"aheads = \"{\"\n",
|
| 480 |
-
"trails = \"{\"\n",
|
| 481 |
-
"tmp = \"\"\n",
|
| 482 |
-
"#----#\n",
|
| 483 |
-
"max_sim_ahead = 0\n",
|
| 484 |
-
"max_sim_trail = 0\n",
|
| 485 |
-
"sim = 0\n",
|
| 486 |
-
"max_name_ahead = ''\n",
|
| 487 |
-
"max_name_trail = ''\n",
|
| 488 |
-
"#----#\n",
|
| 489 |
-
"for index in range(min(list_size,RANGE)):\n",
|
| 490 |
-
" id = START + indices[index].item()\n",
|
| 491 |
-
" name = db_vocab[f'{id}']\n",
|
| 492 |
-
" #-----#\n",
|
| 493 |
-
" if (name.find('</w>')<=-1):\n",
|
| 494 |
-
" name = name + '-'\n",
|
| 495 |
-
" else:\n",
|
| 496 |
-
" name = name.replace('</w>', ' ')\n",
|
| 497 |
-
" if(is_BC[index]>0):\n",
|
| 498 |
-
" trails = trails + name + \"|\"\n",
|
| 499 |
-
" else:\n",
|
| 500 |
-
" aheads = aheads + name + \"|\"\n",
|
| 501 |
" #----#\n",
|
| 502 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
" #----#\n",
|
| 504 |
-
" if(
|
| 505 |
-
"
|
| 506 |
-
" max_sim_ahead = sim\n",
|
| 507 |
-
" max_name_ahead = name\n",
|
| 508 |
-
" else:\n",
|
| 509 |
-
" if sim>max_sim_trail:\n",
|
| 510 |
-
" max_sim_trail = sim\n",
|
| 511 |
-
" max_name_trail = name\n",
|
| 512 |
-
"#------#\n",
|
| 513 |
-
"trails = (trails + \"&&&&\").replace(\"|&&&&\", \"}\").replace(\"</w>\", \" \").replace(\"{&&&&\", \"\")\n",
|
| 514 |
-
"aheads = (aheads + \"&&&&\").replace(\"|&&&&\", \"}\").replace(\"</w>\", \" \").replace(\"{&&&&\", \"\")\n",
|
| 515 |
-
"max_sim_ahead=max_sim_ahead\n",
|
| 516 |
-
"max_sim_ahead=max_sim_trail\n",
|
| 517 |
-
"#-----#\n",
|
| 518 |
-
"print(f\"place these items ahead of prompt : {aheads}\")\n",
|
| 519 |
-
"print(\"\")\n",
|
| 520 |
-
"print(f\"place these items behind the prompt : {trails}\")\n",
|
| 521 |
-
"print(\"\")\n",
|
| 522 |
-
"print(f\"max_similarity = {max_sim_ahead} % when using '{max_name_ahead + must_contain}' \")\n",
|
| 523 |
-
"print(\"\")\n",
|
| 524 |
-
"print(f\"max_similarity = {max_sim_trail} % when using '{must_contain + max_name_trail}' \")\n",
|
| 525 |
-
"#-----#\n",
|
| 526 |
-
"#STEP 2\n",
|
| 527 |
-
"import random\n",
|
| 528 |
-
"names = {}\n",
|
| 529 |
-
"NUM_PERMUTATIONS = 4\n",
|
| 530 |
-
"#-----#\n",
|
| 531 |
-
"dots = torch.zeros(NUM_PERMUTATIONS)\n",
|
| 532 |
-
"for index in range(NUM_PERMUTATIONS):\n",
|
| 533 |
-
" name = must_start_with\n",
|
| 534 |
-
" if index == 0 : name = name + must_contain\n",
|
| 535 |
-
" if index == 1 : name = name + max_name_ahead + must_contain\n",
|
| 536 |
-
" if index == 2 : name = name + must_contain + max_name_trail\n",
|
| 537 |
-
" if index == 3 : name = name + max_name_ahead + must_contain + max_name_trail\n",
|
| 538 |
-
" name = name + must_end_with\n",
|
| 539 |
-
" #----#\n",
|
| 540 |
-
" ids = processor.tokenizer(text=name, padding=use_token_padding, return_tensors=\"pt\")\n",
|
| 541 |
" #----#\n",
|
|
|
|
|
|
|
|
|
|
| 542 |
" if(use == '🖼️image_encoding from image'):\n",
|
| 543 |
-
"
|
| 544 |
-
"
|
| 545 |
-
"
|
| 546 |
-
"
|
| 547 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 548 |
" #-----#\n",
|
| 549 |
-
"
|
| 550 |
-
"
|
| 551 |
-
"
|
| 552 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
" #-----#\n",
|
| 554 |
-
"
|
| 555 |
-
"
|
| 556 |
-
"
|
| 557 |
-
"
|
| 558 |
-
"
|
| 559 |
-
"
|
| 560 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 561 |
" print(f'similiarity = {round(sorted[index].item(),2)} %')\n",
|
| 562 |
" print('------')\n",
|
| 563 |
"#------#\n",
|
|
@@ -565,8 +617,7 @@
|
|
| 565 |
],
|
| 566 |
"metadata": {
|
| 567 |
"collapsed": true,
|
| 568 |
-
"id": "fi0jRruI0-tu"
|
| 569 |
-
"cellView": "form"
|
| 570 |
},
|
| 571 |
"execution_count": null,
|
| 572 |
"outputs": []
|
|
|
|
| 118 |
],
|
| 119 |
"metadata": {
|
| 120 |
"id": "Ch9puvwKH1s3",
|
| 121 |
+
"collapsed": true
|
|
|
|
| 122 |
},
|
| 123 |
"execution_count": null,
|
| 124 |
"outputs": []
|
|
|
|
| 132 |
"tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n",
|
| 133 |
"\n",
|
| 134 |
"# @markdown Write name of token to match against\n",
|
| 135 |
+
"token_name = \"dogs\" # @param {type:'string',\"placeholder\":\"leave empty for random value token\"}\n",
|
| 136 |
"\n",
|
| 137 |
"prompt = token_name\n",
|
| 138 |
"# @markdown (optional) Mix the token with something else\n",
|
|
|
|
| 297 |
"source": [
|
| 298 |
"# @title ⚡+🖼️ -> 📝 Token-Sampling Image interrogator\n",
|
| 299 |
"#-----#\n",
|
| 300 |
+
"NUM_TOKENS = 49407\n",
|
| 301 |
"import shelve\n",
|
| 302 |
"db_vocab = shelve.open(VOCAB_FILENAME)\n",
|
| 303 |
+
"print(f'using the tokens found in {VOCAB_FILENAME}.db as the vocab')\n",
|
| 304 |
"# @markdown # What do you want to to mimic?\n",
|
| 305 |
"use = '🖼️image_encoding from image' # @param ['📝text_encoding from prompt', '🖼️image_encoding from image']\n",
|
| 306 |
"# @markdown --------------------------\n",
|
|
|
|
| 318 |
" return list(uploaded.keys())\n",
|
| 319 |
"#Get image\n",
|
| 320 |
"# You can use \"http://images.cocodataset.org/val2017/000000039769.jpg\" for testing\n",
|
| 321 |
+
"image_url = \"http://images.cocodataset.org/val2017/000000039769.jpg\" # @param {\"type\":\"string\",\"placeholder\":\"leave empty for local upload (scroll down to see it)\"}\n",
|
| 322 |
"colab_image_path = \"\" # @param {\"type\":\"string\",\"placeholder\": \"eval. as '/content/sd_tokens/' + **your input**\"}\n",
|
| 323 |
"# @markdown --------------------------\n",
|
| 324 |
"from PIL import Image\n",
|
|
|
|
| 361 |
"#-----#\n",
|
| 362 |
"# @markdown # The output...\n",
|
| 363 |
"must_start_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
|
| 364 |
+
"must_contain = \" pet \" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
|
| 365 |
"must_end_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
|
|
|
|
| 366 |
"# @markdown -----\n",
|
| 367 |
"# @markdown # Use a range of tokens from the vocab.json (slow method)\n",
|
| 368 |
"start_search_at_index = 1700 # @param {type:\"slider\", min:0, max: 49407, step:100}\n",
|
| 369 |
+
"# @markdown The lower the start_index, the more similiar the sampled tokens will be to the target token assigned in the '⚡ Get similiar tokens' cell\". If the cell was not run, then it will use tokens ordered by similarity to the \"girl\\</w>\" token\n",
|
| 370 |
"start_search_at_ID = start_search_at_index\n",
|
| 371 |
"search_range = 100 # @param {type:\"slider\", min:100, max: 2000, step:0}\n",
|
| 372 |
"restrictions = 'None' # @param [\"None\", \"Suffix only\", \"Prefix only\"]\n",
|
|
|
|
| 378 |
"_enable = False # param {\"type\":\"boolean\"}\n",
|
| 379 |
"prompt_items = \"\" # param {\"type\":\"string\",\"placeholder\":\"{item1|item2|...}\"}\n",
|
| 380 |
"#-----#\n",
|
|
|
|
| 381 |
"#-----#\n",
|
| 382 |
"START = start_search_at_ID\n",
|
| 383 |
+
"RANGE = min(search_range , max(1,NUM_TOKENS - start_search_at_ID))\n",
|
| 384 |
+
"#-----#\n",
|
| 385 |
+
"import math, random\n",
|
| 386 |
+
"CHUNK = math.floor(NUM_TOKENS/(RANGE*100))\n",
|
| 387 |
+
"\n",
|
| 388 |
+
"ITERS = 3\n",
|
| 389 |
+
"#-----#\n",
|
| 390 |
+
"#LOOP START\n",
|
| 391 |
+
"#-----#\n",
|
| 392 |
+
"\n",
|
| 393 |
+
"results_sim = torch.zeros(ITERS+1)\n",
|
| 394 |
+
"results_name = {}\n",
|
| 395 |
+
"\n",
|
| 396 |
+
"# Check if original solution is best\n",
|
| 397 |
+
"best_sim = 0\n",
|
| 398 |
+
"name = must_start_with + must_contain + must_end_with\n",
|
| 399 |
+
"ids = processor.tokenizer(text=name, padding=use_token_padding, return_tensors=\"pt\")\n",
|
| 400 |
+
"text_features = model.get_text_features(**ids)\n",
|
| 401 |
+
"text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
| 402 |
+
"#------#\n",
|
| 403 |
+
"if(use == '🖼️image_encoding from image'):\n",
|
| 404 |
+
" logit_scale = model.logit_scale.exp()\n",
|
| 405 |
+
" torch.matmul(text_features, image_features.t()) * logit_scale\n",
|
| 406 |
+
" sim = torch.nn.functional.cosine_similarity(text_features, image_features) * logit_scale\n",
|
| 407 |
"#-----#\n",
|
| 408 |
+
"if(use == '📝text_encoding from prompt'):\n",
|
| 409 |
+
" sim = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
|
|
|
|
| 410 |
"#-----#\n",
|
| 411 |
+
"best_sim = sim\n",
|
| 412 |
+
"name_B = must_contain\n",
|
| 413 |
+
"#-----#\n",
|
| 414 |
+
"for iter in range(ITERS):\n",
|
| 415 |
+
" dots = torch.zeros(RANGE)\n",
|
| 416 |
+
" is_trail = torch.zeros(RANGE)\n",
|
| 417 |
+
" import re\n",
|
| 418 |
" #-----#\n",
|
| 419 |
+
"\n",
|
| 420 |
+
" _start = START + iter*CHUNK + iter*random.randint(1,CHUNK)\n",
|
| 421 |
+
" results_name[iter] = name_B\n",
|
| 422 |
+
" results_sim[iter] = best_sim\n",
|
| 423 |
+
"\n",
|
| 424 |
+
" for index in range(RANGE):\n",
|
| 425 |
+
" id_C = min(_start + index, NUM_TOKENS)\n",
|
| 426 |
+
" name_C = db_vocab[f'{id_C}']\n",
|
| 427 |
+
" is_Prefix = 0\n",
|
| 428 |
+
" #Skip if non-AZ characters are found\n",
|
| 429 |
+
" #???\n",
|
| 430 |
+
" #-----#\n",
|
| 431 |
+
" # Decide if we should process prefix/suffix tokens\n",
|
| 432 |
+
" if name_C.find('</w>')<=-1:\n",
|
| 433 |
+
" is_Prefix = 1\n",
|
| 434 |
+
" if restrictions != \"Prefix only\":\n",
|
| 435 |
+
" continue\n",
|
| 436 |
+
" else:\n",
|
| 437 |
+
" if restrictions == \"Prefix only\":\n",
|
| 438 |
+
" continue\n",
|
| 439 |
+
" #-----#\n",
|
| 440 |
+
" # Decide if char-size is within range\n",
|
| 441 |
+
" if len(name_C) < min_char_size:\n",
|
| 442 |
" continue\n",
|
| 443 |
+
" if len(name_C) > min_char_size + char_range:\n",
|
| 444 |
+
" continue\n",
|
| 445 |
+
" #-----#\n",
|
| 446 |
+
" name_CB = must_start_with + name_C + name_B + must_end_with\n",
|
| 447 |
+
" if is_Prefix>0:\n",
|
| 448 |
+
" name_CB = must_start_with + ' ' + name_C + '-' + name_B + ' ' + must_end_with\n",
|
| 449 |
+
" #-----#\n",
|
| 450 |
+
" if(use == '🖼️image_encoding from image'):\n",
|
| 451 |
+
" ids_CB = processor.tokenizer(text=name_CB, padding=use_token_padding, return_tensors=\"pt\")\n",
|
| 452 |
+
" text_features = model.get_text_features(**ids_CB)\n",
|
| 453 |
+
" text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
| 454 |
+
" logit_scale = model.logit_scale.exp()\n",
|
| 455 |
+
" torch.matmul(text_features, image_features.t()) * logit_scale\n",
|
| 456 |
+
" sim_CB = torch.nn.functional.cosine_similarity(text_features, image_features) * logit_scale\n",
|
| 457 |
+
" #-----#\n",
|
| 458 |
+
" if(use == '📝text_encoding from prompt'):\n",
|
| 459 |
+
" ids_CB = processor.tokenizer(text=name_CB, padding=use_token_padding, return_tensors=\"pt\")\n",
|
| 460 |
+
" text_features = model.get_text_features(**ids_CB)\n",
|
| 461 |
+
" text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
| 462 |
+
" sim_CB = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
|
| 463 |
+
" #-----#\n",
|
| 464 |
+
" #-----#\n",
|
| 465 |
" if restrictions == \"Prefix only\":\n",
|
| 466 |
+
" result = sim_CB\n",
|
| 467 |
+
" result = result.item()\n",
|
| 468 |
+
" dots[index] = result\n",
|
| 469 |
" continue\n",
|
| 470 |
+
" #-----#\n",
|
| 471 |
+
" if(use == '🖼️image_encoding from image'):\n",
|
| 472 |
+
" name_BC = must_start_with + name_B + name_C + must_end_with\n",
|
| 473 |
+
" ids_BC = processor.tokenizer(text=name_BC, padding=use_token_padding, return_tensors=\"pt\")\n",
|
| 474 |
+
" text_features = model.get_text_features(**ids_BC)\n",
|
| 475 |
+
" text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
| 476 |
+
" logit_scale = model.logit_scale.exp()\n",
|
| 477 |
+
" torch.matmul(text_features, image_features.t()) * logit_scale\n",
|
| 478 |
+
" sim_BC = torch.nn.functional.cosine_similarity(text_features, image_features) * logit_scale\n",
|
| 479 |
+
" #-----#\n",
|
| 480 |
+
" if(use == '📝text_encoding from prompt'):\n",
|
| 481 |
+
" name_BC = must_start_with + name_B + name_C + must_end_with\n",
|
| 482 |
+
" ids_BC = processor.tokenizer(text=name_BC, padding=use_token_padding, return_tensors=\"pt\")\n",
|
| 483 |
+
" text_features = model.get_text_features(**ids_BC)\n",
|
| 484 |
+
" text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
| 485 |
+
" sim_BC = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
|
| 486 |
+
" #-----#\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
" result = sim_CB\n",
|
| 488 |
+
" if(sim_BC > sim_CB):\n",
|
| 489 |
+
" is_trail[index] = 1\n",
|
| 490 |
+
" result = sim_BC\n",
|
| 491 |
+
" #-----#\n",
|
| 492 |
+
" #result = absolute_value(result.item())\n",
|
| 493 |
" result = result.item()\n",
|
| 494 |
" dots[index] = result\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
" #----#\n",
|
| 496 |
+
" sorted, indices = torch.sort(dots,dim=0 , descending=True)\n",
|
| 497 |
+
" # @markdown ----------\n",
|
| 498 |
+
" # @markdown # Print options\n",
|
| 499 |
+
" list_size = 100 # @param {type:'number'}\n",
|
| 500 |
+
" print_ID = False # @param {type:\"boolean\"}\n",
|
| 501 |
+
" print_Similarity = True # @param {type:\"boolean\"}\n",
|
| 502 |
+
" print_Name = True # @param {type:\"boolean\"}\n",
|
| 503 |
+
" print_Divider = True # @param {type:\"boolean\"}\n",
|
| 504 |
" #----#\n",
|
| 505 |
+
" if (print_Divider):\n",
|
| 506 |
+
" print('//---//')\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
" #----#\n",
|
| 508 |
+
" print('')\n",
|
| 509 |
+
"\n",
|
| 510 |
+
" used_reference = f'the text_encoding for {prompt_A}'\n",
|
| 511 |
" if(use == '🖼️image_encoding from image'):\n",
|
| 512 |
+
" used_reference = 'the image input'\n",
|
| 513 |
+
" print(f'These token pairings within the range ID = {START} to ID = {START + RANGE} most closely match {used_reference}: ')\n",
|
| 514 |
+
" print('')\n",
|
| 515 |
+
" #----#\n",
|
| 516 |
+
" aheads = \"{\"\n",
|
| 517 |
+
" trails = \"{\"\n",
|
| 518 |
+
" tmp = \"\"\n",
|
| 519 |
+
" #----#\n",
|
| 520 |
+
" max_sim_ahead = 0\n",
|
| 521 |
+
" max_sim_trail = 0\n",
|
| 522 |
+
" sim = 0\n",
|
| 523 |
+
" max_name_ahead = ''\n",
|
| 524 |
+
" max_name_trail = ''\n",
|
| 525 |
+
" #----#\n",
|
| 526 |
+
" for index in range(min(list_size,RANGE)):\n",
|
| 527 |
+
" id = START + indices[index].item()\n",
|
| 528 |
+
" name = db_vocab[f'{id}']\n",
|
| 529 |
+
" #-----#\n",
|
| 530 |
+
" if (name.find('</w>')<=-1):\n",
|
| 531 |
+
" name = name + '-'\n",
|
| 532 |
+
" if(is_trail[index]>0):\n",
|
| 533 |
+
" trails = trails + name + \"|\"\n",
|
| 534 |
+
" else:\n",
|
| 535 |
+
" aheads = aheads + name + \"|\"\n",
|
| 536 |
+
" #----#\n",
|
| 537 |
+
" sim = sorted[index].item()\n",
|
| 538 |
+
" #----#\n",
|
| 539 |
+
" if(is_trail[index]>0):\n",
|
| 540 |
+
" if sim>max_sim_trail:\n",
|
| 541 |
+
" max_sim_trail = sim\n",
|
| 542 |
+
" max_name_trail = name\n",
|
| 543 |
+
" max_name_trail = max_name_trail.strip()\n",
|
| 544 |
+
"\n",
|
| 545 |
+
" else:\n",
|
| 546 |
+
" if sim>max_sim_ahead:\n",
|
| 547 |
+
" max_sim_ahead = sim\n",
|
| 548 |
+
" max_name_ahead = name\n",
|
| 549 |
+
" #------#\n",
|
| 550 |
+
" trails = (trails + \"&&&&\").replace(\"|&&&&\", \"}\").replace(\"</w>\", \" \").replace(\"{&&&&\", \"\")\n",
|
| 551 |
+
" aheads = (aheads + \"&&&&\").replace(\"|&&&&\", \"}\").replace(\"</w>\", \" \").replace(\"{&&&&\", \"\")\n",
|
| 552 |
" #-----#\n",
|
| 553 |
+
" print(f\"place these items ahead of prompt : {aheads}\")\n",
|
| 554 |
+
" print(\"\")\n",
|
| 555 |
+
" print(f\"place these items behind the prompt : {trails}\")\n",
|
| 556 |
+
" print(\"\")\n",
|
| 557 |
+
"\n",
|
| 558 |
+
" tmp = must_start_with + ' ' + max_name_ahead + name_B + ' ' + must_end_with\n",
|
| 559 |
+
" tmp = tmp.strip()\n",
|
| 560 |
+
" print(f\"max_similarity_ahead = {round(max_sim_ahead,2)} % when using '{tmp}' \")\n",
|
| 561 |
+
" print(\"\")\n",
|
| 562 |
+
" tmp = must_start_with + ' ' + name_B + max_name_trail + ' ' + must_end_with\n",
|
| 563 |
+
" tmp = tmp.strip()\n",
|
| 564 |
+
" print(f\"max_similarity_trail = {round(max_sim_trail,2)} % when using '{tmp}' \")\n",
|
| 565 |
" #-----#\n",
|
| 566 |
+
" #STEP 2\n",
|
| 567 |
+
" import random\n",
|
| 568 |
+
" names = {}\n",
|
| 569 |
+
" NUM_PERMUTATIONS = 4\n",
|
| 570 |
+
" #-----#\n",
|
| 571 |
+
" dots = torch.zeros(NUM_PERMUTATIONS)\n",
|
| 572 |
+
" for index in range(NUM_PERMUTATIONS):\n",
|
| 573 |
+
" name_inner = ''\n",
|
| 574 |
+
" if index == 0 : name_inner = name_B\n",
|
| 575 |
+
" if index == 1 : name_inner = max_name_ahead\n",
|
| 576 |
+
" if index == 2 : name_inner = name_B + max_name_trail\n",
|
| 577 |
+
" if index == 3 : name_inner = max_name_ahead + name_B + max_name_trail\n",
|
| 578 |
+
" name = must_start_with + name_inner + must_end_with\n",
|
| 579 |
+
" #----#\n",
|
| 580 |
+
" ids = processor.tokenizer(text=name, padding=use_token_padding, return_tensors=\"pt\")\n",
|
| 581 |
+
" #----#\n",
|
| 582 |
+
" sim = 0\n",
|
| 583 |
+
" if(use == '🖼️image_encoding from image'):\n",
|
| 584 |
+
" text_features = model.get_text_features(**ids)\n",
|
| 585 |
+
" text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
| 586 |
+
" logit_scale = model.logit_scale.exp()\n",
|
| 587 |
+
" torch.matmul(text_features, image_features.t()) * logit_scale\n",
|
| 588 |
+
" sim = torch.nn.functional.cosine_similarity(text_features, image_features) * logit_scale\n",
|
| 589 |
+
" #-----#\n",
|
| 590 |
+
" if(use == '📝text_encoding from prompt'):\n",
|
| 591 |
+
" text_features = model.get_text_features(**ids)\n",
|
| 592 |
+
" text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
| 593 |
+
" sim = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
|
| 594 |
+
" #-----#\n",
|
| 595 |
+
" dots[index] = sim\n",
|
| 596 |
+
" names[index] = name_inner\n",
|
| 597 |
+
" #------#\n",
|
| 598 |
+
" sorted, indices = torch.sort(dots,dim=0 , descending=True)\n",
|
| 599 |
+
" #------#\n",
|
| 600 |
+
" best_sim = dots[indices[0].item()]\n",
|
| 601 |
+
" name_B = names[indices[0].item()].replace('</w>', ' ') #Update name_B with best value\n",
|
| 602 |
+
"#--------#\n",
|
| 603 |
+
"#store the final value\n",
|
| 604 |
+
"results_name[iter] = name_B\n",
|
| 605 |
+
"results_sim[iter] = best_sim\n",
|
| 606 |
+
"\n",
|
| 607 |
+
"sorted, indices = torch.sort(results_sim,dim=0 , descending=True)\n",
|
| 608 |
+
"\n",
|
| 609 |
+
"print('')\n",
|
| 610 |
+
"for index in range(ITERS+1):\n",
|
| 611 |
+
" name_inner = results_name[indices[index].item()]\n",
|
| 612 |
+
" print(must_start_with + name_inner + must_end_with)\n",
|
| 613 |
" print(f'similiarity = {round(sorted[index].item(),2)} %')\n",
|
| 614 |
" print('------')\n",
|
| 615 |
"#------#\n",
|
|
|
|
| 617 |
],
|
| 618 |
"metadata": {
|
| 619 |
"collapsed": true,
|
| 620 |
+
"id": "fi0jRruI0-tu"
|
|
|
|
| 621 |
},
|
| 622 |
"execution_count": null,
|
| 623 |
"outputs": []
|