Add test model (CC+WIT+COCO) 7 epochs
Browse files- README.md +0 -3
- config.json +156 -0
- evaluation/imagenet_validation_script.ipynb +0 -1003
- flax_model.msgpack +3 -0
- hybrid_clip/README.md +0 -172
- hybrid_clip/configuration_hybrid_clip.py +0 -108
- hybrid_clip/modeling_hybrid_clip.py +0 -420
- hybrid_clip/requirements.txt +0 -8
- hybrid_clip/run_hybrid_clip.py +0 -562
README.md
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
# clip-italian
|
2 |
-
|
3 |
-
# TODO
|
|
|
|
|
|
|
|
config.json
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"HybridCLIP"
|
4 |
+
],
|
5 |
+
"initializer_factor": 1.0,
|
6 |
+
"model_type": "hybrid-clip",
|
7 |
+
"projection_dim": 512,
|
8 |
+
"seed": 42,
|
9 |
+
"text_config": {
|
10 |
+
"_name_or_path": "",
|
11 |
+
"add_cross_attention": false,
|
12 |
+
"architectures": [
|
13 |
+
"BertForMaskedLM"
|
14 |
+
],
|
15 |
+
"attention_probs_dropout_prob": 0.1,
|
16 |
+
"bad_words_ids": null,
|
17 |
+
"bos_token_id": null,
|
18 |
+
"chunk_size_feed_forward": 0,
|
19 |
+
"decoder_start_token_id": null,
|
20 |
+
"diversity_penalty": 0.0,
|
21 |
+
"do_sample": false,
|
22 |
+
"early_stopping": false,
|
23 |
+
"encoder_no_repeat_ngram_size": 0,
|
24 |
+
"eos_token_id": null,
|
25 |
+
"finetuning_task": null,
|
26 |
+
"forced_bos_token_id": null,
|
27 |
+
"forced_eos_token_id": null,
|
28 |
+
"gradient_checkpointing": false,
|
29 |
+
"hidden_act": "gelu",
|
30 |
+
"hidden_dropout_prob": 0.1,
|
31 |
+
"hidden_size": 768,
|
32 |
+
"id2label": {
|
33 |
+
"0": "LABEL_0",
|
34 |
+
"1": "LABEL_1"
|
35 |
+
},
|
36 |
+
"initializer_range": 0.02,
|
37 |
+
"intermediate_size": 3072,
|
38 |
+
"is_decoder": false,
|
39 |
+
"is_encoder_decoder": false,
|
40 |
+
"label2id": {
|
41 |
+
"LABEL_0": 0,
|
42 |
+
"LABEL_1": 1
|
43 |
+
},
|
44 |
+
"layer_norm_eps": 1e-12,
|
45 |
+
"length_penalty": 1.0,
|
46 |
+
"max_length": 20,
|
47 |
+
"max_position_embeddings": 512,
|
48 |
+
"min_length": 0,
|
49 |
+
"model_type": "bert",
|
50 |
+
"no_repeat_ngram_size": 0,
|
51 |
+
"num_attention_heads": 12,
|
52 |
+
"num_beam_groups": 1,
|
53 |
+
"num_beams": 1,
|
54 |
+
"num_hidden_layers": 12,
|
55 |
+
"num_return_sequences": 1,
|
56 |
+
"output_attentions": false,
|
57 |
+
"output_hidden_states": false,
|
58 |
+
"output_scores": false,
|
59 |
+
"pad_token_id": 0,
|
60 |
+
"position_embedding_type": "absolute",
|
61 |
+
"prefix": null,
|
62 |
+
"problem_type": null,
|
63 |
+
"pruned_heads": {},
|
64 |
+
"remove_invalid_values": false,
|
65 |
+
"repetition_penalty": 1.0,
|
66 |
+
"return_dict": true,
|
67 |
+
"return_dict_in_generate": false,
|
68 |
+
"sep_token_id": null,
|
69 |
+
"task_specific_params": null,
|
70 |
+
"temperature": 1.0,
|
71 |
+
"tie_encoder_decoder": false,
|
72 |
+
"tie_word_embeddings": true,
|
73 |
+
"tokenizer_class": null,
|
74 |
+
"top_k": 50,
|
75 |
+
"top_p": 1.0,
|
76 |
+
"torch_dtype": null,
|
77 |
+
"torchscript": false,
|
78 |
+
"transformers_version": "4.9.0.dev0",
|
79 |
+
"type_vocab_size": 2,
|
80 |
+
"use_bfloat16": false,
|
81 |
+
"use_cache": true,
|
82 |
+
"vocab_size": 32102
|
83 |
+
},
|
84 |
+
"transformers_version": null,
|
85 |
+
"vision_config": {
|
86 |
+
"_name_or_path": "",
|
87 |
+
"add_cross_attention": false,
|
88 |
+
"architectures": null,
|
89 |
+
"attention_dropout": 0.0,
|
90 |
+
"bad_words_ids": null,
|
91 |
+
"bos_token_id": null,
|
92 |
+
"chunk_size_feed_forward": 0,
|
93 |
+
"decoder_start_token_id": null,
|
94 |
+
"diversity_penalty": 0.0,
|
95 |
+
"do_sample": false,
|
96 |
+
"dropout": 0.0,
|
97 |
+
"early_stopping": false,
|
98 |
+
"encoder_no_repeat_ngram_size": 0,
|
99 |
+
"eos_token_id": null,
|
100 |
+
"finetuning_task": null,
|
101 |
+
"forced_bos_token_id": null,
|
102 |
+
"forced_eos_token_id": null,
|
103 |
+
"gradient_checkpointing": false,
|
104 |
+
"hidden_act": "quick_gelu",
|
105 |
+
"hidden_size": 768,
|
106 |
+
"id2label": {
|
107 |
+
"0": "LABEL_0",
|
108 |
+
"1": "LABEL_1"
|
109 |
+
},
|
110 |
+
"image_size": 224,
|
111 |
+
"initializer_factor": 1.0,
|
112 |
+
"initializer_range": 0.02,
|
113 |
+
"intermediate_size": 3072,
|
114 |
+
"is_decoder": false,
|
115 |
+
"is_encoder_decoder": false,
|
116 |
+
"label2id": {
|
117 |
+
"LABEL_0": 0,
|
118 |
+
"LABEL_1": 1
|
119 |
+
},
|
120 |
+
"layer_norm_eps": 1e-05,
|
121 |
+
"length_penalty": 1.0,
|
122 |
+
"max_length": 20,
|
123 |
+
"min_length": 0,
|
124 |
+
"model_type": "clip_vision_model",
|
125 |
+
"no_repeat_ngram_size": 0,
|
126 |
+
"num_attention_heads": 12,
|
127 |
+
"num_beam_groups": 1,
|
128 |
+
"num_beams": 1,
|
129 |
+
"num_hidden_layers": 12,
|
130 |
+
"num_return_sequences": 1,
|
131 |
+
"output_attentions": false,
|
132 |
+
"output_hidden_states": false,
|
133 |
+
"output_scores": false,
|
134 |
+
"pad_token_id": null,
|
135 |
+
"patch_size": 32,
|
136 |
+
"prefix": null,
|
137 |
+
"problem_type": null,
|
138 |
+
"pruned_heads": {},
|
139 |
+
"remove_invalid_values": false,
|
140 |
+
"repetition_penalty": 1.0,
|
141 |
+
"return_dict": true,
|
142 |
+
"return_dict_in_generate": false,
|
143 |
+
"sep_token_id": null,
|
144 |
+
"task_specific_params": null,
|
145 |
+
"temperature": 1.0,
|
146 |
+
"tie_encoder_decoder": false,
|
147 |
+
"tie_word_embeddings": true,
|
148 |
+
"tokenizer_class": null,
|
149 |
+
"top_k": 50,
|
150 |
+
"top_p": 1.0,
|
151 |
+
"torch_dtype": null,
|
152 |
+
"torchscript": false,
|
153 |
+
"transformers_version": "4.9.0.dev0",
|
154 |
+
"use_bfloat16": false
|
155 |
+
}
|
156 |
+
}
|
evaluation/imagenet_validation_script.ipynb
DELETED
@@ -1,1003 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "markdown",
|
5 |
-
"metadata": {},
|
6 |
-
"source": [
|
7 |
-
"# Imagenet Evaluation Script\n",
|
8 |
-
"modified from [the evluation script by OpenAI](https://colab.research.google.com/github/openai/clip/blob/master/notebooks/Prompt_Engineering_for_ImageNet.ipynb)."
|
9 |
-
]
|
10 |
-
},
|
11 |
-
{
|
12 |
-
"cell_type": "code",
|
13 |
-
"execution_count": 3,
|
14 |
-
"metadata": {},
|
15 |
-
"outputs": [],
|
16 |
-
"source": [
|
17 |
-
"import json\n",
|
18 |
-
"from modeling_hybrid_clip import FlaxHybridCLIP\n",
|
19 |
-
"from configuration_hybrid_clip import HybridCLIPConfig\n",
|
20 |
-
"\n",
|
21 |
-
"import jax\n",
|
22 |
-
"from jax import numpy as jnp\n",
|
23 |
-
"\n",
|
24 |
-
"import os \n",
|
25 |
-
"os.environ['TOKENIZERS_PARALLELISM'] = \"false\"\n",
|
26 |
-
"\n",
|
27 |
-
"import transformers\n",
|
28 |
-
"from transformers import AutoTokenizer\n",
|
29 |
-
"\n",
|
30 |
-
"import numpy as np\n",
|
31 |
-
"import torch\n",
|
32 |
-
"import torchvision\n",
|
33 |
-
"from torchvision import transforms\n",
|
34 |
-
"from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize, ColorJitter, RandomHorizontalFlip, RandomRotation, ToTensor, Lambda\n",
|
35 |
-
"from torchvision.transforms.functional import InterpolationMode\n",
|
36 |
-
"from tqdm.notebook import tqdm"
|
37 |
-
]
|
38 |
-
},
|
39 |
-
{
|
40 |
-
"cell_type": "code",
|
41 |
-
"execution_count": 4,
|
42 |
-
"metadata": {},
|
43 |
-
"outputs": [],
|
44 |
-
"source": [
|
45 |
-
"# Global Variables\n",
|
46 |
-
"\n",
|
47 |
-
"LANGUAGE = 'en'\n",
|
48 |
-
"IMAGENET_ROOT = \"/home/raphaelp/imagenet_root/\"\n",
|
49 |
-
"\n",
|
50 |
-
"CONFIG_FILE = '/home/raphaelp/clip-base/config.json'\n",
|
51 |
-
"MODEL_FILE = '/home/raphaelp/clip-base/checkpoint/flax_model.msgpack'\n",
|
52 |
-
"\n",
|
53 |
-
"\n",
|
54 |
-
"if LANGUAGE == 'en':\n",
|
55 |
-
" TOKENIZER_NAME = \"roberta-base\" \n",
|
56 |
-
"elif LANGUAGE == 'it':\n",
|
57 |
-
" TOKENIZER_NAME = \"dbmdz/bert-base-italian-cased\""
|
58 |
-
]
|
59 |
-
},
|
60 |
-
{
|
61 |
-
"cell_type": "markdown",
|
62 |
-
"metadata": {
|
63 |
-
"id": "eFxgLV5HAEEw"
|
64 |
-
},
|
65 |
-
"source": [
|
66 |
-
"# Loading the model"
|
67 |
-
]
|
68 |
-
},
|
69 |
-
{
|
70 |
-
"cell_type": "code",
|
71 |
-
"execution_count": 10,
|
72 |
-
"metadata": {},
|
73 |
-
"outputs": [],
|
74 |
-
"source": [
|
75 |
-
"tokenizer = AutoTokenizer.from_pretrained(\n",
|
76 |
-
" TOKENIZER_NAME, cache_dir=None, use_fast=True\n",
|
77 |
-
")"
|
78 |
-
]
|
79 |
-
},
|
80 |
-
{
|
81 |
-
"cell_type": "code",
|
82 |
-
"execution_count": 5,
|
83 |
-
"metadata": {
|
84 |
-
"tags": []
|
85 |
-
},
|
86 |
-
"outputs": [
|
87 |
-
{
|
88 |
-
"name": "stderr",
|
89 |
-
"output_type": "stream",
|
90 |
-
"text": [
|
91 |
-
"INFO:absl:Starting the local TPU driver.\n",
|
92 |
-
"INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
|
93 |
-
"INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: Interpreter TPU Host\n",
|
94 |
-
"2021-07-07 14:04:30.725543: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n"
|
95 |
-
]
|
96 |
-
}
|
97 |
-
],
|
98 |
-
"source": [
|
99 |
-
"with open(CONFIG_FILE, 'r') as f:\n",
|
100 |
-
" config_dict = json.load(f)\n",
|
101 |
-
"config_dict['vision_config']['model_type'] = 'clip'\n",
|
102 |
-
"config = HybridCLIPConfig(text_config_dict=config_dict['text_config'], vision_config_dict=config_dict['vision_config'])\n",
|
103 |
-
"model = FlaxHybridCLIP.from_pretrained(MODEL_FILE, config=config)"
|
104 |
-
]
|
105 |
-
},
|
106 |
-
{
|
107 |
-
"cell_type": "code",
|
108 |
-
"execution_count": 6,
|
109 |
-
"metadata": {
|
110 |
-
"colab": {
|
111 |
-
"base_uri": "https://localhost:8080/"
|
112 |
-
},
|
113 |
-
"id": "IBRVTY9lbGm8",
|
114 |
-
"outputId": "58641dc2-919d-40ae-b71a-7b7b47830f77"
|
115 |
-
},
|
116 |
-
"outputs": [
|
117 |
-
{
|
118 |
-
"name": "stdout",
|
119 |
-
"output_type": "stream",
|
120 |
-
"text": [
|
121 |
-
"Input resolution: 224\n",
|
122 |
-
"Context length: 20\n",
|
123 |
-
"Vocab size: 50265\n"
|
124 |
-
]
|
125 |
-
}
|
126 |
-
],
|
127 |
-
"source": [
|
128 |
-
"print(\"Input resolution:\", config.vision_config.image_size)\n",
|
129 |
-
"print(\"Context length:\", config.text_config.max_length)\n",
|
130 |
-
"print(\"Vocab size:\", config.text_config.vocab_size)"
|
131 |
-
]
|
132 |
-
},
|
133 |
-
{
|
134 |
-
"cell_type": "markdown",
|
135 |
-
"metadata": {
|
136 |
-
"id": "LhO3OtOmF8M4"
|
137 |
-
},
|
138 |
-
"source": [
|
139 |
-
"# Preparing ImageNet labels and prompts\n",
|
140 |
-
"\n",
|
141 |
-
"The following cell contains the 1,000 labels for the ImageNet dataset, followed by the text templates we'll use as \"prompt engineering\"."
|
142 |
-
]
|
143 |
-
},
|
144 |
-
{
|
145 |
-
"cell_type": "code",
|
146 |
-
"execution_count": 7,
|
147 |
-
"metadata": {
|
148 |
-
"id": "R2HbOZrqa0jF"
|
149 |
-
},
|
150 |
-
"outputs": [],
|
151 |
-
"source": [
|
152 |
-
"if LANGUAGE == 'en':\n",
|
153 |
-
" imagenet_classes = [\"tench\", \"goldfish\", \"great white shark\", \"tiger shark\", \"hammerhead shark\", \"electric ray\", \"stingray\", \"rooster\", \"hen\", \"ostrich\", \"brambling\", \"goldfinch\", \"house finch\", \"junco\", \"indigo bunting\", \"American robin\", \"bulbul\", \"jay\", \"magpie\", \"chickadee\", \"American dipper\", \"kite (bird of prey)\", \"bald eagle\", \"vulture\", \"great grey owl\", \"fire salamander\", \"smooth newt\", \"newt\", \"spotted salamander\", \"axolotl\", \"American bullfrog\", \"tree frog\", \"tailed frog\", \"loggerhead sea turtle\", \"leatherback sea turtle\", \"mud turtle\", \"terrapin\", \"box turtle\", \"banded gecko\", \"green iguana\", \"Carolina anole\", \"desert grassland whiptail lizard\", \"agama\", \"frilled-necked lizard\", \"alligator lizard\", \"Gila monster\", \"European green lizard\", \"chameleon\", \"Komodo dragon\", \"Nile crocodile\", \"American alligator\", \"triceratops\", \"worm snake\", \"ring-necked snake\", \"eastern hog-nosed snake\", \"smooth green snake\", \"kingsnake\", \"garter snake\", \"water snake\", \"vine snake\", \"night snake\", \"boa constrictor\", \"African rock python\", \"Indian cobra\", \"green mamba\", \"sea snake\", \"Saharan horned viper\", \"eastern diamondback rattlesnake\", \"sidewinder rattlesnake\", \"trilobite\", \"harvestman\", \"scorpion\", \"yellow garden spider\", \"barn spider\", \"European garden spider\", \"southern black widow\", \"tarantula\", \"wolf spider\", \"tick\", \"centipede\", \"black grouse\", \"ptarmigan\", \"ruffed grouse\", \"prairie grouse\", \"peafowl\", \"quail\", \"partridge\", \"african grey parrot\", \"macaw\", \"sulphur-crested cockatoo\", \"lorikeet\", \"coucal\", \"bee eater\", \"hornbill\", \"hummingbird\", \"jacamar\", \"toucan\", \"duck\", \"red-breasted merganser\", \"goose\", \"black swan\", \"tusker\", \"echidna\", \"platypus\", \"wallaby\", \"koala\", \"wombat\", \"jellyfish\", \"sea anemone\", \"brain coral\", \"flatworm\", \"nematode\", \"conch\", \"snail\", \"slug\", \"sea slug\", \"chiton\", \"chambered nautilus\", \"Dungeness crab\", \"rock crab\", \"fiddler crab\", \"red king crab\", \"American lobster\", \"spiny lobster\", \"crayfish\", \"hermit crab\", \"isopod\", \"white stork\", \"black stork\", \"spoonbill\", \"flamingo\", \"little blue heron\", \"great egret\", \"bittern bird\", \"crane bird\", \"limpkin\", \"common gallinule\", \"American coot\", \"bustard\", \"ruddy turnstone\", \"dunlin\", \"common redshank\", \"dowitcher\", \"oystercatcher\", \"pelican\", \"king penguin\", \"albatross\", \"grey whale\", \"killer whale\", \"dugong\", \"sea lion\", \"Chihuahua\", \"Japanese Chin\", \"Maltese\", \"Pekingese\", \"Shih Tzu\", \"King Charles Spaniel\", \"Papillon\", \"toy terrier\", \"Rhodesian Ridgeback\", \"Afghan Hound\", \"Basset Hound\", \"Beagle\", \"Bloodhound\", \"Bluetick Coonhound\", \"Black and Tan Coonhound\", \"Treeing Walker Coonhound\", \"English foxhound\", \"Redbone Coonhound\", \"borzoi\", \"Irish Wolfhound\", \"Italian Greyhound\", \"Whippet\", \"Ibizan Hound\", \"Norwegian Elkhound\", \"Otterhound\", \"Saluki\", \"Scottish Deerhound\", \"Weimaraner\", \"Staffordshire Bull Terrier\", \"American Staffordshire Terrier\", \"Bedlington Terrier\", \"Border Terrier\", \"Kerry Blue Terrier\", \"Irish Terrier\", \"Norfolk Terrier\", \"Norwich Terrier\", \"Yorkshire Terrier\", \"Wire Fox Terrier\", \"Lakeland Terrier\", \"Sealyham Terrier\", \"Airedale Terrier\", \"Cairn Terrier\", \"Australian Terrier\", \"Dandie Dinmont Terrier\", \"Boston Terrier\", \"Miniature Schnauzer\", \"Giant Schnauzer\", \"Standard Schnauzer\", \"Scottish Terrier\", \"Tibetan Terrier\", \"Australian Silky Terrier\", \"Soft-coated Wheaten Terrier\", \"West Highland White Terrier\", \"Lhasa Apso\", \"Flat-Coated Retriever\", \"Curly-coated Retriever\", \"Golden Retriever\", \"Labrador Retriever\", \"Chesapeake Bay Retriever\", \"German Shorthaired Pointer\", \"Vizsla\", \"English Setter\", \"Irish Setter\", \"Gordon Setter\", \"Brittany dog\", \"Clumber Spaniel\", \"English Springer Spaniel\", \"Welsh Springer Spaniel\", \"Cocker Spaniel\", \"Sussex Spaniel\", \"Irish Water Spaniel\", \"Kuvasz\", \"Schipperke\", \"Groenendael dog\", \"Malinois\", \"Briard\", \"Australian Kelpie\", \"Komondor\", \"Old English Sheepdog\", \"Shetland Sheepdog\", \"collie\", \"Border Collie\", \"Bouvier des Flandres dog\", \"Rottweiler\", \"German Shepherd Dog\", \"Dobermann\", \"Miniature Pinscher\", \"Greater Swiss Mountain Dog\", \"Bernese Mountain Dog\", \"Appenzeller Sennenhund\", \"Entlebucher Sennenhund\", \"Boxer\", \"Bullmastiff\", \"Tibetan Mastiff\", \"French Bulldog\", \"Great Dane\", \"St. Bernard\", \"husky\", \"Alaskan Malamute\", \"Siberian Husky\", \"Dalmatian\", \"Affenpinscher\", \"Basenji\", \"pug\", \"Leonberger\", \"Newfoundland dog\", \"Great Pyrenees dog\", \"Samoyed\", \"Pomeranian\", \"Chow Chow\", \"Keeshond\", \"brussels griffon\", \"Pembroke Welsh Corgi\", \"Cardigan Welsh Corgi\", \"Toy Poodle\", \"Miniature Poodle\", \"Standard Poodle\", \"Mexican hairless dog (xoloitzcuintli)\", \"grey wolf\", \"Alaskan tundra wolf\", \"red wolf or maned wolf\", \"coyote\", \"dingo\", \"dhole\", \"African wild dog\", \"hyena\", \"red fox\", \"kit fox\", \"Arctic fox\", \"grey fox\", \"tabby cat\", \"tiger cat\", \"Persian cat\", \"Siamese cat\", \"Egyptian Mau\", \"cougar\", \"lynx\", \"leopard\", \"snow leopard\", \"jaguar\", \"lion\", \"tiger\", \"cheetah\", \"brown bear\", \"American black bear\", \"polar bear\", \"sloth bear\", \"mongoose\", \"meerkat\", \"tiger beetle\", \"ladybug\", \"ground beetle\", \"longhorn beetle\", \"leaf beetle\", \"dung beetle\", \"rhinoceros beetle\", \"weevil\", \"fly\", \"bee\", \"ant\", \"grasshopper\", \"cricket insect\", \"stick insect\", \"cockroach\", \"praying mantis\", \"cicada\", \"leafhopper\", \"lacewing\", \"dragonfly\", \"damselfly\", \"red admiral butterfly\", \"ringlet butterfly\", \"monarch butterfly\", \"small white butterfly\", \"sulphur butterfly\", \"gossamer-winged butterfly\", \"starfish\", \"sea urchin\", \"sea cucumber\", \"cottontail rabbit\", \"hare\", \"Angora rabbit\", \"hamster\", \"porcupine\", \"fox squirrel\", \"marmot\", \"beaver\", \"guinea pig\", \"common sorrel horse\", \"zebra\", \"pig\", \"wild boar\", \"warthog\", \"hippopotamus\", \"ox\", \"water buffalo\", \"bison\", \"ram (adult male sheep)\", \"bighorn sheep\", \"Alpine ibex\", \"hartebeest\", \"impala (antelope)\", \"gazelle\", \"arabian camel\", \"llama\", \"weasel\", \"mink\", \"European polecat\", \"black-footed ferret\", \"otter\", \"skunk\", \"badger\", \"armadillo\", \"three-toed sloth\", \"orangutan\", \"gorilla\", \"chimpanzee\", \"gibbon\", \"siamang\", \"guenon\", \"patas monkey\", \"baboon\", \"macaque\", \"langur\", \"black-and-white colobus\", \"proboscis monkey\", \"marmoset\", \"white-headed capuchin\", \"howler monkey\", \"titi monkey\", \"Geoffroy's spider monkey\", \"common squirrel monkey\", \"ring-tailed lemur\", \"indri\", \"Asian elephant\", \"African bush elephant\", \"red panda\", \"giant panda\", \"snoek fish\", \"eel\", \"silver salmon\", \"rock beauty fish\", \"clownfish\", \"sturgeon\", \"gar fish\", \"lionfish\", \"pufferfish\", \"abacus\", \"abaya\", \"academic gown\", \"accordion\", \"acoustic guitar\", \"aircraft carrier\", \"airliner\", \"airship\", \"altar\", \"ambulance\", \"amphibious vehicle\", \"analog clock\", \"apiary\", \"apron\", \"trash can\", \"assault rifle\", \"backpack\", \"bakery\", \"balance beam\", \"balloon\", \"ballpoint pen\", \"Band-Aid\", \"banjo\", \"baluster / handrail\", \"barbell\", \"barber chair\", \"barbershop\", \"barn\", \"barometer\", \"barrel\", \"wheelbarrow\", \"baseball\", \"basketball\", \"bassinet\", \"bassoon\", \"swimming cap\", \"bath towel\", \"bathtub\", \"station wagon\", \"lighthouse\", \"beaker\", \"military hat (bearskin or shako)\", \"beer bottle\", \"beer glass\", \"bell tower\", \"baby bib\", \"tandem bicycle\", \"bikini\", \"ring binder\", \"binoculars\", \"birdhouse\", \"boathouse\", \"bobsleigh\", \"bolo tie\", \"poke bonnet\", \"bookcase\", \"bookstore\", \"bottle cap\", \"hunting bow\", \"bow tie\", \"brass memorial plaque\", \"bra\", \"breakwater\", \"breastplate\", \"broom\", \"bucket\", \"buckle\", \"bulletproof vest\", \"high-speed train\", \"butcher shop\", \"taxicab\", \"cauldron\", \"candle\", \"cannon\", \"canoe\", \"can opener\", \"cardigan\", \"car mirror\", \"carousel\", \"tool kit\", \"cardboard box / carton\", \"car wheel\", \"automated teller machine\", \"cassette\", \"cassette player\", \"castle\", \"catamaran\", \"CD player\", \"cello\", \"mobile phone\", \"chain\", \"chain-link fence\", \"chain mail\", \"chainsaw\", \"storage chest\", \"chiffonier\", \"bell or wind chime\", \"china cabinet\", \"Christmas stocking\", \"church\", \"movie theater\", \"cleaver\", \"cliff dwelling\", \"cloak\", \"clogs\", \"cocktail shaker\", \"coffee mug\", \"coffeemaker\", \"spiral or coil\", \"combination lock\", \"computer keyboard\", \"candy store\", \"container ship\", \"convertible\", \"corkscrew\", \"cornet\", \"cowboy boot\", \"cowboy hat\", \"cradle\", \"construction crane\", \"crash helmet\", \"crate\", \"infant bed\", \"Crock Pot\", \"croquet ball\", \"crutch\", \"cuirass\", \"dam\", \"desk\", \"desktop computer\", \"rotary dial telephone\", \"diaper\", \"digital clock\", \"digital watch\", \"dining table\", \"dishcloth\", \"dishwasher\", \"disc brake\", \"dock\", \"dog sled\", \"dome\", \"doormat\", \"drilling rig\", \"drum\", \"drumstick\", \"dumbbell\", \"Dutch oven\", \"electric fan\", \"electric guitar\", \"electric locomotive\", \"entertainment center\", \"envelope\", \"espresso machine\", \"face powder\", \"feather boa\", \"filing cabinet\", \"fireboat\", \"fire truck\", \"fire screen\", \"flagpole\", \"flute\", \"folding chair\", \"football helmet\", \"forklift\", \"fountain\", \"fountain pen\", \"four-poster bed\", \"freight car\", \"French horn\", \"frying pan\", \"fur coat\", \"garbage truck\", \"gas mask or respirator\", \"gas pump\", \"goblet\", \"go-kart\", \"golf ball\", \"golf cart\", \"gondola\", \"gong\", \"gown\", \"grand piano\", \"greenhouse\", \"radiator grille\", \"grocery store\", \"guillotine\", \"hair clip\", \"hair spray\", \"half-track\", \"hammer\", \"hamper\", \"hair dryer\", \"hand-held computer\", \"handkerchief\", \"hard disk drive\", \"harmonica\", \"harp\", \"combine harvester\", \"hatchet\", \"holster\", \"home theater\", \"honeycomb\", \"hook\", \"hoop skirt\", \"gymnastic horizontal bar\", \"horse-drawn vehicle\", \"hourglass\", \"iPod\", \"clothes iron\", \"carved pumpkin\", \"jeans\", \"jeep\", \"T-shirt\", \"jigsaw puzzle\", \"rickshaw\", \"joystick\", \"kimono\", \"knee pad\", \"knot\", \"lab coat\", \"ladle\", \"lampshade\", \"laptop computer\", \"lawn mower\", \"lens cap\", \"letter opener\", \"library\", \"lifeboat\", \"lighter\", \"limousine\", \"ocean liner\", \"lipstick\", \"slip-on shoe\", \"lotion\", \"music speaker\", \"loupe magnifying glass\", \"sawmill\", \"magnetic compass\", \"messenger bag\", \"mailbox\", \"tights\", \"one-piece bathing suit\", \"manhole cover\", \"maraca\", \"marimba\", \"mask\", \"matchstick\", \"maypole\", \"maze\", \"measuring cup\", \"medicine cabinet\", \"megalith\", \"microphone\", \"microwave oven\", \"military uniform\", \"milk can\", \"minibus\", \"miniskirt\", \"minivan\", \"missile\", \"mitten\", \"mixing bowl\", \"mobile home\", \"ford model t\", \"modem\", \"monastery\", \"monitor\", \"moped\", \"mortar and pestle\", \"graduation cap\", \"mosque\", \"mosquito net\", \"vespa\", \"mountain bike\", \"tent\", \"computer mouse\", \"mousetrap\", \"moving van\", \"muzzle\", \"metal nail\", \"neck brace\", \"necklace\", \"baby pacifier\", \"notebook computer\", \"obelisk\", \"oboe\", \"ocarina\", \"odometer\", \"oil filter\", \"pipe organ\", \"oscilloscope\", \"overskirt\", \"bullock cart\", \"oxygen mask\", \"product packet / packaging\", \"paddle\", \"paddle wheel\", \"padlock\", \"paintbrush\", \"pajamas\", \"palace\", \"pan flute\", \"paper towel\", \"parachute\", \"parallel bars\", \"park bench\", \"parking meter\", \"railroad car\", \"patio\", \"payphone\", \"pedestal\", \"pencil case\", \"pencil sharpener\", \"perfume\", \"Petri dish\", \"photocopier\", \"plectrum\", \"Pickelhaube\", \"picket fence\", \"pickup truck\", \"pier\", \"piggy bank\", \"pill bottle\", \"pillow\", \"ping-pong ball\", \"pinwheel\", \"pirate ship\", \"drink pitcher\", \"block plane\", \"planetarium\", \"plastic bag\", \"plate rack\", \"farm plow\", \"plunger\", \"Polaroid camera\", \"pole\", \"police van\", \"poncho\", \"pool table\", \"soda bottle\", \"plant pot\", \"potter's wheel\", \"power drill\", \"prayer rug\", \"printer\", \"prison\", \"missile\", \"projector\", \"hockey puck\", \"punching bag\", \"purse\", \"quill\", \"quilt\", \"race car\", \"racket\", \"radiator\", \"radio\", \"radio telescope\", \"rain barrel\", \"recreational vehicle\", \"fishing casting reel\", \"reflex camera\", \"refrigerator\", \"remote control\", \"restaurant\", \"revolver\", \"rifle\", \"rocking chair\", \"rotisserie\", \"eraser\", \"rugby ball\", \"ruler measuring stick\", \"sneaker\", \"safe\", \"safety pin\", \"salt shaker\", \"sandal\", \"sarong\", \"saxophone\", \"scabbard\", \"weighing scale\", \"school bus\", \"schooner\", \"scoreboard\", \"CRT monitor\", \"screw\", \"screwdriver\", \"seat belt\", \"sewing machine\", \"shield\", \"shoe store\", \"shoji screen / room divider\", \"shopping basket\", \"shopping cart\", \"shovel\", \"shower cap\", \"shower curtain\", \"ski\", \"balaclava ski mask\", \"sleeping bag\", \"slide rule\", \"sliding door\", \"slot machine\", \"snorkel\", \"snowmobile\", \"snowplow\", \"soap dispenser\", \"soccer ball\", \"sock\", \"solar thermal collector\", \"sombrero\", \"soup bowl\", \"keyboard space bar\", \"space heater\", \"space shuttle\", \"spatula\", \"motorboat\", \"spider web\", \"spindle\", \"sports car\", \"spotlight\", \"stage\", \"steam locomotive\", \"through arch bridge\", \"steel drum\", \"stethoscope\", \"scarf\", \"stone wall\", \"stopwatch\", \"stove\", \"strainer\", \"tram\", \"stretcher\", \"couch\", \"stupa\", \"submarine\", \"suit\", \"sundial\", \"sunglasses\", \"sunglasses\", \"sunscreen\", \"suspension bridge\", \"mop\", \"sweatshirt\", \"swim trunks / shorts\", \"swing\", \"electrical switch\", \"syringe\", \"table lamp\", \"tank\", \"tape player\", \"teapot\", \"teddy bear\", \"television\", \"tennis ball\", \"thatched roof\", \"front curtain\", \"thimble\", \"threshing machine\", \"throne\", \"tile roof\", \"toaster\", \"tobacco shop\", \"toilet seat\", \"torch\", \"totem pole\", \"tow truck\", \"toy store\", \"tractor\", \"semi-trailer truck\", \"tray\", \"trench coat\", \"tricycle\", \"trimaran\", \"tripod\", \"triumphal arch\", \"trolleybus\", \"trombone\", \"hot tub\", \"turnstile\", \"typewriter keyboard\", \"umbrella\", \"unicycle\", \"upright piano\", \"vacuum cleaner\", \"vase\", \"vaulted or arched ceiling\", \"velvet fabric\", \"vending machine\", \"vestment\", \"viaduct\", \"violin\", \"volleyball\", \"waffle iron\", \"wall clock\", \"wallet\", \"wardrobe\", \"military aircraft\", \"sink\", \"washing machine\", \"water bottle\", \"water jug\", \"water tower\", \"whiskey jug\", \"whistle\", \"hair wig\", \"window screen\", \"window shade\", \"Windsor tie\", \"wine bottle\", \"airplane wing\", \"wok\", \"wooden spoon\", \"wool\", \"split-rail fence\", \"shipwreck\", \"sailboat\", \"yurt\", \"website\", \"comic book\", \"crossword\", \"traffic or street sign\", \"traffic light\", \"dust jacket\", \"menu\", \"plate\", \"guacamole\", \"consomme\", \"hot pot\", \"trifle\", \"ice cream\", \"popsicle\", \"baguette\", \"bagel\", \"pretzel\", \"cheeseburger\", \"hot dog\", \"mashed potatoes\", \"cabbage\", \"broccoli\", \"cauliflower\", \"zucchini\", \"spaghetti squash\", \"acorn squash\", \"butternut squash\", \"cucumber\", \"artichoke\", \"bell pepper\", \"cardoon\", \"mushroom\", \"Granny Smith apple\", \"strawberry\", \"orange\", \"lemon\", \"fig\", \"pineapple\", \"banana\", \"jackfruit\", \"cherimoya (custard apple)\", \"pomegranate\", \"hay\", \"carbonara\", \"chocolate syrup\", \"dough\", \"meatloaf\", \"pizza\", \"pot pie\", \"burrito\", \"red wine\", \"espresso\", \"tea cup\", \"eggnog\", \"mountain\", \"bubble\", \"cliff\", \"coral reef\", \"geyser\", \"lakeshore\", \"promontory\", \"sandbar\", \"beach\", \"valley\", \"volcano\", \"baseball player\", \"bridegroom\", \"scuba diver\", \"rapeseed\", \"daisy\", \"yellow lady's slipper\", \"corn\", \"acorn\", \"rose hip\", \"horse chestnut seed\", \"coral fungus\", \"agaric\", \"gyromitra\", \"stinkhorn mushroom\", \"earth star fungus\", \"hen of the woods mushroom\", \"bolete\", \"corn cob\", \"toilet paper\"]\n",
|
154 |
-
"elif LANGUAGE == 'it':\n",
|
155 |
-
" raise NotImplementedError"
|
156 |
-
]
|
157 |
-
},
|
158 |
-
{
|
159 |
-
"cell_type": "code",
|
160 |
-
"execution_count": 8,
|
161 |
-
"metadata": {
|
162 |
-
"colab": {
|
163 |
-
"base_uri": "https://localhost:8080/"
|
164 |
-
},
|
165 |
-
"id": "toGtcd-Ji_MD",
|
166 |
-
"outputId": "46bcc85f-3968-4836-f3c6-e48848e944c4"
|
167 |
-
},
|
168 |
-
"outputs": [
|
169 |
-
{
|
170 |
-
"name": "stdout",
|
171 |
-
"output_type": "stream",
|
172 |
-
"text": [
|
173 |
-
"1000 classes, 80 templates\n"
|
174 |
-
]
|
175 |
-
}
|
176 |
-
],
|
177 |
-
"source": [
|
178 |
-
"if LANGUAGE == 'en':\n",
|
179 |
-
" imagenet_templates = [\n",
|
180 |
-
" 'a bad photo of a {}.',\n",
|
181 |
-
" 'a photo of many {}.',\n",
|
182 |
-
" 'a sculpture of a {}.',\n",
|
183 |
-
" 'a photo of the hard to see {}.',\n",
|
184 |
-
" 'a low resolution photo of the {}.',\n",
|
185 |
-
" 'a rendering of a {}.',\n",
|
186 |
-
" 'graffiti of a {}.',\n",
|
187 |
-
" 'a bad photo of the {}.',\n",
|
188 |
-
" 'a cropped photo of the {}.',\n",
|
189 |
-
" 'a tattoo of a {}.',\n",
|
190 |
-
" 'the embroidered {}.',\n",
|
191 |
-
" 'a photo of a hard to see {}.',\n",
|
192 |
-
" 'a bright photo of a {}.',\n",
|
193 |
-
" 'a photo of a clean {}.',\n",
|
194 |
-
" 'a photo of a dirty {}.',\n",
|
195 |
-
" 'a dark photo of the {}.',\n",
|
196 |
-
" 'a drawing of a {}.',\n",
|
197 |
-
" 'a photo of my {}.',\n",
|
198 |
-
" 'the plastic {}.',\n",
|
199 |
-
" 'a photo of the cool {}.',\n",
|
200 |
-
" 'a close-up photo of a {}.',\n",
|
201 |
-
" 'a black and white photo of the {}.',\n",
|
202 |
-
" 'a painting of the {}.',\n",
|
203 |
-
" 'a painting of a {}.',\n",
|
204 |
-
" 'a pixelated photo of the {}.',\n",
|
205 |
-
" 'a sculpture of the {}.',\n",
|
206 |
-
" 'a bright photo of the {}.',\n",
|
207 |
-
" 'a cropped photo of a {}.',\n",
|
208 |
-
" 'a plastic {}.',\n",
|
209 |
-
" 'a photo of the dirty {}.',\n",
|
210 |
-
" 'a jpeg corrupted photo of a {}.',\n",
|
211 |
-
" 'a blurry photo of the {}.',\n",
|
212 |
-
" 'a photo of the {}.',\n",
|
213 |
-
" 'a good photo of the {}.',\n",
|
214 |
-
" 'a rendering of the {}.',\n",
|
215 |
-
" 'a {} in a video game.',\n",
|
216 |
-
" 'a photo of one {}.',\n",
|
217 |
-
" 'a doodle of a {}.',\n",
|
218 |
-
" 'a close-up photo of the {}.',\n",
|
219 |
-
" 'a photo of a {}.',\n",
|
220 |
-
" 'the origami {}.',\n",
|
221 |
-
" 'the {} in a video game.',\n",
|
222 |
-
" 'a sketch of a {}.',\n",
|
223 |
-
" 'a doodle of the {}.',\n",
|
224 |
-
" 'a origami {}.',\n",
|
225 |
-
" 'a low resolution photo of a {}.',\n",
|
226 |
-
" 'the toy {}.',\n",
|
227 |
-
" 'a rendition of the {}.',\n",
|
228 |
-
" 'a photo of the clean {}.',\n",
|
229 |
-
" 'a photo of a large {}.',\n",
|
230 |
-
" 'a rendition of a {}.',\n",
|
231 |
-
" 'a photo of a nice {}.',\n",
|
232 |
-
" 'a photo of a weird {}.',\n",
|
233 |
-
" 'a blurry photo of a {}.',\n",
|
234 |
-
" 'a cartoon {}.',\n",
|
235 |
-
" 'art of a {}.',\n",
|
236 |
-
" 'a sketch of the {}.',\n",
|
237 |
-
" 'a embroidered {}.',\n",
|
238 |
-
" 'a pixelated photo of a {}.',\n",
|
239 |
-
" 'itap of the {}.',\n",
|
240 |
-
" 'a jpeg corrupted photo of the {}.',\n",
|
241 |
-
" 'a good photo of a {}.',\n",
|
242 |
-
" 'a plushie {}.',\n",
|
243 |
-
" 'a photo of the nice {}.',\n",
|
244 |
-
" 'a photo of the small {}.',\n",
|
245 |
-
" 'a photo of the weird {}.',\n",
|
246 |
-
" 'the cartoon {}.',\n",
|
247 |
-
" 'art of the {}.',\n",
|
248 |
-
" 'a drawing of the {}.',\n",
|
249 |
-
" 'a photo of the large {}.',\n",
|
250 |
-
" 'a black and white photo of a {}.',\n",
|
251 |
-
" 'the plushie {}.',\n",
|
252 |
-
" 'a dark photo of a {}.',\n",
|
253 |
-
" 'itap of a {}.',\n",
|
254 |
-
" 'graffiti of the {}.',\n",
|
255 |
-
" 'a toy {}.',\n",
|
256 |
-
" 'itap of my {}.',\n",
|
257 |
-
" 'a photo of a cool {}.',\n",
|
258 |
-
" 'a photo of a small {}.',\n",
|
259 |
-
" 'a tattoo of the {}.',\n",
|
260 |
-
" ]\n",
|
261 |
-
"elif LANGUAGE == 'it':\n",
|
262 |
-
" raise NotImplementedError\n",
|
263 |
-
"\n",
|
264 |
-
"print(f\"{len(imagenet_classes)} classes, {len(imagenet_templates)} templates\")"
|
265 |
-
]
|
266 |
-
},
|
267 |
-
{
|
268 |
-
"cell_type": "markdown",
|
269 |
-
"metadata": {},
|
270 |
-
"source": [
|
271 |
-
"# Set up Validation Set"
|
272 |
-
]
|
273 |
-
},
|
274 |
-
{
|
275 |
-
"cell_type": "code",
|
276 |
-
"execution_count": 9,
|
277 |
-
"metadata": {
|
278 |
-
"colab": {
|
279 |
-
"base_uri": "https://localhost:8080/"
|
280 |
-
},
|
281 |
-
"id": "cboKZocQlSYX",
|
282 |
-
"outputId": "58e644d4-6e23-43b5-964e-1e9e8540d22e"
|
283 |
-
},
|
284 |
-
"outputs": [],
|
285 |
-
"source": [
|
286 |
-
"val_preprocess = transforms.Compose([\n",
|
287 |
-
" Resize([config.vision_config.image_size], interpolation=InterpolationMode.BICUBIC),\n",
|
288 |
-
" CenterCrop(config.vision_config.image_size),\n",
|
289 |
-
" ToTensor(),\n",
|
290 |
-
" Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),\n",
|
291 |
-
"])"
|
292 |
-
]
|
293 |
-
},
|
294 |
-
{
|
295 |
-
"cell_type": "code",
|
296 |
-
"execution_count": 11,
|
297 |
-
"metadata": {
|
298 |
-
"colab": {
|
299 |
-
"base_uri": "https://localhost:8080/"
|
300 |
-
},
|
301 |
-
"id": "moHR4UlHKsDc",
|
302 |
-
"outputId": "178f6d0d-9a34-4cbc-c9c1-e7ce09927980"
|
303 |
-
},
|
304 |
-
"outputs": [],
|
305 |
-
"source": [
|
306 |
-
"images = torchvision.datasets.ImageNet(IMAGENET_ROOT, split='val', transform=val_preprocess)\n",
|
307 |
-
"loader = torch.utils.data.DataLoader(\n",
|
308 |
-
" images,\n",
|
309 |
-
" batch_size=64,\n",
|
310 |
-
" shuffle=False,\n",
|
311 |
-
" num_workers=32,\n",
|
312 |
-
" persistent_workers=True,\n",
|
313 |
-
" drop_last=False\n",
|
314 |
-
")"
|
315 |
-
]
|
316 |
-
},
|
317 |
-
{
|
318 |
-
"cell_type": "markdown",
|
319 |
-
"metadata": {
|
320 |
-
"id": "fz6D-F-Wbrtp"
|
321 |
-
},
|
322 |
-
"source": [
|
323 |
-
"# Creating zero-shot classifier weights"
|
324 |
-
]
|
325 |
-
},
|
326 |
-
{
|
327 |
-
"cell_type": "code",
|
328 |
-
"execution_count": 12,
|
329 |
-
"metadata": {
|
330 |
-
"colab": {
|
331 |
-
"base_uri": "https://localhost:8080/",
|
332 |
-
"height": 66,
|
333 |
-
"referenced_widgets": [
|
334 |
-
"4e3a3f83649f45f8bef3434980634664",
|
335 |
-
"f066bdb766664c788ba1e9de8d311e22",
|
336 |
-
"4e7a7427d28a4ae684e0be4548eb9944",
|
337 |
-
"cc9dc019c1334a46b2558ffa6c0dd6e6",
|
338 |
-
"285c877d4f644f3a8a58c4eb5948101c",
|
339 |
-
"075d6545e02e419ca565589eb5ffc318",
|
340 |
-
"53f9106c80e84d5b8c3ec96162d1db98",
|
341 |
-
"19c57d99e7c44cbda508ce558fde435d"
|
342 |
-
]
|
343 |
-
},
|
344 |
-
"id": "sRqDoz1Gbsii",
|
345 |
-
"outputId": "5ab6c001-8a5e-42c9-ab46-4477a693229c"
|
346 |
-
},
|
347 |
-
"outputs": [
|
348 |
-
{
|
349 |
-
"data": {
|
350 |
-
"application/vnd.jupyter.widget-view+json": {
|
351 |
-
"model_id": "64fc81ee79fb45808e9fd8040151940e",
|
352 |
-
"version_major": 2,
|
353 |
-
"version_minor": 0
|
354 |
-
},
|
355 |
-
"text/plain": [
|
356 |
-
"HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))"
|
357 |
-
]
|
358 |
-
},
|
359 |
-
"metadata": {},
|
360 |
-
"output_type": "display_data"
|
361 |
-
},
|
362 |
-
{
|
363 |
-
"name": "stdout",
|
364 |
-
"output_type": "stream",
|
365 |
-
"text": [
|
366 |
-
"\n"
|
367 |
-
]
|
368 |
-
}
|
369 |
-
],
|
370 |
-
"source": [
|
371 |
-
"def zeroshot_classifier(classnames, templates):\n",
|
372 |
-
" zeroshot_weights = []\n",
|
373 |
-
" for classname in tqdm(classnames):\n",
|
374 |
-
" texts = [template.format(classname) for template in templates] #format with class\n",
|
375 |
-
" inputs = tokenizer(texts, max_length=72, padding=\"max_length\", return_tensors=\"np\")\n",
|
376 |
-
" class_embeddings = model.get_text_features(inputs['input_ids'], inputs['attention_mask'])#embed with text encoder\n",
|
377 |
-
" class_embeddings /= jnp.linalg.norm(class_embeddings, axis=-1, keepdims=True)\n",
|
378 |
-
" class_embedding = jnp.mean(class_embeddings, axis=0)\n",
|
379 |
-
" class_embedding /= jnp.linalg.norm(class_embeddings)\n",
|
380 |
-
" zeroshot_weights.append(class_embedding)\n",
|
381 |
-
" zeroshot_weights = jnp.stack(zeroshot_weights, axis=1)\n",
|
382 |
-
" return zeroshot_weights\n",
|
383 |
-
"\n",
|
384 |
-
"\n",
|
385 |
-
"zeroshot_weights = zeroshot_classifier(imagenet_classes, imagenet_templates)"
|
386 |
-
]
|
387 |
-
},
|
388 |
-
{
|
389 |
-
"cell_type": "markdown",
|
390 |
-
"metadata": {
|
391 |
-
"id": "1fZo7hG8iJP5"
|
392 |
-
},
|
393 |
-
"source": [
|
394 |
-
"# Zero-shot prediction"
|
395 |
-
]
|
396 |
-
},
|
397 |
-
{
|
398 |
-
"cell_type": "code",
|
399 |
-
"execution_count": 13,
|
400 |
-
"metadata": {
|
401 |
-
"id": "j4kPSZoShQxN"
|
402 |
-
},
|
403 |
-
"outputs": [],
|
404 |
-
"source": [
|
405 |
-
"def accuracy(output, target, topk=(1,)):\n",
|
406 |
-
" # pred = output.topk(max(topk), 1, True, True)[1].t()\n",
|
407 |
-
" pred = np.argsort(output, axis=1)[:,-max(topk):]\n",
|
408 |
-
" correct = jnp.equal(pred, jnp.expand_dims(target, axis=-1))\n",
|
409 |
-
" return [float(jnp.sum(jnp.reshape(correct[:k], -1), axis=0, keepdims=True)) for k in topk]"
|
410 |
-
]
|
411 |
-
},
|
412 |
-
{
|
413 |
-
"cell_type": "code",
|
414 |
-
"execution_count": 16,
|
415 |
-
"metadata": {
|
416 |
-
"colab": {
|
417 |
-
"base_uri": "https://localhost:8080/",
|
418 |
-
"height": 100,
|
419 |
-
"referenced_widgets": [
|
420 |
-
"fbb2b937b22049f5987f39f48c652a86",
|
421 |
-
"0a1b6b76984349ccb36ca2fc4a4a0208",
|
422 |
-
"c136afb47aa14ac2832093ee415c6f3e",
|
423 |
-
"467a151e73744eccb199fe72aa352e5b",
|
424 |
-
"f6d637c3fc3c46928d023441227130e5",
|
425 |
-
"029e6eadacb8480193aab52ff073be8f",
|
426 |
-
"30178355f76742898d37966b3875ef0a",
|
427 |
-
"2e62544c03d64d6d92b94fcfaca2fc90"
|
428 |
-
]
|
429 |
-
},
|
430 |
-
"id": "wKJ7YsdlkDXo",
|
431 |
-
"outputId": "90e084fd-86bc-4a52-a06e-61bff7aa86e0"
|
432 |
-
},
|
433 |
-
"outputs": [
|
434 |
-
{
|
435 |
-
"data": {
|
436 |
-
"application/vnd.jupyter.widget-view+json": {
|
437 |
-
"model_id": "f5db0a6445f04d1bb5ae08dee278a215",
|
438 |
-
"version_major": 2,
|
439 |
-
"version_minor": 0
|
440 |
-
},
|
441 |
-
"text/plain": [
|
442 |
-
"HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))"
|
443 |
-
]
|
444 |
-
},
|
445 |
-
"metadata": {},
|
446 |
-
"output_type": "display_data"
|
447 |
-
},
|
448 |
-
{
|
449 |
-
"name": "stdout",
|
450 |
-
"output_type": "stream",
|
451 |
-
"text": [
|
452 |
-
"\n",
|
453 |
-
"{'top1': 0.958, 'top5': 4.688, 'top10': 9.202, 'top100': 58.29}\n"
|
454 |
-
]
|
455 |
-
}
|
456 |
-
],
|
457 |
-
"source": [
|
458 |
-
"top_ns = [1, 5, 10, 100]\n",
|
459 |
-
"acc_counters = [0. for _ in top_ns]\n",
|
460 |
-
"n = 0.\n",
|
461 |
-
"\n",
|
462 |
-
"for i, (images, target) in enumerate(tqdm(loader)):\n",
|
463 |
-
" images = images.permute(0, 2, 3, 1).numpy()\n",
|
464 |
-
" target = target.numpy()\n",
|
465 |
-
" # predict\n",
|
466 |
-
" image_features = model.get_image_features(images,)\n",
|
467 |
-
" image_features /= jnp.linalg.norm(image_features, axis=-1, keepdims=True)\n",
|
468 |
-
" logits = 100. * image_features @ zeroshot_weights\n",
|
469 |
-
"\n",
|
470 |
-
" # measure accuracy\n",
|
471 |
-
" accs = accuracy(logits, target, topk=top_ns)\n",
|
472 |
-
" for j in range(len(top_ns)):\n",
|
473 |
-
" acc_counters[j] += accs[j]\n",
|
474 |
-
" n += images.shape[0]\n",
|
475 |
-
"\n",
|
476 |
-
"tops = {f'top{top_ns[i]}': acc_counters[i] / n * 100 for i in range(len(top_ns))}\n",
|
477 |
-
"\n",
|
478 |
-
"print(tops)"
|
479 |
-
]
|
480 |
-
}
|
481 |
-
],
|
482 |
-
"metadata": {
|
483 |
-
"accelerator": "GPU",
|
484 |
-
"colab": {
|
485 |
-
"collapsed_sections": [],
|
486 |
-
"name": "Prompt Engineering for ImageNet.ipynb",
|
487 |
-
"provenance": []
|
488 |
-
},
|
489 |
-
"kernelspec": {
|
490 |
-
"display_name": "Python 3 (ipykernel)",
|
491 |
-
"language": "python",
|
492 |
-
"name": "python3"
|
493 |
-
},
|
494 |
-
"language_info": {
|
495 |
-
"codemirror_mode": {
|
496 |
-
"name": "ipython",
|
497 |
-
"version": 3
|
498 |
-
},
|
499 |
-
"file_extension": ".py",
|
500 |
-
"mimetype": "text/x-python",
|
501 |
-
"name": "python",
|
502 |
-
"nbconvert_exporter": "python",
|
503 |
-
"pygments_lexer": "ipython3",
|
504 |
-
"version": "3.8.10"
|
505 |
-
},
|
506 |
-
"widgets": {
|
507 |
-
"application/vnd.jupyter.widget-state+json": {
|
508 |
-
"029e6eadacb8480193aab52ff073be8f": {
|
509 |
-
"model_module": "@jupyter-widgets/base",
|
510 |
-
"model_name": "LayoutModel",
|
511 |
-
"state": {
|
512 |
-
"_model_module": "@jupyter-widgets/base",
|
513 |
-
"_model_module_version": "1.2.0",
|
514 |
-
"_model_name": "LayoutModel",
|
515 |
-
"_view_count": null,
|
516 |
-
"_view_module": "@jupyter-widgets/base",
|
517 |
-
"_view_module_version": "1.2.0",
|
518 |
-
"_view_name": "LayoutView",
|
519 |
-
"align_content": null,
|
520 |
-
"align_items": null,
|
521 |
-
"align_self": null,
|
522 |
-
"border": null,
|
523 |
-
"bottom": null,
|
524 |
-
"display": null,
|
525 |
-
"flex": null,
|
526 |
-
"flex_flow": null,
|
527 |
-
"grid_area": null,
|
528 |
-
"grid_auto_columns": null,
|
529 |
-
"grid_auto_flow": null,
|
530 |
-
"grid_auto_rows": null,
|
531 |
-
"grid_column": null,
|
532 |
-
"grid_gap": null,
|
533 |
-
"grid_row": null,
|
534 |
-
"grid_template_areas": null,
|
535 |
-
"grid_template_columns": null,
|
536 |
-
"grid_template_rows": null,
|
537 |
-
"height": null,
|
538 |
-
"justify_content": null,
|
539 |
-
"justify_items": null,
|
540 |
-
"left": null,
|
541 |
-
"margin": null,
|
542 |
-
"max_height": null,
|
543 |
-
"max_width": null,
|
544 |
-
"min_height": null,
|
545 |
-
"min_width": null,
|
546 |
-
"object_fit": null,
|
547 |
-
"object_position": null,
|
548 |
-
"order": null,
|
549 |
-
"overflow": null,
|
550 |
-
"overflow_x": null,
|
551 |
-
"overflow_y": null,
|
552 |
-
"padding": null,
|
553 |
-
"right": null,
|
554 |
-
"top": null,
|
555 |
-
"visibility": null,
|
556 |
-
"width": null
|
557 |
-
}
|
558 |
-
},
|
559 |
-
"075d6545e02e419ca565589eb5ffc318": {
|
560 |
-
"model_module": "@jupyter-widgets/base",
|
561 |
-
"model_name": "LayoutModel",
|
562 |
-
"state": {
|
563 |
-
"_model_module": "@jupyter-widgets/base",
|
564 |
-
"_model_module_version": "1.2.0",
|
565 |
-
"_model_name": "LayoutModel",
|
566 |
-
"_view_count": null,
|
567 |
-
"_view_module": "@jupyter-widgets/base",
|
568 |
-
"_view_module_version": "1.2.0",
|
569 |
-
"_view_name": "LayoutView",
|
570 |
-
"align_content": null,
|
571 |
-
"align_items": null,
|
572 |
-
"align_self": null,
|
573 |
-
"border": null,
|
574 |
-
"bottom": null,
|
575 |
-
"display": null,
|
576 |
-
"flex": null,
|
577 |
-
"flex_flow": null,
|
578 |
-
"grid_area": null,
|
579 |
-
"grid_auto_columns": null,
|
580 |
-
"grid_auto_flow": null,
|
581 |
-
"grid_auto_rows": null,
|
582 |
-
"grid_column": null,
|
583 |
-
"grid_gap": null,
|
584 |
-
"grid_row": null,
|
585 |
-
"grid_template_areas": null,
|
586 |
-
"grid_template_columns": null,
|
587 |
-
"grid_template_rows": null,
|
588 |
-
"height": null,
|
589 |
-
"justify_content": null,
|
590 |
-
"justify_items": null,
|
591 |
-
"left": null,
|
592 |
-
"margin": null,
|
593 |
-
"max_height": null,
|
594 |
-
"max_width": null,
|
595 |
-
"min_height": null,
|
596 |
-
"min_width": null,
|
597 |
-
"object_fit": null,
|
598 |
-
"object_position": null,
|
599 |
-
"order": null,
|
600 |
-
"overflow": null,
|
601 |
-
"overflow_x": null,
|
602 |
-
"overflow_y": null,
|
603 |
-
"padding": null,
|
604 |
-
"right": null,
|
605 |
-
"top": null,
|
606 |
-
"visibility": null,
|
607 |
-
"width": null
|
608 |
-
}
|
609 |
-
},
|
610 |
-
"0a1b6b76984349ccb36ca2fc4a4a0208": {
|
611 |
-
"model_module": "@jupyter-widgets/base",
|
612 |
-
"model_name": "LayoutModel",
|
613 |
-
"state": {
|
614 |
-
"_model_module": "@jupyter-widgets/base",
|
615 |
-
"_model_module_version": "1.2.0",
|
616 |
-
"_model_name": "LayoutModel",
|
617 |
-
"_view_count": null,
|
618 |
-
"_view_module": "@jupyter-widgets/base",
|
619 |
-
"_view_module_version": "1.2.0",
|
620 |
-
"_view_name": "LayoutView",
|
621 |
-
"align_content": null,
|
622 |
-
"align_items": null,
|
623 |
-
"align_self": null,
|
624 |
-
"border": null,
|
625 |
-
"bottom": null,
|
626 |
-
"display": null,
|
627 |
-
"flex": null,
|
628 |
-
"flex_flow": null,
|
629 |
-
"grid_area": null,
|
630 |
-
"grid_auto_columns": null,
|
631 |
-
"grid_auto_flow": null,
|
632 |
-
"grid_auto_rows": null,
|
633 |
-
"grid_column": null,
|
634 |
-
"grid_gap": null,
|
635 |
-
"grid_row": null,
|
636 |
-
"grid_template_areas": null,
|
637 |
-
"grid_template_columns": null,
|
638 |
-
"grid_template_rows": null,
|
639 |
-
"height": null,
|
640 |
-
"justify_content": null,
|
641 |
-
"justify_items": null,
|
642 |
-
"left": null,
|
643 |
-
"margin": null,
|
644 |
-
"max_height": null,
|
645 |
-
"max_width": null,
|
646 |
-
"min_height": null,
|
647 |
-
"min_width": null,
|
648 |
-
"object_fit": null,
|
649 |
-
"object_position": null,
|
650 |
-
"order": null,
|
651 |
-
"overflow": null,
|
652 |
-
"overflow_x": null,
|
653 |
-
"overflow_y": null,
|
654 |
-
"padding": null,
|
655 |
-
"right": null,
|
656 |
-
"top": null,
|
657 |
-
"visibility": null,
|
658 |
-
"width": null
|
659 |
-
}
|
660 |
-
},
|
661 |
-
"19c57d99e7c44cbda508ce558fde435d": {
|
662 |
-
"model_module": "@jupyter-widgets/base",
|
663 |
-
"model_name": "LayoutModel",
|
664 |
-
"state": {
|
665 |
-
"_model_module": "@jupyter-widgets/base",
|
666 |
-
"_model_module_version": "1.2.0",
|
667 |
-
"_model_name": "LayoutModel",
|
668 |
-
"_view_count": null,
|
669 |
-
"_view_module": "@jupyter-widgets/base",
|
670 |
-
"_view_module_version": "1.2.0",
|
671 |
-
"_view_name": "LayoutView",
|
672 |
-
"align_content": null,
|
673 |
-
"align_items": null,
|
674 |
-
"align_self": null,
|
675 |
-
"border": null,
|
676 |
-
"bottom": null,
|
677 |
-
"display": null,
|
678 |
-
"flex": null,
|
679 |
-
"flex_flow": null,
|
680 |
-
"grid_area": null,
|
681 |
-
"grid_auto_columns": null,
|
682 |
-
"grid_auto_flow": null,
|
683 |
-
"grid_auto_rows": null,
|
684 |
-
"grid_column": null,
|
685 |
-
"grid_gap": null,
|
686 |
-
"grid_row": null,
|
687 |
-
"grid_template_areas": null,
|
688 |
-
"grid_template_columns": null,
|
689 |
-
"grid_template_rows": null,
|
690 |
-
"height": null,
|
691 |
-
"justify_content": null,
|
692 |
-
"justify_items": null,
|
693 |
-
"left": null,
|
694 |
-
"margin": null,
|
695 |
-
"max_height": null,
|
696 |
-
"max_width": null,
|
697 |
-
"min_height": null,
|
698 |
-
"min_width": null,
|
699 |
-
"object_fit": null,
|
700 |
-
"object_position": null,
|
701 |
-
"order": null,
|
702 |
-
"overflow": null,
|
703 |
-
"overflow_x": null,
|
704 |
-
"overflow_y": null,
|
705 |
-
"padding": null,
|
706 |
-
"right": null,
|
707 |
-
"top": null,
|
708 |
-
"visibility": null,
|
709 |
-
"width": null
|
710 |
-
}
|
711 |
-
},
|
712 |
-
"285c877d4f644f3a8a58c4eb5948101c": {
|
713 |
-
"model_module": "@jupyter-widgets/controls",
|
714 |
-
"model_name": "ProgressStyleModel",
|
715 |
-
"state": {
|
716 |
-
"_model_module": "@jupyter-widgets/controls",
|
717 |
-
"_model_module_version": "1.5.0",
|
718 |
-
"_model_name": "ProgressStyleModel",
|
719 |
-
"_view_count": null,
|
720 |
-
"_view_module": "@jupyter-widgets/base",
|
721 |
-
"_view_module_version": "1.2.0",
|
722 |
-
"_view_name": "StyleView",
|
723 |
-
"bar_color": null,
|
724 |
-
"description_width": "initial"
|
725 |
-
}
|
726 |
-
},
|
727 |
-
"2e62544c03d64d6d92b94fcfaca2fc90": {
|
728 |
-
"model_module": "@jupyter-widgets/base",
|
729 |
-
"model_name": "LayoutModel",
|
730 |
-
"state": {
|
731 |
-
"_model_module": "@jupyter-widgets/base",
|
732 |
-
"_model_module_version": "1.2.0",
|
733 |
-
"_model_name": "LayoutModel",
|
734 |
-
"_view_count": null,
|
735 |
-
"_view_module": "@jupyter-widgets/base",
|
736 |
-
"_view_module_version": "1.2.0",
|
737 |
-
"_view_name": "LayoutView",
|
738 |
-
"align_content": null,
|
739 |
-
"align_items": null,
|
740 |
-
"align_self": null,
|
741 |
-
"border": null,
|
742 |
-
"bottom": null,
|
743 |
-
"display": null,
|
744 |
-
"flex": null,
|
745 |
-
"flex_flow": null,
|
746 |
-
"grid_area": null,
|
747 |
-
"grid_auto_columns": null,
|
748 |
-
"grid_auto_flow": null,
|
749 |
-
"grid_auto_rows": null,
|
750 |
-
"grid_column": null,
|
751 |
-
"grid_gap": null,
|
752 |
-
"grid_row": null,
|
753 |
-
"grid_template_areas": null,
|
754 |
-
"grid_template_columns": null,
|
755 |
-
"grid_template_rows": null,
|
756 |
-
"height": null,
|
757 |
-
"justify_content": null,
|
758 |
-
"justify_items": null,
|
759 |
-
"left": null,
|
760 |
-
"margin": null,
|
761 |
-
"max_height": null,
|
762 |
-
"max_width": null,
|
763 |
-
"min_height": null,
|
764 |
-
"min_width": null,
|
765 |
-
"object_fit": null,
|
766 |
-
"object_position": null,
|
767 |
-
"order": null,
|
768 |
-
"overflow": null,
|
769 |
-
"overflow_x": null,
|
770 |
-
"overflow_y": null,
|
771 |
-
"padding": null,
|
772 |
-
"right": null,
|
773 |
-
"top": null,
|
774 |
-
"visibility": null,
|
775 |
-
"width": null
|
776 |
-
}
|
777 |
-
},
|
778 |
-
"30178355f76742898d37966b3875ef0a": {
|
779 |
-
"model_module": "@jupyter-widgets/controls",
|
780 |
-
"model_name": "DescriptionStyleModel",
|
781 |
-
"state": {
|
782 |
-
"_model_module": "@jupyter-widgets/controls",
|
783 |
-
"_model_module_version": "1.5.0",
|
784 |
-
"_model_name": "DescriptionStyleModel",
|
785 |
-
"_view_count": null,
|
786 |
-
"_view_module": "@jupyter-widgets/base",
|
787 |
-
"_view_module_version": "1.2.0",
|
788 |
-
"_view_name": "StyleView",
|
789 |
-
"description_width": ""
|
790 |
-
}
|
791 |
-
},
|
792 |
-
"467a151e73744eccb199fe72aa352e5b": {
|
793 |
-
"model_module": "@jupyter-widgets/controls",
|
794 |
-
"model_name": "HTMLModel",
|
795 |
-
"state": {
|
796 |
-
"_dom_classes": [],
|
797 |
-
"_model_module": "@jupyter-widgets/controls",
|
798 |
-
"_model_module_version": "1.5.0",
|
799 |
-
"_model_name": "HTMLModel",
|
800 |
-
"_view_count": null,
|
801 |
-
"_view_module": "@jupyter-widgets/controls",
|
802 |
-
"_view_module_version": "1.5.0",
|
803 |
-
"_view_name": "HTMLView",
|
804 |
-
"description": "",
|
805 |
-
"description_tooltip": null,
|
806 |
-
"layout": "IPY_MODEL_2e62544c03d64d6d92b94fcfaca2fc90",
|
807 |
-
"placeholder": "",
|
808 |
-
"style": "IPY_MODEL_30178355f76742898d37966b3875ef0a",
|
809 |
-
"value": " 313/313 [01:26<00:00, 3.62it/s]"
|
810 |
-
}
|
811 |
-
},
|
812 |
-
"4e3a3f83649f45f8bef3434980634664": {
|
813 |
-
"model_module": "@jupyter-widgets/controls",
|
814 |
-
"model_name": "HBoxModel",
|
815 |
-
"state": {
|
816 |
-
"_dom_classes": [],
|
817 |
-
"_model_module": "@jupyter-widgets/controls",
|
818 |
-
"_model_module_version": "1.5.0",
|
819 |
-
"_model_name": "HBoxModel",
|
820 |
-
"_view_count": null,
|
821 |
-
"_view_module": "@jupyter-widgets/controls",
|
822 |
-
"_view_module_version": "1.5.0",
|
823 |
-
"_view_name": "HBoxView",
|
824 |
-
"box_style": "",
|
825 |
-
"children": [
|
826 |
-
"IPY_MODEL_4e7a7427d28a4ae684e0be4548eb9944",
|
827 |
-
"IPY_MODEL_cc9dc019c1334a46b2558ffa6c0dd6e6"
|
828 |
-
],
|
829 |
-
"layout": "IPY_MODEL_f066bdb766664c788ba1e9de8d311e22"
|
830 |
-
}
|
831 |
-
},
|
832 |
-
"4e7a7427d28a4ae684e0be4548eb9944": {
|
833 |
-
"model_module": "@jupyter-widgets/controls",
|
834 |
-
"model_name": "FloatProgressModel",
|
835 |
-
"state": {
|
836 |
-
"_dom_classes": [],
|
837 |
-
"_model_module": "@jupyter-widgets/controls",
|
838 |
-
"_model_module_version": "1.5.0",
|
839 |
-
"_model_name": "FloatProgressModel",
|
840 |
-
"_view_count": null,
|
841 |
-
"_view_module": "@jupyter-widgets/controls",
|
842 |
-
"_view_module_version": "1.5.0",
|
843 |
-
"_view_name": "ProgressView",
|
844 |
-
"bar_style": "success",
|
845 |
-
"description": "100%",
|
846 |
-
"description_tooltip": null,
|
847 |
-
"layout": "IPY_MODEL_075d6545e02e419ca565589eb5ffc318",
|
848 |
-
"max": 1000,
|
849 |
-
"min": 0,
|
850 |
-
"orientation": "horizontal",
|
851 |
-
"style": "IPY_MODEL_285c877d4f644f3a8a58c4eb5948101c",
|
852 |
-
"value": 1000
|
853 |
-
}
|
854 |
-
},
|
855 |
-
"53f9106c80e84d5b8c3ec96162d1db98": {
|
856 |
-
"model_module": "@jupyter-widgets/controls",
|
857 |
-
"model_name": "DescriptionStyleModel",
|
858 |
-
"state": {
|
859 |
-
"_model_module": "@jupyter-widgets/controls",
|
860 |
-
"_model_module_version": "1.5.0",
|
861 |
-
"_model_name": "DescriptionStyleModel",
|
862 |
-
"_view_count": null,
|
863 |
-
"_view_module": "@jupyter-widgets/base",
|
864 |
-
"_view_module_version": "1.2.0",
|
865 |
-
"_view_name": "StyleView",
|
866 |
-
"description_width": ""
|
867 |
-
}
|
868 |
-
},
|
869 |
-
"c136afb47aa14ac2832093ee415c6f3e": {
|
870 |
-
"model_module": "@jupyter-widgets/controls",
|
871 |
-
"model_name": "FloatProgressModel",
|
872 |
-
"state": {
|
873 |
-
"_dom_classes": [],
|
874 |
-
"_model_module": "@jupyter-widgets/controls",
|
875 |
-
"_model_module_version": "1.5.0",
|
876 |
-
"_model_name": "FloatProgressModel",
|
877 |
-
"_view_count": null,
|
878 |
-
"_view_module": "@jupyter-widgets/controls",
|
879 |
-
"_view_module_version": "1.5.0",
|
880 |
-
"_view_name": "ProgressView",
|
881 |
-
"bar_style": "success",
|
882 |
-
"description": "100%",
|
883 |
-
"description_tooltip": null,
|
884 |
-
"layout": "IPY_MODEL_029e6eadacb8480193aab52ff073be8f",
|
885 |
-
"max": 313,
|
886 |
-
"min": 0,
|
887 |
-
"orientation": "horizontal",
|
888 |
-
"style": "IPY_MODEL_f6d637c3fc3c46928d023441227130e5",
|
889 |
-
"value": 313
|
890 |
-
}
|
891 |
-
},
|
892 |
-
"cc9dc019c1334a46b2558ffa6c0dd6e6": {
|
893 |
-
"model_module": "@jupyter-widgets/controls",
|
894 |
-
"model_name": "HTMLModel",
|
895 |
-
"state": {
|
896 |
-
"_dom_classes": [],
|
897 |
-
"_model_module": "@jupyter-widgets/controls",
|
898 |
-
"_model_module_version": "1.5.0",
|
899 |
-
"_model_name": "HTMLModel",
|
900 |
-
"_view_count": null,
|
901 |
-
"_view_module": "@jupyter-widgets/controls",
|
902 |
-
"_view_module_version": "1.5.0",
|
903 |
-
"_view_name": "HTMLView",
|
904 |
-
"description": "",
|
905 |
-
"description_tooltip": null,
|
906 |
-
"layout": "IPY_MODEL_19c57d99e7c44cbda508ce558fde435d",
|
907 |
-
"placeholder": "",
|
908 |
-
"style": "IPY_MODEL_53f9106c80e84d5b8c3ec96162d1db98",
|
909 |
-
"value": " 1000/1000 [01:09<00:00, 14.35it/s]"
|
910 |
-
}
|
911 |
-
},
|
912 |
-
"f066bdb766664c788ba1e9de8d311e22": {
|
913 |
-
"model_module": "@jupyter-widgets/base",
|
914 |
-
"model_name": "LayoutModel",
|
915 |
-
"state": {
|
916 |
-
"_model_module": "@jupyter-widgets/base",
|
917 |
-
"_model_module_version": "1.2.0",
|
918 |
-
"_model_name": "LayoutModel",
|
919 |
-
"_view_count": null,
|
920 |
-
"_view_module": "@jupyter-widgets/base",
|
921 |
-
"_view_module_version": "1.2.0",
|
922 |
-
"_view_name": "LayoutView",
|
923 |
-
"align_content": null,
|
924 |
-
"align_items": null,
|
925 |
-
"align_self": null,
|
926 |
-
"border": null,
|
927 |
-
"bottom": null,
|
928 |
-
"display": null,
|
929 |
-
"flex": null,
|
930 |
-
"flex_flow": null,
|
931 |
-
"grid_area": null,
|
932 |
-
"grid_auto_columns": null,
|
933 |
-
"grid_auto_flow": null,
|
934 |
-
"grid_auto_rows": null,
|
935 |
-
"grid_column": null,
|
936 |
-
"grid_gap": null,
|
937 |
-
"grid_row": null,
|
938 |
-
"grid_template_areas": null,
|
939 |
-
"grid_template_columns": null,
|
940 |
-
"grid_template_rows": null,
|
941 |
-
"height": null,
|
942 |
-
"justify_content": null,
|
943 |
-
"justify_items": null,
|
944 |
-
"left": null,
|
945 |
-
"margin": null,
|
946 |
-
"max_height": null,
|
947 |
-
"max_width": null,
|
948 |
-
"min_height": null,
|
949 |
-
"min_width": null,
|
950 |
-
"object_fit": null,
|
951 |
-
"object_position": null,
|
952 |
-
"order": null,
|
953 |
-
"overflow": null,
|
954 |
-
"overflow_x": null,
|
955 |
-
"overflow_y": null,
|
956 |
-
"padding": null,
|
957 |
-
"right": null,
|
958 |
-
"top": null,
|
959 |
-
"visibility": null,
|
960 |
-
"width": null
|
961 |
-
}
|
962 |
-
},
|
963 |
-
"f6d637c3fc3c46928d023441227130e5": {
|
964 |
-
"model_module": "@jupyter-widgets/controls",
|
965 |
-
"model_name": "ProgressStyleModel",
|
966 |
-
"state": {
|
967 |
-
"_model_module": "@jupyter-widgets/controls",
|
968 |
-
"_model_module_version": "1.5.0",
|
969 |
-
"_model_name": "ProgressStyleModel",
|
970 |
-
"_view_count": null,
|
971 |
-
"_view_module": "@jupyter-widgets/base",
|
972 |
-
"_view_module_version": "1.2.0",
|
973 |
-
"_view_name": "StyleView",
|
974 |
-
"bar_color": null,
|
975 |
-
"description_width": "initial"
|
976 |
-
}
|
977 |
-
},
|
978 |
-
"fbb2b937b22049f5987f39f48c652a86": {
|
979 |
-
"model_module": "@jupyter-widgets/controls",
|
980 |
-
"model_name": "HBoxModel",
|
981 |
-
"state": {
|
982 |
-
"_dom_classes": [],
|
983 |
-
"_model_module": "@jupyter-widgets/controls",
|
984 |
-
"_model_module_version": "1.5.0",
|
985 |
-
"_model_name": "HBoxModel",
|
986 |
-
"_view_count": null,
|
987 |
-
"_view_module": "@jupyter-widgets/controls",
|
988 |
-
"_view_module_version": "1.5.0",
|
989 |
-
"_view_name": "HBoxView",
|
990 |
-
"box_style": "",
|
991 |
-
"children": [
|
992 |
-
"IPY_MODEL_c136afb47aa14ac2832093ee415c6f3e",
|
993 |
-
"IPY_MODEL_467a151e73744eccb199fe72aa352e5b"
|
994 |
-
],
|
995 |
-
"layout": "IPY_MODEL_0a1b6b76984349ccb36ca2fc4a4a0208"
|
996 |
-
}
|
997 |
-
}
|
998 |
-
}
|
999 |
-
}
|
1000 |
-
},
|
1001 |
-
"nbformat": 4,
|
1002 |
-
"nbformat_minor": 4
|
1003 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flax_model.msgpack
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:32679e84a5c7c8148def9a503553b3dddab3529928c3d069b3146f7257e577fe
|
3 |
+
size 795766616
|
hybrid_clip/README.md
DELETED
@@ -1,172 +0,0 @@
|
|
1 |
-
<!---
|
2 |
-
Copyright 2021 The HuggingFace Team. All rights reserved.
|
3 |
-
|
4 |
-
Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
you may not use this file except in compliance with the License.
|
6 |
-
You may obtain a copy of the License at
|
7 |
-
|
8 |
-
http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
|
10 |
-
Unless required by applicable law or agreed to in writing, software
|
11 |
-
distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
See the License for the specific language governing permissions and
|
14 |
-
limitations under the License.
|
15 |
-
-->
|
16 |
-
|
17 |
-
# Vision-Text dual encoder model training examples
|
18 |
-
|
19 |
-
> Note: This example is experimental and might not give the best possible results
|
20 |
-
|
21 |
-
The following example showcases how to train a CLIP like vision-text dual encoder model
|
22 |
-
using a pre-trained vision and text encoder using the JAX/Flax backend.
|
23 |
-
|
24 |
-
Such a model can be used for natural language image search and potentially zero-shot image classification.
|
25 |
-
The model is inspired by the [CLIP](https://openai.com/blog/clip/) approach, introduced by Alec Radford et al.
|
26 |
-
The idea is to train a vision encoder and a text encoder jointly to project the representation of images and their
|
27 |
-
captions into the same embedding space, such that the caption embeddings are located near the embeddings
|
28 |
-
of the images they describe.
|
29 |
-
|
30 |
-
JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU.
|
31 |
-
Models written in JAX/Flax are **immutable** and updated in a purely functional
|
32 |
-
way which enables simple and efficient model parallelism.
|
33 |
-
|
34 |
-
In this example we will use the vision model from [CLIP](https://huggingface.co/models?filter=clip)
|
35 |
-
as the image encoder and [`roberta-base`](https://huggingface.co/roberta-base) as the text encoder.
|
36 |
-
Note that one can also use the [ViT](https://huggingface.co/models?filter=vit) model as image encoder and any other BERT or ROBERTa model as text encoder.
|
37 |
-
To train the model on languages other than English one should choose a text encoder trained on the desired
|
38 |
-
language and a image-text dataset in that language. One such dataset is [WIT](https://github.com/google-research-datasets/wit).
|
39 |
-
|
40 |
-
Let's start by creating a model repository to save the trained model and logs.
|
41 |
-
Here we call the model `"clip-roberta-base"`, but you can change the model name as you like.
|
42 |
-
|
43 |
-
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
|
44 |
-
you are logged in) or via the command line:
|
45 |
-
|
46 |
-
```
|
47 |
-
huggingface-cli repo create clip-roberta-base
|
48 |
-
```
|
49 |
-
Next we clone the model repository to add the tokenizer and model files.
|
50 |
-
```
|
51 |
-
git clone https://huggingface.co/<your-username>/clip-roberta-base
|
52 |
-
```
|
53 |
-
To ensure that all tensorboard traces will be uploaded correctly, we need to
|
54 |
-
track them. You can run the following command inside your model repo to do so.
|
55 |
-
|
56 |
-
```
|
57 |
-
cd clip-roberta-base
|
58 |
-
git lfs track "*tfevents*"
|
59 |
-
```
|
60 |
-
|
61 |
-
Great, we have set up our model repository. During training, we will automatically
|
62 |
-
push the training logs and model weights to the repo.
|
63 |
-
|
64 |
-
Next, let's add a symbolic link to the `run_hybrid_clip.py`.
|
65 |
-
|
66 |
-
```bash
|
67 |
-
export MODEL_DIR="./clip-roberta-base
|
68 |
-
ln -s ~/transformers/examples/flax/summarization/run_hybrid_clip.py run_hybrid_clip.py
|
69 |
-
```
|
70 |
-
|
71 |
-
## How to use the `FlaxHybridCLIP` model:
|
72 |
-
|
73 |
-
The `FlaxHybridCLIP` class let's you load any text and vision encoder model to create a dual encoder.
|
74 |
-
Here is an example of how to load the model using pre-trained text and vision models.
|
75 |
-
|
76 |
-
```python
|
77 |
-
from modeling_hybrid_clip import FlaxHybridCLIP
|
78 |
-
|
79 |
-
model = FlaxHybridCLIP.from_text_vision_pretrained("bert-base-uncased", "openai/clip-vit-base-patch32")
|
80 |
-
|
81 |
-
# save the model
|
82 |
-
model.save_pretrained("bert-clip")
|
83 |
-
|
84 |
-
# load the saved model
|
85 |
-
model = FlaxHybridCLIP.from_pretrained("bert-clip")
|
86 |
-
```
|
87 |
-
|
88 |
-
If the checkpoints are in PyTorch then one could pass `text_from_pt=True` and `vision_from_pt=True`. This will load the model
|
89 |
-
PyTorch checkpoints convert them to flax and load the model.
|
90 |
-
|
91 |
-
```python
|
92 |
-
model = FlaxHybridCLIP.from_text_vision_pretrained("bert-base-uncased", "openai/clip-vit-base-patch32", text_from_pt=True, vision_from_pt=True)
|
93 |
-
```
|
94 |
-
|
95 |
-
This loads both the text and vision encoders using pre-trained weights, the projection layers are randomly
|
96 |
-
initialized except for CLIP's vision model. If you use CLIP to initialize the vision model then the vision projection weights are also
|
97 |
-
loaded using the pre-trained weights.
|
98 |
-
|
99 |
-
## Prepare the dataset
|
100 |
-
|
101 |
-
We will use the MS-COCO dataset to train our dual encoder model. MS-COCO contains over 82,000 images, each of which has at least 5 different caption annotations. The dataset is usually used for image captioning tasks, but we can repurpose the image-caption pairs to train our dual encoder model for image search.
|
102 |
-
|
103 |
-
### Download and extract the data.
|
104 |
-
|
105 |
-
It consists of two compressed folders: one with images, and the other—with associated image captions. Note that the compressed images folder is 13GB in size.
|
106 |
-
|
107 |
-
```bash
|
108 |
-
wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip
|
109 |
-
wget http://images.cocodataset.org/zips/train2014.zip
|
110 |
-
|
111 |
-
unzip annotations_trainval2014.zip
|
112 |
-
unzip train2014.zip
|
113 |
-
|
114 |
-
mkdir coco_dataset
|
115 |
-
mv train2014 coco_dataset/
|
116 |
-
mv annotations coco_dataset/
|
117 |
-
```
|
118 |
-
|
119 |
-
### Prepare dataset files and split the dataset.
|
120 |
-
|
121 |
-
```python
|
122 |
-
import json
|
123 |
-
import collections
|
124 |
-
|
125 |
-
images_dir = "coco_dataset/train2014"
|
126 |
-
annotation_file = "coco_dataset/annotations/captions_train2014.json"
|
127 |
-
with open(annotation_file, "r") as f:
|
128 |
-
annotations = json.load(f)["annotations"]
|
129 |
-
|
130 |
-
image_path_to_caption = collections.defaultdict(list)
|
131 |
-
for element in annotations:
|
132 |
-
caption = f"{element['caption'].lower().rstrip('.')}"
|
133 |
-
image_path = images_dir + "/COCO_train2014_" + "%012d.jpg" % (element["image_id"])
|
134 |
-
image_path_to_caption[image_path].append(caption)
|
135 |
-
|
136 |
-
lines = []
|
137 |
-
for image_path, captions in image_path_to_caption.items():
|
138 |
-
lines.append(json.dumps({"image_path": image_path, "captions": captions}))
|
139 |
-
|
140 |
-
train_lines = lines[:-8000]
|
141 |
-
valid_line = lines[-8000:]
|
142 |
-
with open("coco_dataset/train_dataset.json", "w") as f:
|
143 |
-
f.write("\n".join(train_lines))
|
144 |
-
|
145 |
-
with open("coco_dataset/valid_dataset.json", "w") as f:
|
146 |
-
f.write("\n".join(valid_line))
|
147 |
-
```
|
148 |
-
|
149 |
-
> Note: The data loading and processing part of this script can still be improved for maximum performance. In particular one should decode the images beforehand and use those instead decoding them each time. If the dataset is small or if you have huge disk space the you could also pre-process all the dataset beforehand and then use it.
|
150 |
-
|
151 |
-
## Train the model
|
152 |
-
Next we can run the example script to train the model:
|
153 |
-
|
154 |
-
```bash
|
155 |
-
python run_hybrid_clip.py \
|
156 |
-
--output_dir ${MODEL_DIR} \
|
157 |
-
--text_model_name_or_path="roberta-base" \
|
158 |
-
--vision_model_name_or_path="openai/clip-vit-base-patch32" \
|
159 |
-
--tokenizer_name="roberta-base" \
|
160 |
-
--train_file="coco_dataset/train_dataset.json" \
|
161 |
-
--validation_file="coco_dataset/validation_dataset.json" \
|
162 |
-
--do_train --do_eval \
|
163 |
-
--num_train_epochs="40" --max_seq_length 96 \
|
164 |
-
--per_device_train_batch_size="64" \
|
165 |
-
--per_device_eval_batch_size="64" \
|
166 |
-
--learning_rate="5e-5" --warmup_steps="0" --weight_decay 0.1 \
|
167 |
-
--overwrite_output_dir \
|
168 |
-
--preprocessing_num_workers 32 \
|
169 |
-
--push_to_hub
|
170 |
-
```
|
171 |
-
|
172 |
-
This should finish in ~1h50 mins with min validation loss 2.43. Training statistics can be accessed on [tfhub.de](https://tensorboard.dev/experiment/RUNPYd1yRgSD5kZSb9hDig/#scalars)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hybrid_clip/configuration_hybrid_clip.py
DELETED
@@ -1,108 +0,0 @@
|
|
1 |
-
import copy
|
2 |
-
|
3 |
-
from transformers.configuration_utils import PretrainedConfig
|
4 |
-
from transformers.utils import logging
|
5 |
-
|
6 |
-
|
7 |
-
logger = logging.get_logger(__name__)
|
8 |
-
|
9 |
-
|
10 |
-
class HybridCLIPConfig(PretrainedConfig):
|
11 |
-
r"""
|
12 |
-
:class:`HybridCLIPConfig` is the configuration class to store the configuration of a
|
13 |
-
:class:`~HybridCLIPModel`. It is used to instantiate HybridCLIPModel model according to the specified arguments,
|
14 |
-
defining the text model and vision model configs.
|
15 |
-
|
16 |
-
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
|
17 |
-
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
|
18 |
-
|
19 |
-
Args:
|
20 |
-
text_config_dict (:obj:`dict`):
|
21 |
-
Dictionary of configuration options that defines text model config.
|
22 |
-
vision_config_dict (:obj:`dict`):
|
23 |
-
Dictionary of configuration options that defines vison model config.
|
24 |
-
projection_dim (:obj:`int`, `optional`, defaults to 512):
|
25 |
-
Dimentionality of text and vision projection layers.
|
26 |
-
kwargs (`optional`):
|
27 |
-
Dictionary of keyword arguments.
|
28 |
-
|
29 |
-
Examples::
|
30 |
-
|
31 |
-
>>> from transformers import BertConfig, CLIPConfig, HybridCLIPConfig, FlaxHybridCLIP
|
32 |
-
|
33 |
-
>>> # Initializing a BERT and CLIP configuration
|
34 |
-
>>> config_text = BertConfig()
|
35 |
-
>>> config_vision = CLIPConfig()
|
36 |
-
|
37 |
-
>>> config = HybridCLIPConfig.from_text_vision_configs(config_text, config_vision, projection_dim=512)
|
38 |
-
|
39 |
-
>>> # Initializing a BERT and CLIPVision model
|
40 |
-
>>> model = EncoderDecoderModel(config=config)
|
41 |
-
|
42 |
-
>>> # Accessing the model configuration
|
43 |
-
>>> config_text = model.config.text_config
|
44 |
-
>>> config_vision = model.config.vision_config
|
45 |
-
|
46 |
-
>>> # Saving the model, including its configuration
|
47 |
-
>>> model.save_pretrained('my-model')
|
48 |
-
|
49 |
-
>>> # loading model and config from pretrained folder
|
50 |
-
>>> encoder_decoder_config = HybridCLIPConfig.from_pretrained('my-model')
|
51 |
-
>>> model = FlaxHybridCLIP.from_pretrained('my-model', config=encoder_decoder_config)
|
52 |
-
"""
|
53 |
-
|
54 |
-
model_type = "hybrid-clip"
|
55 |
-
is_composition = True
|
56 |
-
|
57 |
-
def __init__(self, projection_dim=512, **kwargs):
|
58 |
-
super().__init__(**kwargs)
|
59 |
-
|
60 |
-
if "text_config" not in kwargs:
|
61 |
-
raise ValueError("`text_config` can not be `None`.")
|
62 |
-
|
63 |
-
if "vision_config" not in kwargs:
|
64 |
-
raise ValueError("`vision_config` can not be `None`.")
|
65 |
-
|
66 |
-
text_config = kwargs.pop("text_config")
|
67 |
-
vision_config = kwargs.pop("vision_config")
|
68 |
-
|
69 |
-
text_model_type = text_config.pop("model_type")
|
70 |
-
vision_model_type = vision_config.pop("model_type")
|
71 |
-
|
72 |
-
from transformers import AutoConfig
|
73 |
-
|
74 |
-
self.text_config = AutoConfig.for_model(text_model_type, **text_config)
|
75 |
-
|
76 |
-
if vision_model_type == "clip":
|
77 |
-
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config
|
78 |
-
else:
|
79 |
-
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)
|
80 |
-
|
81 |
-
self.projection_dim = projection_dim
|
82 |
-
self.initializer_factor = 1.0
|
83 |
-
|
84 |
-
@classmethod
|
85 |
-
def from_text_vision_configs(cls, text_config: PretrainedConfig, vision_config: PretrainedConfig, **kwargs):
|
86 |
-
r"""
|
87 |
-
Instantiate a :class:`HybridCLIPConfig` (or a derived class) from text model configuration and
|
88 |
-
vision model configuration.
|
89 |
-
|
90 |
-
Returns:
|
91 |
-
:class:`HybridCLIPConfig`: An instance of a configuration object
|
92 |
-
"""
|
93 |
-
|
94 |
-
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
|
95 |
-
|
96 |
-
def to_dict(self):
|
97 |
-
"""
|
98 |
-
Serializes this instance to a Python dictionary. Override the default
|
99 |
-
:meth:`~transformers.PretrainedConfig.to_dict`.
|
100 |
-
|
101 |
-
Returns:
|
102 |
-
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
103 |
-
"""
|
104 |
-
output = copy.deepcopy(self.__dict__)
|
105 |
-
output["text_config"] = self.text_config.to_dict()
|
106 |
-
output["vision_config"] = self.vision_config.to_dict()
|
107 |
-
output["model_type"] = self.__class__.model_type
|
108 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hybrid_clip/modeling_hybrid_clip.py
DELETED
@@ -1,420 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
|
16 |
-
from typing import Optional, Tuple
|
17 |
-
|
18 |
-
import flax.linen as nn
|
19 |
-
import jax
|
20 |
-
import jax.numpy as jnp
|
21 |
-
from configuration_hybrid_clip import HybridCLIPConfig
|
22 |
-
from flax.core.frozen_dict import FrozenDict
|
23 |
-
from transformers import FLAX_MODEL_MAPPING, FlaxCLIPVisionModel
|
24 |
-
from transformers.modeling_flax_utils import FlaxPreTrainedModel
|
25 |
-
from transformers.models.clip.modeling_flax_clip import FlaxCLIPOutput
|
26 |
-
from transformers.utils import logging
|
27 |
-
|
28 |
-
|
29 |
-
logger = logging.get_logger(__name__)
|
30 |
-
|
31 |
-
|
32 |
-
class FlaxHybridCLIPModule(nn.Module):
|
33 |
-
config: HybridCLIPConfig
|
34 |
-
dtype: jnp.dtype = jnp.float32
|
35 |
-
|
36 |
-
def setup(self):
|
37 |
-
text_config = self.config.text_config
|
38 |
-
vision_config = self.config.vision_config
|
39 |
-
|
40 |
-
self.projection_dim = self.config.projection_dim
|
41 |
-
self.text_embed_dim = text_config.hidden_size
|
42 |
-
self.vision_embed_dim = vision_config.hidden_size
|
43 |
-
|
44 |
-
text_module = FLAX_MODEL_MAPPING[self.config.text_config.__class__].module_class
|
45 |
-
vision_module = FLAX_MODEL_MAPPING.get(self.config.vision_config.__class__, FlaxCLIPVisionModel).module_class
|
46 |
-
|
47 |
-
self.text_model = text_module(text_config, dtype=self.dtype)
|
48 |
-
self.vision_model = vision_module(vision_config, dtype=self.dtype)
|
49 |
-
|
50 |
-
self.visual_projection = nn.Dense(
|
51 |
-
self.projection_dim,
|
52 |
-
dtype=self.dtype,
|
53 |
-
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
|
54 |
-
use_bias=False,
|
55 |
-
)
|
56 |
-
self.text_projection = nn.Dense(
|
57 |
-
self.projection_dim,
|
58 |
-
dtype=self.dtype,
|
59 |
-
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
|
60 |
-
use_bias=False,
|
61 |
-
)
|
62 |
-
self.logit_scale = self.param("logit_scale", jax.nn.initializers.ones, [])
|
63 |
-
|
64 |
-
def __call__(
|
65 |
-
self,
|
66 |
-
input_ids=None,
|
67 |
-
pixel_values=None,
|
68 |
-
attention_mask=None,
|
69 |
-
position_ids=None,
|
70 |
-
token_type_ids=None,
|
71 |
-
deterministic: bool = True,
|
72 |
-
output_attentions=None,
|
73 |
-
output_hidden_states=None,
|
74 |
-
return_dict=None,
|
75 |
-
):
|
76 |
-
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
77 |
-
|
78 |
-
vision_outputs = self.vision_model(
|
79 |
-
pixel_values=pixel_values,
|
80 |
-
deterministic=deterministic,
|
81 |
-
output_attentions=output_attentions,
|
82 |
-
output_hidden_states=output_hidden_states,
|
83 |
-
return_dict=return_dict,
|
84 |
-
)
|
85 |
-
|
86 |
-
text_outputs = self.text_model(
|
87 |
-
input_ids=input_ids,
|
88 |
-
attention_mask=attention_mask,
|
89 |
-
token_type_ids=token_type_ids,
|
90 |
-
position_ids=position_ids,
|
91 |
-
deterministic=deterministic,
|
92 |
-
output_attentions=output_attentions,
|
93 |
-
output_hidden_states=output_hidden_states,
|
94 |
-
return_dict=return_dict,
|
95 |
-
)
|
96 |
-
|
97 |
-
image_embeds = vision_outputs[1]
|
98 |
-
image_embeds = self.visual_projection(image_embeds)
|
99 |
-
|
100 |
-
text_embeds = text_outputs[1]
|
101 |
-
text_embeds = self.text_projection(text_embeds)
|
102 |
-
|
103 |
-
# normalized features
|
104 |
-
image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=-1, keepdims=True)
|
105 |
-
text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True)
|
106 |
-
|
107 |
-
# cosine similarity as logits
|
108 |
-
logit_scale = jnp.exp(self.logit_scale)
|
109 |
-
logits_per_text = jnp.matmul(text_embeds, image_embeds.T) * logit_scale
|
110 |
-
logits_per_image = logits_per_text.T
|
111 |
-
|
112 |
-
if not return_dict:
|
113 |
-
return (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
114 |
-
|
115 |
-
return FlaxCLIPOutput(
|
116 |
-
logits_per_image=logits_per_image,
|
117 |
-
logits_per_text=logits_per_text,
|
118 |
-
text_embeds=text_embeds,
|
119 |
-
image_embeds=image_embeds,
|
120 |
-
text_model_output=text_outputs,
|
121 |
-
vision_model_output=vision_outputs,
|
122 |
-
)
|
123 |
-
|
124 |
-
|
125 |
-
class FlaxHybridCLIP(FlaxPreTrainedModel):
|
126 |
-
config_class = HybridCLIPConfig
|
127 |
-
module_class = FlaxHybridCLIPModule
|
128 |
-
|
129 |
-
def __init__(
|
130 |
-
self,
|
131 |
-
config: HybridCLIPConfig,
|
132 |
-
input_shape: Optional[Tuple] = None,
|
133 |
-
seed: int = 0,
|
134 |
-
dtype: jnp.dtype = jnp.float32,
|
135 |
-
**kwargs
|
136 |
-
):
|
137 |
-
if input_shape is None:
|
138 |
-
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
|
139 |
-
|
140 |
-
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
141 |
-
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
142 |
-
|
143 |
-
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
144 |
-
# init input tensor
|
145 |
-
input_ids = jnp.zeros(input_shape[0], dtype="i4")
|
146 |
-
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
|
147 |
-
token_type_ids = jnp.ones_like(input_ids)
|
148 |
-
attention_mask = jnp.ones_like(input_ids)
|
149 |
-
|
150 |
-
pixel_values = jax.random.normal(rng, input_shape[1])
|
151 |
-
|
152 |
-
params_rng, dropout_rng = jax.random.split(rng)
|
153 |
-
rngs = {"params": params_rng, "dropout": dropout_rng}
|
154 |
-
|
155 |
-
return self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids)["params"]
|
156 |
-
|
157 |
-
def __call__(
|
158 |
-
self,
|
159 |
-
input_ids,
|
160 |
-
pixel_values,
|
161 |
-
attention_mask=None,
|
162 |
-
position_ids=None,
|
163 |
-
token_type_ids=None,
|
164 |
-
params: dict = None,
|
165 |
-
dropout_rng: jax.random.PRNGKey = None,
|
166 |
-
train: bool = False,
|
167 |
-
output_attentions: Optional[bool] = None,
|
168 |
-
output_hidden_states: Optional[bool] = None,
|
169 |
-
return_dict: Optional[bool] = None,
|
170 |
-
):
|
171 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
172 |
-
output_hidden_states = (
|
173 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
174 |
-
)
|
175 |
-
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
176 |
-
|
177 |
-
if position_ids is None:
|
178 |
-
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
179 |
-
|
180 |
-
if token_type_ids is None:
|
181 |
-
token_type_ids = jnp.zeros_like(input_ids)
|
182 |
-
|
183 |
-
if attention_mask is None:
|
184 |
-
attention_mask = jnp.ones_like(input_ids)
|
185 |
-
|
186 |
-
# Handle any PRNG if needed
|
187 |
-
rngs = {}
|
188 |
-
if dropout_rng is not None:
|
189 |
-
rngs["dropout"] = dropout_rng
|
190 |
-
|
191 |
-
return self.module.apply(
|
192 |
-
{"params": params or self.params},
|
193 |
-
jnp.array(input_ids, dtype="i4"),
|
194 |
-
jnp.array(pixel_values, dtype=jnp.float32),
|
195 |
-
jnp.array(attention_mask, dtype="i4"),
|
196 |
-
jnp.array(position_ids, dtype="i4"),
|
197 |
-
jnp.array(token_type_ids, dtype="i4"),
|
198 |
-
not train,
|
199 |
-
output_attentions,
|
200 |
-
output_hidden_states,
|
201 |
-
return_dict,
|
202 |
-
rngs=rngs,
|
203 |
-
)
|
204 |
-
|
205 |
-
def get_text_features(
|
206 |
-
self,
|
207 |
-
input_ids,
|
208 |
-
attention_mask=None,
|
209 |
-
position_ids=None,
|
210 |
-
token_type_ids=None,
|
211 |
-
dropout_rng: jax.random.PRNGKey = None,
|
212 |
-
train=False,
|
213 |
-
):
|
214 |
-
r"""
|
215 |
-
Args:
|
216 |
-
input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`):
|
217 |
-
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
218 |
-
provide it.
|
219 |
-
|
220 |
-
Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
|
221 |
-
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
|
222 |
-
for details.
|
223 |
-
|
224 |
-
`What are input IDs? <../glossary.html#input-ids>`__
|
225 |
-
|
226 |
-
Returns:
|
227 |
-
text_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The text embeddings
|
228 |
-
obtained by applying the projection layer to the pooled output of text model.
|
229 |
-
"""
|
230 |
-
if position_ids is None:
|
231 |
-
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
232 |
-
|
233 |
-
if token_type_ids is None:
|
234 |
-
token_type_ids = jnp.zeros_like(input_ids)
|
235 |
-
|
236 |
-
if attention_mask is None:
|
237 |
-
attention_mask = jnp.ones_like(input_ids)
|
238 |
-
|
239 |
-
# Handle any PRNG if needed
|
240 |
-
rngs = {}
|
241 |
-
if dropout_rng is not None:
|
242 |
-
rngs["dropout"] = dropout_rng
|
243 |
-
|
244 |
-
def _get_features(module, input_ids, attention_mask, position_ids, token_type_ids, deterministic):
|
245 |
-
text_outputs = module.text_model(
|
246 |
-
input_ids=input_ids,
|
247 |
-
attention_mask=attention_mask,
|
248 |
-
position_ids=position_ids,
|
249 |
-
token_type_ids=token_type_ids,
|
250 |
-
deterministic=deterministic,
|
251 |
-
)
|
252 |
-
pooled_output = text_outputs[1]
|
253 |
-
text_features = module.text_projection(pooled_output)
|
254 |
-
return text_features
|
255 |
-
|
256 |
-
return self.module.apply(
|
257 |
-
{"params": self.params},
|
258 |
-
jnp.array(input_ids, dtype="i4"),
|
259 |
-
jnp.array(attention_mask, dtype="i4"),
|
260 |
-
jnp.array(position_ids, dtype="i4"),
|
261 |
-
jnp.array(token_type_ids, dtype="i4"),
|
262 |
-
not train,
|
263 |
-
method=_get_features,
|
264 |
-
rngs=rngs,
|
265 |
-
)
|
266 |
-
|
267 |
-
def get_image_features(self, pixel_values, dropout_rng: jax.random.PRNGKey = None, train=False):
|
268 |
-
r"""
|
269 |
-
Args:
|
270 |
-
pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
|
271 |
-
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained
|
272 |
-
using :class:`~transformers.ImageFeatureExtractionMixin`. See
|
273 |
-
:meth:`transformers.ImageFeatureExtractionMixin.__call__` for details.
|
274 |
-
|
275 |
-
Returns:
|
276 |
-
image_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The image embeddings
|
277 |
-
obtained by applying the projection layer to the pooled output of vision model.
|
278 |
-
"""
|
279 |
-
|
280 |
-
# Handle any PRNG if needed
|
281 |
-
rngs = {}
|
282 |
-
if dropout_rng is not None:
|
283 |
-
rngs["dropout"] = dropout_rng
|
284 |
-
|
285 |
-
def _get_features(module, pixel_values, deterministic):
|
286 |
-
vision_outputs = module.vision_model(pixel_values=pixel_values, deterministic=deterministic)
|
287 |
-
pooled_output = vision_outputs[1] # pooled_output
|
288 |
-
image_features = module.visual_projection(pooled_output)
|
289 |
-
return image_features
|
290 |
-
|
291 |
-
return self.module.apply(
|
292 |
-
{"params": self.params},
|
293 |
-
jnp.array(pixel_values, dtype=jnp.float32),
|
294 |
-
not train,
|
295 |
-
method=_get_features,
|
296 |
-
rngs=rngs,
|
297 |
-
)
|
298 |
-
|
299 |
-
@classmethod
|
300 |
-
def from_text_vision_pretrained(
|
301 |
-
cls,
|
302 |
-
text_model_name_or_path: str = None,
|
303 |
-
vision_model_name_or_path: str = None,
|
304 |
-
*model_args,
|
305 |
-
**kwargs,
|
306 |
-
) -> FlaxPreTrainedModel:
|
307 |
-
"""
|
308 |
-
Params:
|
309 |
-
text_model_name_or_path (:obj: `str`, `optional`):
|
310 |
-
Information necessary to initiate the text model. Can be either:
|
311 |
-
|
312 |
-
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
313 |
-
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
|
314 |
-
a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
315 |
-
- A path to a `directory` containing model weights saved using
|
316 |
-
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
317 |
-
- A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In
|
318 |
-
this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
|
319 |
-
as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in
|
320 |
-
a Flax model using the provided conversion scripts and loading the Flax model afterwards.
|
321 |
-
|
322 |
-
vision_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
|
323 |
-
Information necessary to initiate the vision model. Can be either:
|
324 |
-
|
325 |
-
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
326 |
-
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
|
327 |
-
a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
328 |
-
- A path to a `directory` containing model weights saved using
|
329 |
-
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
330 |
-
- A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In
|
331 |
-
this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
|
332 |
-
as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in
|
333 |
-
a Flax model using the provided conversion scripts and loading the Flax model afterwards.
|
334 |
-
|
335 |
-
model_args (remaining positional arguments, `optional`):
|
336 |
-
All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
|
337 |
-
|
338 |
-
kwargs (remaining dictionary of keyword arguments, `optional`):
|
339 |
-
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
340 |
-
:obj:`output_attentions=True`).
|
341 |
-
|
342 |
-
- To update the text configuration, use the prefix `text_` for each configuration parameter.
|
343 |
-
- To update the vision configuration, use the prefix `vision_` for each configuration parameter.
|
344 |
-
- To update the parent model configuration, do not use a prefix for each configuration parameter.
|
345 |
-
|
346 |
-
Behaves differently depending on whether a :obj:`config` is provided or automatically loaded.
|
347 |
-
|
348 |
-
Example::
|
349 |
-
|
350 |
-
>>> from transformers import FlaxHybridCLIP
|
351 |
-
>>> # initialize a model from pretrained BERT and CLIP models. Note that the projection layers will be randomly initialized.
|
352 |
-
>>> # If using CLIP's vision model the vision projection layer will be initialized using pre-trained weights
|
353 |
-
>>> model = FlaxHybridCLIP.from_text_vision_pretrained('bert-base-uncased', 'openai/clip-vit-base-patch32')
|
354 |
-
>>> # saving model after fine-tuning
|
355 |
-
>>> model.save_pretrained("./bert-clip")
|
356 |
-
>>> # load fine-tuned model
|
357 |
-
>>> model = FlaxHybridCLIP.from_pretrained("./bert-clip")
|
358 |
-
"""
|
359 |
-
|
360 |
-
kwargs_text = {
|
361 |
-
argument[len("text_") :]: value for argument, value in kwargs.items() if argument.startswith("text_")
|
362 |
-
}
|
363 |
-
|
364 |
-
kwargs_vision = {
|
365 |
-
argument[len("vision_") :]: value for argument, value in kwargs.items() if argument.startswith("vision_")
|
366 |
-
}
|
367 |
-
|
368 |
-
# remove text, vision kwargs from kwargs
|
369 |
-
for key in kwargs_text.keys():
|
370 |
-
del kwargs["text_" + key]
|
371 |
-
for key in kwargs_vision.keys():
|
372 |
-
del kwargs["vision_" + key]
|
373 |
-
|
374 |
-
# Load and initialize the text and vision model
|
375 |
-
text_model = kwargs_text.pop("model", None)
|
376 |
-
if text_model is None:
|
377 |
-
assert (
|
378 |
-
text_model_name_or_path is not None
|
379 |
-
), "If `model` is not defined as an argument, a `text_model_name_or_path` has to be defined"
|
380 |
-
from transformers import FlaxAutoModel
|
381 |
-
|
382 |
-
if "config" not in kwargs_text:
|
383 |
-
from transformers import AutoConfig
|
384 |
-
|
385 |
-
text_config = AutoConfig.from_pretrained(text_model_name_or_path)
|
386 |
-
kwargs_text["config"] = text_config
|
387 |
-
|
388 |
-
text_model = FlaxAutoModel.from_pretrained(text_model_name_or_path, *model_args, **kwargs_text)
|
389 |
-
|
390 |
-
vision_model = kwargs_vision.pop("model", None)
|
391 |
-
if vision_model is None:
|
392 |
-
assert (
|
393 |
-
vision_model_name_or_path is not None
|
394 |
-
), "If `model` is not defined as an argument, a `vision_model_name_or_path` has to be defined"
|
395 |
-
from transformers import FlaxAutoModel
|
396 |
-
|
397 |
-
if "config" not in kwargs_vision:
|
398 |
-
from transformers import AutoConfig
|
399 |
-
|
400 |
-
vision_config = AutoConfig.from_pretrained(vision_model_name_or_path)
|
401 |
-
kwargs_vision["config"] = vision_config
|
402 |
-
|
403 |
-
vision_model = FlaxAutoModel.from_pretrained(vision_model_name_or_path, *model_args, **kwargs_vision)
|
404 |
-
|
405 |
-
# instantiate config with corresponding kwargs
|
406 |
-
dtype = kwargs.pop("dtype", jnp.float32)
|
407 |
-
config = HybridCLIPConfig.from_text_vision_configs(text_model.config, vision_model.config, **kwargs)
|
408 |
-
|
409 |
-
# init model
|
410 |
-
model = cls(config, *model_args, dtype=dtype, **kwargs)
|
411 |
-
|
412 |
-
if vision_config.model_type == "clip":
|
413 |
-
model.params["vision_model"]["vision_model"] = vision_model.params["vision_model"]
|
414 |
-
model.params["visual_projection"]["kernel"] = vision_model.params["visual_projection"]["kernel"]
|
415 |
-
else:
|
416 |
-
model.params["vision_model"] = vision_model.params
|
417 |
-
|
418 |
-
model.params["text_model"] = text_model.params
|
419 |
-
|
420 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hybrid_clip/requirements.txt
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
jax>=0.2.8
|
2 |
-
jaxlib>=0.1.59
|
3 |
-
flax>=0.3.4
|
4 |
-
optax>=0.0.8
|
5 |
-
-f https://download.pytorch.org/whl/torch_stable.html
|
6 |
-
torch==1.9.0+cpu
|
7 |
-
-f https://download.pytorch.org/whl/torch_stable.html
|
8 |
-
torchvision==0.10.0+cpu
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hybrid_clip/run_hybrid_clip.py
DELETED
@@ -1,562 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python
|
2 |
-
# coding=utf-8
|
3 |
-
# Copyright 2021 The HuggingFace Team All rights reserved.
|
4 |
-
#
|
5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
-
# you may not use this file except in compliance with the License.
|
7 |
-
# You may obtain a copy of the License at
|
8 |
-
#
|
9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
-
#
|
11 |
-
# Unless required by applicable law or agreed to in writing, software
|
12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
-
# See the License for the specific language governing permissions and
|
15 |
-
# limitations under the License.
|
16 |
-
"""
|
17 |
-
Training a CLIP like dual encoder models using text and vision encoders in the library.
|
18 |
-
|
19 |
-
The script can be used to train CLIP like models for languages other than english by using
|
20 |
-
a text encoder pre-trained in the desired language. Currently this script support the following vision
|
21 |
-
and text models:
|
22 |
-
Vision models: ViT(https://huggingface.co/models?filter=vit), CLIP (https://huggingface.co/models?filter=clip)
|
23 |
-
Text models: BERT, ROBERTa (https://huggingface.co/models?filter=masked-lm)
|
24 |
-
"""
|
25 |
-
|
26 |
-
import json
|
27 |
-
import logging
|
28 |
-
import os
|
29 |
-
import sys
|
30 |
-
import time
|
31 |
-
from dataclasses import dataclass, field
|
32 |
-
from pathlib import Path
|
33 |
-
from typing import Callable, Optional
|
34 |
-
|
35 |
-
import torch
|
36 |
-
from torchvision.datasets import VisionDataset
|
37 |
-
from torchvision.io import ImageReadMode, read_image
|
38 |
-
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
|
39 |
-
from torchvision.transforms.functional import InterpolationMode
|
40 |
-
from tqdm import tqdm
|
41 |
-
|
42 |
-
import jax
|
43 |
-
import jax.numpy as jnp
|
44 |
-
import optax
|
45 |
-
import transformers
|
46 |
-
from flax import jax_utils
|
47 |
-
from flax.jax_utils import unreplicate
|
48 |
-
from flax.training import train_state
|
49 |
-
from flax.training.common_utils import get_metrics, shard, shard_prng_key
|
50 |
-
from modeling_hybrid_clip import FlaxHybridCLIP
|
51 |
-
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, is_tensorboard_available, set_seed
|
52 |
-
|
53 |
-
|
54 |
-
logger = logging.getLogger(__name__)
|
55 |
-
|
56 |
-
# Cache the result
|
57 |
-
has_tensorboard = is_tensorboard_available()
|
58 |
-
if has_tensorboard:
|
59 |
-
try:
|
60 |
-
from flax.metrics.tensorboard import SummaryWriter
|
61 |
-
except ImportError as ie:
|
62 |
-
has_tensorboard = False
|
63 |
-
print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
|
64 |
-
|
65 |
-
else:
|
66 |
-
print(
|
67 |
-
"Unable to display metrics through TensorBoard because the package is not installed: "
|
68 |
-
"Please run pip install tensorboard to enable."
|
69 |
-
)
|
70 |
-
|
71 |
-
|
72 |
-
@dataclass
|
73 |
-
class ModelArguments:
|
74 |
-
"""
|
75 |
-
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
76 |
-
"""
|
77 |
-
|
78 |
-
text_model_name_or_path: str = field(
|
79 |
-
metadata={
|
80 |
-
"help": "The text model checkpoint for weights initialization."
|
81 |
-
"Don't set if you want to train a model from scratch."
|
82 |
-
},
|
83 |
-
)
|
84 |
-
vision_model_name_or_path: str = field(
|
85 |
-
metadata={
|
86 |
-
"help": "The vision model checkpoint for weights initialization."
|
87 |
-
"Don't set if you want to train a model from scratch."
|
88 |
-
},
|
89 |
-
)
|
90 |
-
from_pt: bool = field(
|
91 |
-
default=True,
|
92 |
-
metadata={"help": "whether to load the text and vision model using PyTorch checkpoints."},
|
93 |
-
)
|
94 |
-
config_name: Optional[str] = field(
|
95 |
-
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
96 |
-
)
|
97 |
-
tokenizer_name: Optional[str] = field(
|
98 |
-
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
99 |
-
)
|
100 |
-
cache_dir: Optional[str] = field(
|
101 |
-
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
102 |
-
)
|
103 |
-
use_fast_tokenizer: bool = field(
|
104 |
-
default=True,
|
105 |
-
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
106 |
-
)
|
107 |
-
dtype: Optional[str] = field(
|
108 |
-
default="float32",
|
109 |
-
metadata={
|
110 |
-
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
111 |
-
},
|
112 |
-
)
|
113 |
-
|
114 |
-
|
115 |
-
@dataclass
|
116 |
-
class DataTrainingArguments:
|
117 |
-
"""
|
118 |
-
Arguments pertaining to what data we are going to input our model for training and eval.
|
119 |
-
"""
|
120 |
-
|
121 |
-
data_dir: Optional[str] = field(default=None, metadata={"help": "The data directory containing input files."})
|
122 |
-
train_file: Optional[str] = field(
|
123 |
-
default=None, metadata={"help": "The input training data file (a jsonlines file)."}
|
124 |
-
)
|
125 |
-
validation_file: Optional[str] = field(
|
126 |
-
default=None,
|
127 |
-
metadata={"help": "An optional input evaluation data file (a jsonlines file)."},
|
128 |
-
)
|
129 |
-
max_seq_length: Optional[int] = field(
|
130 |
-
default=72,
|
131 |
-
metadata={
|
132 |
-
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
133 |
-
"than this will be truncated, sequences shorter will be padded."
|
134 |
-
},
|
135 |
-
)
|
136 |
-
max_train_samples: Optional[int] = field(
|
137 |
-
default=None,
|
138 |
-
metadata={
|
139 |
-
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
140 |
-
"value if set."
|
141 |
-
},
|
142 |
-
)
|
143 |
-
max_eval_samples: Optional[int] = field(
|
144 |
-
default=None,
|
145 |
-
metadata={
|
146 |
-
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
147 |
-
"value if set."
|
148 |
-
},
|
149 |
-
)
|
150 |
-
overwrite_cache: bool = field(
|
151 |
-
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
152 |
-
)
|
153 |
-
overwrite_cache: bool = field(
|
154 |
-
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
155 |
-
)
|
156 |
-
preprocessing_num_workers: Optional[int] = field(
|
157 |
-
default=None,
|
158 |
-
metadata={"help": "The number of processes to use for the preprocessing."},
|
159 |
-
)
|
160 |
-
|
161 |
-
def __post_init__(self):
|
162 |
-
if self.train_file is None and self.validation_file is None:
|
163 |
-
raise ValueError("Need either a dataset name or a training/validation file.")
|
164 |
-
else:
|
165 |
-
if self.train_file is not None:
|
166 |
-
extension = self.train_file.split(".")[-1]
|
167 |
-
assert extension == "json", "`train_file` should be a json file."
|
168 |
-
if self.validation_file is not None:
|
169 |
-
extension = self.validation_file.split(".")[-1]
|
170 |
-
assert extension == "json", "`validation_file` should be a json file."
|
171 |
-
|
172 |
-
|
173 |
-
# We use torchvision for faster image pre-processing.
|
174 |
-
# We need to ensure faster processing speed as it can become a bottleneck on TPU
|
175 |
-
class Transform(torch.nn.Module):
|
176 |
-
def __init__(self, image_size):
|
177 |
-
super().__init__()
|
178 |
-
self.transforms = torch.nn.Sequential(
|
179 |
-
Resize([image_size], interpolation=InterpolationMode.BICUBIC),
|
180 |
-
CenterCrop(image_size),
|
181 |
-
ConvertImageDtype(torch.float),
|
182 |
-
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
183 |
-
)
|
184 |
-
|
185 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
186 |
-
with torch.no_grad():
|
187 |
-
x = self.transforms(x)
|
188 |
-
return x
|
189 |
-
|
190 |
-
|
191 |
-
class ImageTextDataset(VisionDataset):
|
192 |
-
"""
|
193 |
-
Dtaset for loading image-text data for tasks like CLIP training, Image Captioning.
|
194 |
-
|
195 |
-
Args:
|
196 |
-
root: (string): The root path where the dataset is stored
|
197 |
-
file_path: (string): Path to the file containing the image_paths and associated captions.
|
198 |
-
The expected format is jsonlines where each line is a json object containing to keys.
|
199 |
-
`image_path`: The path to the image.
|
200 |
-
`captions`: An `array` of captions.
|
201 |
-
transform (callable, optional): A function/transform that takes in an PIL image
|
202 |
-
and returns a transformed version. E.g, ``transforms.ToTensor``
|
203 |
-
target_transform (callable, optional): A function/transform that takes in the
|
204 |
-
target and transforms it.
|
205 |
-
transforms (callable, optional): A function/transform that takes input sample and its target as entry
|
206 |
-
and returns a transformed version.
|
207 |
-
"""
|
208 |
-
|
209 |
-
def __init__(
|
210 |
-
self,
|
211 |
-
root: str,
|
212 |
-
file_path: str,
|
213 |
-
captions_per_image=2,
|
214 |
-
transform: Optional[Callable] = None,
|
215 |
-
target_transform: Optional[Callable] = None,
|
216 |
-
transforms: Optional[Callable] = None,
|
217 |
-
):
|
218 |
-
super().__init__(root, transforms, transform, target_transform)
|
219 |
-
|
220 |
-
with open(file_path, "r") as f:
|
221 |
-
examples = [json.loads(line) for line in f.readlines()]
|
222 |
-
|
223 |
-
self.captions = []
|
224 |
-
self.image_paths = []
|
225 |
-
|
226 |
-
for example in examples:
|
227 |
-
self.captions.extend(example["captions"][:captions_per_image])
|
228 |
-
self.image_paths.extend([example["image_path"]] * captions_per_image)
|
229 |
-
|
230 |
-
def _load_image(self, idx: int):
|
231 |
-
path = self.image_paths[idx]
|
232 |
-
return read_image(path, mode=ImageReadMode.RGB)
|
233 |
-
|
234 |
-
def _load_target(self, idx):
|
235 |
-
return self.captions[idx]
|
236 |
-
|
237 |
-
def __getitem__(self, index: int):
|
238 |
-
image = self._load_image(index)
|
239 |
-
target = self._load_target(index)
|
240 |
-
|
241 |
-
if self.transforms is not None:
|
242 |
-
image, target = self.transforms(image, target)
|
243 |
-
|
244 |
-
return image, target
|
245 |
-
|
246 |
-
def __len__(self) -> int:
|
247 |
-
return len(self.captions)
|
248 |
-
|
249 |
-
|
250 |
-
class TrainState(train_state.TrainState):
|
251 |
-
dropout_rng: jnp.ndarray
|
252 |
-
|
253 |
-
def replicate(self):
|
254 |
-
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
255 |
-
|
256 |
-
|
257 |
-
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
258 |
-
summary_writer.scalar("train_time", train_time, step)
|
259 |
-
|
260 |
-
train_metrics = get_metrics(train_metrics)
|
261 |
-
for key, vals in train_metrics.items():
|
262 |
-
tag = f"train_{key}"
|
263 |
-
for i, val in enumerate(vals):
|
264 |
-
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
265 |
-
|
266 |
-
for metric_name, value in eval_metrics.items():
|
267 |
-
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
268 |
-
|
269 |
-
|
270 |
-
def create_learning_rate_fn(
|
271 |
-
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
272 |
-
) -> Callable[[int], jnp.array]:
|
273 |
-
"""Returns a linear warmup, linear_decay learning rate function."""
|
274 |
-
steps_per_epoch = train_ds_size // train_batch_size
|
275 |
-
num_train_steps = steps_per_epoch * num_train_epochs
|
276 |
-
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
277 |
-
decay_fn = optax.linear_schedule(
|
278 |
-
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
279 |
-
)
|
280 |
-
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
281 |
-
return schedule_fn
|
282 |
-
|
283 |
-
|
284 |
-
def main():
|
285 |
-
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
286 |
-
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
287 |
-
# If we pass only one argument to the script and it's the path to a json file,
|
288 |
-
# let's parse it to get our arguments.
|
289 |
-
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
290 |
-
else:
|
291 |
-
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
292 |
-
|
293 |
-
if (
|
294 |
-
os.path.exists(training_args.output_dir)
|
295 |
-
and os.listdir(training_args.output_dir)
|
296 |
-
and training_args.do_train
|
297 |
-
and not training_args.overwrite_output_dir
|
298 |
-
):
|
299 |
-
raise ValueError(
|
300 |
-
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
301 |
-
"Use --overwrite_output_dir to overcome."
|
302 |
-
)
|
303 |
-
|
304 |
-
# Make one log on every process with the configuration for debugging.
|
305 |
-
logging.basicConfig(
|
306 |
-
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
307 |
-
datefmt="%m/%d/%Y %H:%M:%S",
|
308 |
-
level=logging.INFO,
|
309 |
-
)
|
310 |
-
# Setup logging, we only want one process per machine to log things on the screen.
|
311 |
-
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
|
312 |
-
if jax.process_index() == 0:
|
313 |
-
transformers.utils.logging.set_verbosity_info()
|
314 |
-
else:
|
315 |
-
transformers.utils.logging.set_verbosity_error()
|
316 |
-
|
317 |
-
# Set the verbosity to info of the Transformers logger (on main process only):
|
318 |
-
logger.info(f"Training/evaluation parameters {training_args}")
|
319 |
-
|
320 |
-
if model_args.tokenizer_name:
|
321 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
322 |
-
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
323 |
-
)
|
324 |
-
elif model_args.text_model_name_or_path:
|
325 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
326 |
-
model_args.text_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
327 |
-
)
|
328 |
-
else:
|
329 |
-
raise ValueError(
|
330 |
-
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
331 |
-
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
332 |
-
)
|
333 |
-
|
334 |
-
model = FlaxHybridCLIP.from_text_vision_pretrained(
|
335 |
-
model_args.text_model_name_or_path,
|
336 |
-
model_args.vision_model_name_or_path,
|
337 |
-
seed=training_args.seed,
|
338 |
-
dtype=getattr(jnp, model_args.dtype),
|
339 |
-
text_from_pt=model_args.from_pt,
|
340 |
-
vision_from_pt=model_args.from_pt,
|
341 |
-
)
|
342 |
-
config = model.config
|
343 |
-
# set seed for torch dataloaders
|
344 |
-
set_seed(training_args.seed)
|
345 |
-
|
346 |
-
# Initialize torchvision transforms and jit them for faster processing
|
347 |
-
preprocess = Transform(config.vision_config.image_size)
|
348 |
-
preprocess = torch.jit.script(preprocess)
|
349 |
-
|
350 |
-
# Initialize the image-text dataset
|
351 |
-
train_dataset = ImageTextDataset(
|
352 |
-
data_args.data_dir,
|
353 |
-
data_args.train_file,
|
354 |
-
captions_per_image=2,
|
355 |
-
transform=preprocess,
|
356 |
-
)
|
357 |
-
|
358 |
-
eval_dataset = ImageTextDataset(
|
359 |
-
data_args.data_dir,
|
360 |
-
data_args.validation_file,
|
361 |
-
captions_per_image=1,
|
362 |
-
transform=preprocess,
|
363 |
-
)
|
364 |
-
|
365 |
-
# Store some constant
|
366 |
-
num_epochs = int(training_args.num_train_epochs)
|
367 |
-
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
368 |
-
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
369 |
-
steps_per_epoch = len(train_dataset) // train_batch_size
|
370 |
-
total_train_steps = steps_per_epoch * num_epochs
|
371 |
-
|
372 |
-
# Use collate function to tokenizer the text and convert the processed images to numpy
|
373 |
-
def collate_fn(examples):
|
374 |
-
pixel_values = torch.stack([example[0] for example in examples]).permute(0, 2, 3, 1).numpy()
|
375 |
-
captions = [example[1] for example in examples]
|
376 |
-
inputs = tokenizer(captions, max_length=data_args.max_seq_length, padding="max_length", return_tensors="np")
|
377 |
-
|
378 |
-
batch = {
|
379 |
-
"pixel_values": pixel_values,
|
380 |
-
"input_ids": inputs["input_ids"],
|
381 |
-
"attention_mask": inputs["attention_mask"],
|
382 |
-
}
|
383 |
-
|
384 |
-
return batch
|
385 |
-
|
386 |
-
# Create data loaders
|
387 |
-
train_loader = torch.utils.data.DataLoader(
|
388 |
-
train_dataset,
|
389 |
-
batch_size=train_batch_size,
|
390 |
-
shuffle=True,
|
391 |
-
num_workers=data_args.preprocessing_num_workers,
|
392 |
-
persistent_workers=True,
|
393 |
-
drop_last=True,
|
394 |
-
collate_fn=collate_fn,
|
395 |
-
)
|
396 |
-
|
397 |
-
eval_loader = torch.utils.data.DataLoader(
|
398 |
-
eval_dataset,
|
399 |
-
batch_size=eval_batch_size,
|
400 |
-
shuffle=False,
|
401 |
-
num_workers=data_args.preprocessing_num_workers,
|
402 |
-
persistent_workers=True,
|
403 |
-
drop_last=True,
|
404 |
-
collate_fn=collate_fn,
|
405 |
-
)
|
406 |
-
|
407 |
-
# Enable tensorboard only on the master node
|
408 |
-
if has_tensorboard and jax.process_index() == 0:
|
409 |
-
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix())
|
410 |
-
|
411 |
-
# Initialize our training
|
412 |
-
rng = jax.random.PRNGKey(training_args.seed)
|
413 |
-
rng, dropout_rng = jax.random.split(rng)
|
414 |
-
|
415 |
-
# Create learning rate schedule
|
416 |
-
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
417 |
-
len(train_dataset),
|
418 |
-
train_batch_size,
|
419 |
-
training_args.num_train_epochs,
|
420 |
-
training_args.warmup_steps,
|
421 |
-
training_args.learning_rate,
|
422 |
-
)
|
423 |
-
|
424 |
-
# create adam optimizer
|
425 |
-
adamw = optax.adamw(
|
426 |
-
learning_rate=linear_decay_lr_schedule_fn,
|
427 |
-
b1=training_args.adam_beta1,
|
428 |
-
b2=training_args.adam_beta2,
|
429 |
-
eps=training_args.adam_epsilon,
|
430 |
-
weight_decay=training_args.weight_decay,
|
431 |
-
)
|
432 |
-
|
433 |
-
# Setup train state
|
434 |
-
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
|
435 |
-
|
436 |
-
def cross_entropy(logits, axis):
|
437 |
-
logprobs = jax.nn.log_softmax(logits, axis=axis)
|
438 |
-
nll = jnp.diag(logprobs)
|
439 |
-
ce = -jnp.mean(nll)
|
440 |
-
return ce
|
441 |
-
|
442 |
-
def clip_loss(similarity):
|
443 |
-
loss = (cross_entropy(similarity, axis=0) + cross_entropy(similarity, axis=1)) / 2
|
444 |
-
return loss
|
445 |
-
|
446 |
-
# Define gradient update step fn
|
447 |
-
def train_step(state, batch):
|
448 |
-
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
449 |
-
|
450 |
-
def compute_loss(params):
|
451 |
-
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
452 |
-
loss = clip_loss(logits)
|
453 |
-
return loss
|
454 |
-
|
455 |
-
grad_fn = jax.value_and_grad(compute_loss)
|
456 |
-
loss, grad = grad_fn(state.params)
|
457 |
-
grad = jax.lax.pmean(grad, "batch")
|
458 |
-
|
459 |
-
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
460 |
-
|
461 |
-
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
462 |
-
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
463 |
-
|
464 |
-
return new_state, metrics
|
465 |
-
|
466 |
-
# Define eval fn
|
467 |
-
def eval_step(params, batch):
|
468 |
-
logits = model(**batch, params=params, train=False)[0]
|
469 |
-
loss = clip_loss(logits)
|
470 |
-
|
471 |
-
# summarize metrics
|
472 |
-
metrics = {"loss": loss}
|
473 |
-
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
474 |
-
return metrics
|
475 |
-
|
476 |
-
# Create parallel version of the train and eval step
|
477 |
-
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
478 |
-
p_eval_step = jax.pmap(eval_step, "batch")
|
479 |
-
|
480 |
-
# Replicate the train state on each device
|
481 |
-
state = state.replicate()
|
482 |
-
|
483 |
-
logger.info("***** Running training *****")
|
484 |
-
logger.info(f" Num examples = {len(train_dataset)}")
|
485 |
-
logger.info(f" Num Epochs = {num_epochs}")
|
486 |
-
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
|
487 |
-
logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
|
488 |
-
logger.info(f" Total optimization steps = {total_train_steps}")
|
489 |
-
|
490 |
-
train_time = 0
|
491 |
-
# Create sampling rng
|
492 |
-
rng, input_rng = jax.random.split(rng)
|
493 |
-
|
494 |
-
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
495 |
-
for epoch in epochs:
|
496 |
-
# ======================== Training ================================
|
497 |
-
train_start = time.time()
|
498 |
-
|
499 |
-
# Create sampling rng
|
500 |
-
rng, input_rng = jax.random.split(rng)
|
501 |
-
train_metrics = []
|
502 |
-
|
503 |
-
steps_per_epoch = len(train_dataset) // train_batch_size
|
504 |
-
train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
|
505 |
-
# train
|
506 |
-
for batch in train_loader:
|
507 |
-
batch = shard(batch)
|
508 |
-
state, train_metric = p_train_step(state, batch)
|
509 |
-
train_metrics.append(train_metric)
|
510 |
-
|
511 |
-
train_step_progress_bar.update(1)
|
512 |
-
|
513 |
-
train_time += time.time() - train_start
|
514 |
-
|
515 |
-
train_metric = unreplicate(train_metric)
|
516 |
-
|
517 |
-
train_step_progress_bar.close()
|
518 |
-
epochs.write(
|
519 |
-
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
520 |
-
)
|
521 |
-
|
522 |
-
# ======================== Evaluating ==============================
|
523 |
-
eval_metrics = []
|
524 |
-
eval_steps = len(eval_dataset) // eval_batch_size
|
525 |
-
eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False)
|
526 |
-
for batch in eval_loader:
|
527 |
-
# Model forward
|
528 |
-
batch = shard(batch)
|
529 |
-
metrics = p_eval_step(state.params, batch)
|
530 |
-
eval_metrics.append(metrics)
|
531 |
-
|
532 |
-
eval_step_progress_bar.update(1)
|
533 |
-
|
534 |
-
# normalize eval metrics
|
535 |
-
eval_metrics = get_metrics(eval_metrics)
|
536 |
-
|
537 |
-
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
538 |
-
|
539 |
-
# Print metrics and update progress bar
|
540 |
-
eval_step_progress_bar.close()
|
541 |
-
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
542 |
-
epochs.write(desc)
|
543 |
-
epochs.desc = desc
|
544 |
-
|
545 |
-
# Save metrics
|
546 |
-
if has_tensorboard and jax.process_index() == 0:
|
547 |
-
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
548 |
-
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
|
549 |
-
|
550 |
-
# save checkpoint after each epoch and push checkpoint to the hub
|
551 |
-
if jax.process_index() == 0:
|
552 |
-
params = jax.device_get(unreplicate(state.params))
|
553 |
-
model.save_pretrained(
|
554 |
-
training_args.output_dir,
|
555 |
-
params=params,
|
556 |
-
push_to_hub=training_args.push_to_hub,
|
557 |
-
commit_message=f"Saving weights and logs of epoch {epoch+1}",
|
558 |
-
)
|
559 |
-
|
560 |
-
|
561 |
-
if __name__ == "__main__":
|
562 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|