g8a9 commited on
Commit
d4dbff1
1 Parent(s): 4d52970

Add test model (CC+WIT+COCO) 7 epochs

Browse files
README.md DELETED
@@ -1,3 +0,0 @@
1
- # clip-italian
2
-
3
- # TODO
 
 
 
 
config.json ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "HybridCLIP"
4
+ ],
5
+ "initializer_factor": 1.0,
6
+ "model_type": "hybrid-clip",
7
+ "projection_dim": 512,
8
+ "seed": 42,
9
+ "text_config": {
10
+ "_name_or_path": "",
11
+ "add_cross_attention": false,
12
+ "architectures": [
13
+ "BertForMaskedLM"
14
+ ],
15
+ "attention_probs_dropout_prob": 0.1,
16
+ "bad_words_ids": null,
17
+ "bos_token_id": null,
18
+ "chunk_size_feed_forward": 0,
19
+ "decoder_start_token_id": null,
20
+ "diversity_penalty": 0.0,
21
+ "do_sample": false,
22
+ "early_stopping": false,
23
+ "encoder_no_repeat_ngram_size": 0,
24
+ "eos_token_id": null,
25
+ "finetuning_task": null,
26
+ "forced_bos_token_id": null,
27
+ "forced_eos_token_id": null,
28
+ "gradient_checkpointing": false,
29
+ "hidden_act": "gelu",
30
+ "hidden_dropout_prob": 0.1,
31
+ "hidden_size": 768,
32
+ "id2label": {
33
+ "0": "LABEL_0",
34
+ "1": "LABEL_1"
35
+ },
36
+ "initializer_range": 0.02,
37
+ "intermediate_size": 3072,
38
+ "is_decoder": false,
39
+ "is_encoder_decoder": false,
40
+ "label2id": {
41
+ "LABEL_0": 0,
42
+ "LABEL_1": 1
43
+ },
44
+ "layer_norm_eps": 1e-12,
45
+ "length_penalty": 1.0,
46
+ "max_length": 20,
47
+ "max_position_embeddings": 512,
48
+ "min_length": 0,
49
+ "model_type": "bert",
50
+ "no_repeat_ngram_size": 0,
51
+ "num_attention_heads": 12,
52
+ "num_beam_groups": 1,
53
+ "num_beams": 1,
54
+ "num_hidden_layers": 12,
55
+ "num_return_sequences": 1,
56
+ "output_attentions": false,
57
+ "output_hidden_states": false,
58
+ "output_scores": false,
59
+ "pad_token_id": 0,
60
+ "position_embedding_type": "absolute",
61
+ "prefix": null,
62
+ "problem_type": null,
63
+ "pruned_heads": {},
64
+ "remove_invalid_values": false,
65
+ "repetition_penalty": 1.0,
66
+ "return_dict": true,
67
+ "return_dict_in_generate": false,
68
+ "sep_token_id": null,
69
+ "task_specific_params": null,
70
+ "temperature": 1.0,
71
+ "tie_encoder_decoder": false,
72
+ "tie_word_embeddings": true,
73
+ "tokenizer_class": null,
74
+ "top_k": 50,
75
+ "top_p": 1.0,
76
+ "torch_dtype": null,
77
+ "torchscript": false,
78
+ "transformers_version": "4.9.0.dev0",
79
+ "type_vocab_size": 2,
80
+ "use_bfloat16": false,
81
+ "use_cache": true,
82
+ "vocab_size": 32102
83
+ },
84
+ "transformers_version": null,
85
+ "vision_config": {
86
+ "_name_or_path": "",
87
+ "add_cross_attention": false,
88
+ "architectures": null,
89
+ "attention_dropout": 0.0,
90
+ "bad_words_ids": null,
91
+ "bos_token_id": null,
92
+ "chunk_size_feed_forward": 0,
93
+ "decoder_start_token_id": null,
94
+ "diversity_penalty": 0.0,
95
+ "do_sample": false,
96
+ "dropout": 0.0,
97
+ "early_stopping": false,
98
+ "encoder_no_repeat_ngram_size": 0,
99
+ "eos_token_id": null,
100
+ "finetuning_task": null,
101
+ "forced_bos_token_id": null,
102
+ "forced_eos_token_id": null,
103
+ "gradient_checkpointing": false,
104
+ "hidden_act": "quick_gelu",
105
+ "hidden_size": 768,
106
+ "id2label": {
107
+ "0": "LABEL_0",
108
+ "1": "LABEL_1"
109
+ },
110
+ "image_size": 224,
111
+ "initializer_factor": 1.0,
112
+ "initializer_range": 0.02,
113
+ "intermediate_size": 3072,
114
+ "is_decoder": false,
115
+ "is_encoder_decoder": false,
116
+ "label2id": {
117
+ "LABEL_0": 0,
118
+ "LABEL_1": 1
119
+ },
120
+ "layer_norm_eps": 1e-05,
121
+ "length_penalty": 1.0,
122
+ "max_length": 20,
123
+ "min_length": 0,
124
+ "model_type": "clip_vision_model",
125
+ "no_repeat_ngram_size": 0,
126
+ "num_attention_heads": 12,
127
+ "num_beam_groups": 1,
128
+ "num_beams": 1,
129
+ "num_hidden_layers": 12,
130
+ "num_return_sequences": 1,
131
+ "output_attentions": false,
132
+ "output_hidden_states": false,
133
+ "output_scores": false,
134
+ "pad_token_id": null,
135
+ "patch_size": 32,
136
+ "prefix": null,
137
+ "problem_type": null,
138
+ "pruned_heads": {},
139
+ "remove_invalid_values": false,
140
+ "repetition_penalty": 1.0,
141
+ "return_dict": true,
142
+ "return_dict_in_generate": false,
143
+ "sep_token_id": null,
144
+ "task_specific_params": null,
145
+ "temperature": 1.0,
146
+ "tie_encoder_decoder": false,
147
+ "tie_word_embeddings": true,
148
+ "tokenizer_class": null,
149
+ "top_k": 50,
150
+ "top_p": 1.0,
151
+ "torch_dtype": null,
152
+ "torchscript": false,
153
+ "transformers_version": "4.9.0.dev0",
154
+ "use_bfloat16": false
155
+ }
156
+ }
evaluation/imagenet_validation_script.ipynb DELETED
@@ -1,1003 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {},
6
- "source": [
7
- "# Imagenet Evaluation Script\n",
8
- "modified from [the evluation script by OpenAI](https://colab.research.google.com/github/openai/clip/blob/master/notebooks/Prompt_Engineering_for_ImageNet.ipynb)."
9
- ]
10
- },
11
- {
12
- "cell_type": "code",
13
- "execution_count": 3,
14
- "metadata": {},
15
- "outputs": [],
16
- "source": [
17
- "import json\n",
18
- "from modeling_hybrid_clip import FlaxHybridCLIP\n",
19
- "from configuration_hybrid_clip import HybridCLIPConfig\n",
20
- "\n",
21
- "import jax\n",
22
- "from jax import numpy as jnp\n",
23
- "\n",
24
- "import os \n",
25
- "os.environ['TOKENIZERS_PARALLELISM'] = \"false\"\n",
26
- "\n",
27
- "import transformers\n",
28
- "from transformers import AutoTokenizer\n",
29
- "\n",
30
- "import numpy as np\n",
31
- "import torch\n",
32
- "import torchvision\n",
33
- "from torchvision import transforms\n",
34
- "from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize, ColorJitter, RandomHorizontalFlip, RandomRotation, ToTensor, Lambda\n",
35
- "from torchvision.transforms.functional import InterpolationMode\n",
36
- "from tqdm.notebook import tqdm"
37
- ]
38
- },
39
- {
40
- "cell_type": "code",
41
- "execution_count": 4,
42
- "metadata": {},
43
- "outputs": [],
44
- "source": [
45
- "# Global Variables\n",
46
- "\n",
47
- "LANGUAGE = 'en'\n",
48
- "IMAGENET_ROOT = \"/home/raphaelp/imagenet_root/\"\n",
49
- "\n",
50
- "CONFIG_FILE = '/home/raphaelp/clip-base/config.json'\n",
51
- "MODEL_FILE = '/home/raphaelp/clip-base/checkpoint/flax_model.msgpack'\n",
52
- "\n",
53
- "\n",
54
- "if LANGUAGE == 'en':\n",
55
- " TOKENIZER_NAME = \"roberta-base\" \n",
56
- "elif LANGUAGE == 'it':\n",
57
- " TOKENIZER_NAME = \"dbmdz/bert-base-italian-cased\""
58
- ]
59
- },
60
- {
61
- "cell_type": "markdown",
62
- "metadata": {
63
- "id": "eFxgLV5HAEEw"
64
- },
65
- "source": [
66
- "# Loading the model"
67
- ]
68
- },
69
- {
70
- "cell_type": "code",
71
- "execution_count": 10,
72
- "metadata": {},
73
- "outputs": [],
74
- "source": [
75
- "tokenizer = AutoTokenizer.from_pretrained(\n",
76
- " TOKENIZER_NAME, cache_dir=None, use_fast=True\n",
77
- ")"
78
- ]
79
- },
80
- {
81
- "cell_type": "code",
82
- "execution_count": 5,
83
- "metadata": {
84
- "tags": []
85
- },
86
- "outputs": [
87
- {
88
- "name": "stderr",
89
- "output_type": "stream",
90
- "text": [
91
- "INFO:absl:Starting the local TPU driver.\n",
92
- "INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
93
- "INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: Interpreter TPU Host\n",
94
- "2021-07-07 14:04:30.725543: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n"
95
- ]
96
- }
97
- ],
98
- "source": [
99
- "with open(CONFIG_FILE, 'r') as f:\n",
100
- " config_dict = json.load(f)\n",
101
- "config_dict['vision_config']['model_type'] = 'clip'\n",
102
- "config = HybridCLIPConfig(text_config_dict=config_dict['text_config'], vision_config_dict=config_dict['vision_config'])\n",
103
- "model = FlaxHybridCLIP.from_pretrained(MODEL_FILE, config=config)"
104
- ]
105
- },
106
- {
107
- "cell_type": "code",
108
- "execution_count": 6,
109
- "metadata": {
110
- "colab": {
111
- "base_uri": "https://localhost:8080/"
112
- },
113
- "id": "IBRVTY9lbGm8",
114
- "outputId": "58641dc2-919d-40ae-b71a-7b7b47830f77"
115
- },
116
- "outputs": [
117
- {
118
- "name": "stdout",
119
- "output_type": "stream",
120
- "text": [
121
- "Input resolution: 224\n",
122
- "Context length: 20\n",
123
- "Vocab size: 50265\n"
124
- ]
125
- }
126
- ],
127
- "source": [
128
- "print(\"Input resolution:\", config.vision_config.image_size)\n",
129
- "print(\"Context length:\", config.text_config.max_length)\n",
130
- "print(\"Vocab size:\", config.text_config.vocab_size)"
131
- ]
132
- },
133
- {
134
- "cell_type": "markdown",
135
- "metadata": {
136
- "id": "LhO3OtOmF8M4"
137
- },
138
- "source": [
139
- "# Preparing ImageNet labels and prompts\n",
140
- "\n",
141
- "The following cell contains the 1,000 labels for the ImageNet dataset, followed by the text templates we'll use as \"prompt engineering\"."
142
- ]
143
- },
144
- {
145
- "cell_type": "code",
146
- "execution_count": 7,
147
- "metadata": {
148
- "id": "R2HbOZrqa0jF"
149
- },
150
- "outputs": [],
151
- "source": [
152
- "if LANGUAGE == 'en':\n",
153
- " imagenet_classes = [\"tench\", \"goldfish\", \"great white shark\", \"tiger shark\", \"hammerhead shark\", \"electric ray\", \"stingray\", \"rooster\", \"hen\", \"ostrich\", \"brambling\", \"goldfinch\", \"house finch\", \"junco\", \"indigo bunting\", \"American robin\", \"bulbul\", \"jay\", \"magpie\", \"chickadee\", \"American dipper\", \"kite (bird of prey)\", \"bald eagle\", \"vulture\", \"great grey owl\", \"fire salamander\", \"smooth newt\", \"newt\", \"spotted salamander\", \"axolotl\", \"American bullfrog\", \"tree frog\", \"tailed frog\", \"loggerhead sea turtle\", \"leatherback sea turtle\", \"mud turtle\", \"terrapin\", \"box turtle\", \"banded gecko\", \"green iguana\", \"Carolina anole\", \"desert grassland whiptail lizard\", \"agama\", \"frilled-necked lizard\", \"alligator lizard\", \"Gila monster\", \"European green lizard\", \"chameleon\", \"Komodo dragon\", \"Nile crocodile\", \"American alligator\", \"triceratops\", \"worm snake\", \"ring-necked snake\", \"eastern hog-nosed snake\", \"smooth green snake\", \"kingsnake\", \"garter snake\", \"water snake\", \"vine snake\", \"night snake\", \"boa constrictor\", \"African rock python\", \"Indian cobra\", \"green mamba\", \"sea snake\", \"Saharan horned viper\", \"eastern diamondback rattlesnake\", \"sidewinder rattlesnake\", \"trilobite\", \"harvestman\", \"scorpion\", \"yellow garden spider\", \"barn spider\", \"European garden spider\", \"southern black widow\", \"tarantula\", \"wolf spider\", \"tick\", \"centipede\", \"black grouse\", \"ptarmigan\", \"ruffed grouse\", \"prairie grouse\", \"peafowl\", \"quail\", \"partridge\", \"african grey parrot\", \"macaw\", \"sulphur-crested cockatoo\", \"lorikeet\", \"coucal\", \"bee eater\", \"hornbill\", \"hummingbird\", \"jacamar\", \"toucan\", \"duck\", \"red-breasted merganser\", \"goose\", \"black swan\", \"tusker\", \"echidna\", \"platypus\", \"wallaby\", \"koala\", \"wombat\", \"jellyfish\", \"sea anemone\", \"brain coral\", \"flatworm\", \"nematode\", \"conch\", \"snail\", \"slug\", \"sea slug\", \"chiton\", \"chambered nautilus\", \"Dungeness crab\", \"rock crab\", \"fiddler crab\", \"red king crab\", \"American lobster\", \"spiny lobster\", \"crayfish\", \"hermit crab\", \"isopod\", \"white stork\", \"black stork\", \"spoonbill\", \"flamingo\", \"little blue heron\", \"great egret\", \"bittern bird\", \"crane bird\", \"limpkin\", \"common gallinule\", \"American coot\", \"bustard\", \"ruddy turnstone\", \"dunlin\", \"common redshank\", \"dowitcher\", \"oystercatcher\", \"pelican\", \"king penguin\", \"albatross\", \"grey whale\", \"killer whale\", \"dugong\", \"sea lion\", \"Chihuahua\", \"Japanese Chin\", \"Maltese\", \"Pekingese\", \"Shih Tzu\", \"King Charles Spaniel\", \"Papillon\", \"toy terrier\", \"Rhodesian Ridgeback\", \"Afghan Hound\", \"Basset Hound\", \"Beagle\", \"Bloodhound\", \"Bluetick Coonhound\", \"Black and Tan Coonhound\", \"Treeing Walker Coonhound\", \"English foxhound\", \"Redbone Coonhound\", \"borzoi\", \"Irish Wolfhound\", \"Italian Greyhound\", \"Whippet\", \"Ibizan Hound\", \"Norwegian Elkhound\", \"Otterhound\", \"Saluki\", \"Scottish Deerhound\", \"Weimaraner\", \"Staffordshire Bull Terrier\", \"American Staffordshire Terrier\", \"Bedlington Terrier\", \"Border Terrier\", \"Kerry Blue Terrier\", \"Irish Terrier\", \"Norfolk Terrier\", \"Norwich Terrier\", \"Yorkshire Terrier\", \"Wire Fox Terrier\", \"Lakeland Terrier\", \"Sealyham Terrier\", \"Airedale Terrier\", \"Cairn Terrier\", \"Australian Terrier\", \"Dandie Dinmont Terrier\", \"Boston Terrier\", \"Miniature Schnauzer\", \"Giant Schnauzer\", \"Standard Schnauzer\", \"Scottish Terrier\", \"Tibetan Terrier\", \"Australian Silky Terrier\", \"Soft-coated Wheaten Terrier\", \"West Highland White Terrier\", \"Lhasa Apso\", \"Flat-Coated Retriever\", \"Curly-coated Retriever\", \"Golden Retriever\", \"Labrador Retriever\", \"Chesapeake Bay Retriever\", \"German Shorthaired Pointer\", \"Vizsla\", \"English Setter\", \"Irish Setter\", \"Gordon Setter\", \"Brittany dog\", \"Clumber Spaniel\", \"English Springer Spaniel\", \"Welsh Springer Spaniel\", \"Cocker Spaniel\", \"Sussex Spaniel\", \"Irish Water Spaniel\", \"Kuvasz\", \"Schipperke\", \"Groenendael dog\", \"Malinois\", \"Briard\", \"Australian Kelpie\", \"Komondor\", \"Old English Sheepdog\", \"Shetland Sheepdog\", \"collie\", \"Border Collie\", \"Bouvier des Flandres dog\", \"Rottweiler\", \"German Shepherd Dog\", \"Dobermann\", \"Miniature Pinscher\", \"Greater Swiss Mountain Dog\", \"Bernese Mountain Dog\", \"Appenzeller Sennenhund\", \"Entlebucher Sennenhund\", \"Boxer\", \"Bullmastiff\", \"Tibetan Mastiff\", \"French Bulldog\", \"Great Dane\", \"St. Bernard\", \"husky\", \"Alaskan Malamute\", \"Siberian Husky\", \"Dalmatian\", \"Affenpinscher\", \"Basenji\", \"pug\", \"Leonberger\", \"Newfoundland dog\", \"Great Pyrenees dog\", \"Samoyed\", \"Pomeranian\", \"Chow Chow\", \"Keeshond\", \"brussels griffon\", \"Pembroke Welsh Corgi\", \"Cardigan Welsh Corgi\", \"Toy Poodle\", \"Miniature Poodle\", \"Standard Poodle\", \"Mexican hairless dog (xoloitzcuintli)\", \"grey wolf\", \"Alaskan tundra wolf\", \"red wolf or maned wolf\", \"coyote\", \"dingo\", \"dhole\", \"African wild dog\", \"hyena\", \"red fox\", \"kit fox\", \"Arctic fox\", \"grey fox\", \"tabby cat\", \"tiger cat\", \"Persian cat\", \"Siamese cat\", \"Egyptian Mau\", \"cougar\", \"lynx\", \"leopard\", \"snow leopard\", \"jaguar\", \"lion\", \"tiger\", \"cheetah\", \"brown bear\", \"American black bear\", \"polar bear\", \"sloth bear\", \"mongoose\", \"meerkat\", \"tiger beetle\", \"ladybug\", \"ground beetle\", \"longhorn beetle\", \"leaf beetle\", \"dung beetle\", \"rhinoceros beetle\", \"weevil\", \"fly\", \"bee\", \"ant\", \"grasshopper\", \"cricket insect\", \"stick insect\", \"cockroach\", \"praying mantis\", \"cicada\", \"leafhopper\", \"lacewing\", \"dragonfly\", \"damselfly\", \"red admiral butterfly\", \"ringlet butterfly\", \"monarch butterfly\", \"small white butterfly\", \"sulphur butterfly\", \"gossamer-winged butterfly\", \"starfish\", \"sea urchin\", \"sea cucumber\", \"cottontail rabbit\", \"hare\", \"Angora rabbit\", \"hamster\", \"porcupine\", \"fox squirrel\", \"marmot\", \"beaver\", \"guinea pig\", \"common sorrel horse\", \"zebra\", \"pig\", \"wild boar\", \"warthog\", \"hippopotamus\", \"ox\", \"water buffalo\", \"bison\", \"ram (adult male sheep)\", \"bighorn sheep\", \"Alpine ibex\", \"hartebeest\", \"impala (antelope)\", \"gazelle\", \"arabian camel\", \"llama\", \"weasel\", \"mink\", \"European polecat\", \"black-footed ferret\", \"otter\", \"skunk\", \"badger\", \"armadillo\", \"three-toed sloth\", \"orangutan\", \"gorilla\", \"chimpanzee\", \"gibbon\", \"siamang\", \"guenon\", \"patas monkey\", \"baboon\", \"macaque\", \"langur\", \"black-and-white colobus\", \"proboscis monkey\", \"marmoset\", \"white-headed capuchin\", \"howler monkey\", \"titi monkey\", \"Geoffroy's spider monkey\", \"common squirrel monkey\", \"ring-tailed lemur\", \"indri\", \"Asian elephant\", \"African bush elephant\", \"red panda\", \"giant panda\", \"snoek fish\", \"eel\", \"silver salmon\", \"rock beauty fish\", \"clownfish\", \"sturgeon\", \"gar fish\", \"lionfish\", \"pufferfish\", \"abacus\", \"abaya\", \"academic gown\", \"accordion\", \"acoustic guitar\", \"aircraft carrier\", \"airliner\", \"airship\", \"altar\", \"ambulance\", \"amphibious vehicle\", \"analog clock\", \"apiary\", \"apron\", \"trash can\", \"assault rifle\", \"backpack\", \"bakery\", \"balance beam\", \"balloon\", \"ballpoint pen\", \"Band-Aid\", \"banjo\", \"baluster / handrail\", \"barbell\", \"barber chair\", \"barbershop\", \"barn\", \"barometer\", \"barrel\", \"wheelbarrow\", \"baseball\", \"basketball\", \"bassinet\", \"bassoon\", \"swimming cap\", \"bath towel\", \"bathtub\", \"station wagon\", \"lighthouse\", \"beaker\", \"military hat (bearskin or shako)\", \"beer bottle\", \"beer glass\", \"bell tower\", \"baby bib\", \"tandem bicycle\", \"bikini\", \"ring binder\", \"binoculars\", \"birdhouse\", \"boathouse\", \"bobsleigh\", \"bolo tie\", \"poke bonnet\", \"bookcase\", \"bookstore\", \"bottle cap\", \"hunting bow\", \"bow tie\", \"brass memorial plaque\", \"bra\", \"breakwater\", \"breastplate\", \"broom\", \"bucket\", \"buckle\", \"bulletproof vest\", \"high-speed train\", \"butcher shop\", \"taxicab\", \"cauldron\", \"candle\", \"cannon\", \"canoe\", \"can opener\", \"cardigan\", \"car mirror\", \"carousel\", \"tool kit\", \"cardboard box / carton\", \"car wheel\", \"automated teller machine\", \"cassette\", \"cassette player\", \"castle\", \"catamaran\", \"CD player\", \"cello\", \"mobile phone\", \"chain\", \"chain-link fence\", \"chain mail\", \"chainsaw\", \"storage chest\", \"chiffonier\", \"bell or wind chime\", \"china cabinet\", \"Christmas stocking\", \"church\", \"movie theater\", \"cleaver\", \"cliff dwelling\", \"cloak\", \"clogs\", \"cocktail shaker\", \"coffee mug\", \"coffeemaker\", \"spiral or coil\", \"combination lock\", \"computer keyboard\", \"candy store\", \"container ship\", \"convertible\", \"corkscrew\", \"cornet\", \"cowboy boot\", \"cowboy hat\", \"cradle\", \"construction crane\", \"crash helmet\", \"crate\", \"infant bed\", \"Crock Pot\", \"croquet ball\", \"crutch\", \"cuirass\", \"dam\", \"desk\", \"desktop computer\", \"rotary dial telephone\", \"diaper\", \"digital clock\", \"digital watch\", \"dining table\", \"dishcloth\", \"dishwasher\", \"disc brake\", \"dock\", \"dog sled\", \"dome\", \"doormat\", \"drilling rig\", \"drum\", \"drumstick\", \"dumbbell\", \"Dutch oven\", \"electric fan\", \"electric guitar\", \"electric locomotive\", \"entertainment center\", \"envelope\", \"espresso machine\", \"face powder\", \"feather boa\", \"filing cabinet\", \"fireboat\", \"fire truck\", \"fire screen\", \"flagpole\", \"flute\", \"folding chair\", \"football helmet\", \"forklift\", \"fountain\", \"fountain pen\", \"four-poster bed\", \"freight car\", \"French horn\", \"frying pan\", \"fur coat\", \"garbage truck\", \"gas mask or respirator\", \"gas pump\", \"goblet\", \"go-kart\", \"golf ball\", \"golf cart\", \"gondola\", \"gong\", \"gown\", \"grand piano\", \"greenhouse\", \"radiator grille\", \"grocery store\", \"guillotine\", \"hair clip\", \"hair spray\", \"half-track\", \"hammer\", \"hamper\", \"hair dryer\", \"hand-held computer\", \"handkerchief\", \"hard disk drive\", \"harmonica\", \"harp\", \"combine harvester\", \"hatchet\", \"holster\", \"home theater\", \"honeycomb\", \"hook\", \"hoop skirt\", \"gymnastic horizontal bar\", \"horse-drawn vehicle\", \"hourglass\", \"iPod\", \"clothes iron\", \"carved pumpkin\", \"jeans\", \"jeep\", \"T-shirt\", \"jigsaw puzzle\", \"rickshaw\", \"joystick\", \"kimono\", \"knee pad\", \"knot\", \"lab coat\", \"ladle\", \"lampshade\", \"laptop computer\", \"lawn mower\", \"lens cap\", \"letter opener\", \"library\", \"lifeboat\", \"lighter\", \"limousine\", \"ocean liner\", \"lipstick\", \"slip-on shoe\", \"lotion\", \"music speaker\", \"loupe magnifying glass\", \"sawmill\", \"magnetic compass\", \"messenger bag\", \"mailbox\", \"tights\", \"one-piece bathing suit\", \"manhole cover\", \"maraca\", \"marimba\", \"mask\", \"matchstick\", \"maypole\", \"maze\", \"measuring cup\", \"medicine cabinet\", \"megalith\", \"microphone\", \"microwave oven\", \"military uniform\", \"milk can\", \"minibus\", \"miniskirt\", \"minivan\", \"missile\", \"mitten\", \"mixing bowl\", \"mobile home\", \"ford model t\", \"modem\", \"monastery\", \"monitor\", \"moped\", \"mortar and pestle\", \"graduation cap\", \"mosque\", \"mosquito net\", \"vespa\", \"mountain bike\", \"tent\", \"computer mouse\", \"mousetrap\", \"moving van\", \"muzzle\", \"metal nail\", \"neck brace\", \"necklace\", \"baby pacifier\", \"notebook computer\", \"obelisk\", \"oboe\", \"ocarina\", \"odometer\", \"oil filter\", \"pipe organ\", \"oscilloscope\", \"overskirt\", \"bullock cart\", \"oxygen mask\", \"product packet / packaging\", \"paddle\", \"paddle wheel\", \"padlock\", \"paintbrush\", \"pajamas\", \"palace\", \"pan flute\", \"paper towel\", \"parachute\", \"parallel bars\", \"park bench\", \"parking meter\", \"railroad car\", \"patio\", \"payphone\", \"pedestal\", \"pencil case\", \"pencil sharpener\", \"perfume\", \"Petri dish\", \"photocopier\", \"plectrum\", \"Pickelhaube\", \"picket fence\", \"pickup truck\", \"pier\", \"piggy bank\", \"pill bottle\", \"pillow\", \"ping-pong ball\", \"pinwheel\", \"pirate ship\", \"drink pitcher\", \"block plane\", \"planetarium\", \"plastic bag\", \"plate rack\", \"farm plow\", \"plunger\", \"Polaroid camera\", \"pole\", \"police van\", \"poncho\", \"pool table\", \"soda bottle\", \"plant pot\", \"potter's wheel\", \"power drill\", \"prayer rug\", \"printer\", \"prison\", \"missile\", \"projector\", \"hockey puck\", \"punching bag\", \"purse\", \"quill\", \"quilt\", \"race car\", \"racket\", \"radiator\", \"radio\", \"radio telescope\", \"rain barrel\", \"recreational vehicle\", \"fishing casting reel\", \"reflex camera\", \"refrigerator\", \"remote control\", \"restaurant\", \"revolver\", \"rifle\", \"rocking chair\", \"rotisserie\", \"eraser\", \"rugby ball\", \"ruler measuring stick\", \"sneaker\", \"safe\", \"safety pin\", \"salt shaker\", \"sandal\", \"sarong\", \"saxophone\", \"scabbard\", \"weighing scale\", \"school bus\", \"schooner\", \"scoreboard\", \"CRT monitor\", \"screw\", \"screwdriver\", \"seat belt\", \"sewing machine\", \"shield\", \"shoe store\", \"shoji screen / room divider\", \"shopping basket\", \"shopping cart\", \"shovel\", \"shower cap\", \"shower curtain\", \"ski\", \"balaclava ski mask\", \"sleeping bag\", \"slide rule\", \"sliding door\", \"slot machine\", \"snorkel\", \"snowmobile\", \"snowplow\", \"soap dispenser\", \"soccer ball\", \"sock\", \"solar thermal collector\", \"sombrero\", \"soup bowl\", \"keyboard space bar\", \"space heater\", \"space shuttle\", \"spatula\", \"motorboat\", \"spider web\", \"spindle\", \"sports car\", \"spotlight\", \"stage\", \"steam locomotive\", \"through arch bridge\", \"steel drum\", \"stethoscope\", \"scarf\", \"stone wall\", \"stopwatch\", \"stove\", \"strainer\", \"tram\", \"stretcher\", \"couch\", \"stupa\", \"submarine\", \"suit\", \"sundial\", \"sunglasses\", \"sunglasses\", \"sunscreen\", \"suspension bridge\", \"mop\", \"sweatshirt\", \"swim trunks / shorts\", \"swing\", \"electrical switch\", \"syringe\", \"table lamp\", \"tank\", \"tape player\", \"teapot\", \"teddy bear\", \"television\", \"tennis ball\", \"thatched roof\", \"front curtain\", \"thimble\", \"threshing machine\", \"throne\", \"tile roof\", \"toaster\", \"tobacco shop\", \"toilet seat\", \"torch\", \"totem pole\", \"tow truck\", \"toy store\", \"tractor\", \"semi-trailer truck\", \"tray\", \"trench coat\", \"tricycle\", \"trimaran\", \"tripod\", \"triumphal arch\", \"trolleybus\", \"trombone\", \"hot tub\", \"turnstile\", \"typewriter keyboard\", \"umbrella\", \"unicycle\", \"upright piano\", \"vacuum cleaner\", \"vase\", \"vaulted or arched ceiling\", \"velvet fabric\", \"vending machine\", \"vestment\", \"viaduct\", \"violin\", \"volleyball\", \"waffle iron\", \"wall clock\", \"wallet\", \"wardrobe\", \"military aircraft\", \"sink\", \"washing machine\", \"water bottle\", \"water jug\", \"water tower\", \"whiskey jug\", \"whistle\", \"hair wig\", \"window screen\", \"window shade\", \"Windsor tie\", \"wine bottle\", \"airplane wing\", \"wok\", \"wooden spoon\", \"wool\", \"split-rail fence\", \"shipwreck\", \"sailboat\", \"yurt\", \"website\", \"comic book\", \"crossword\", \"traffic or street sign\", \"traffic light\", \"dust jacket\", \"menu\", \"plate\", \"guacamole\", \"consomme\", \"hot pot\", \"trifle\", \"ice cream\", \"popsicle\", \"baguette\", \"bagel\", \"pretzel\", \"cheeseburger\", \"hot dog\", \"mashed potatoes\", \"cabbage\", \"broccoli\", \"cauliflower\", \"zucchini\", \"spaghetti squash\", \"acorn squash\", \"butternut squash\", \"cucumber\", \"artichoke\", \"bell pepper\", \"cardoon\", \"mushroom\", \"Granny Smith apple\", \"strawberry\", \"orange\", \"lemon\", \"fig\", \"pineapple\", \"banana\", \"jackfruit\", \"cherimoya (custard apple)\", \"pomegranate\", \"hay\", \"carbonara\", \"chocolate syrup\", \"dough\", \"meatloaf\", \"pizza\", \"pot pie\", \"burrito\", \"red wine\", \"espresso\", \"tea cup\", \"eggnog\", \"mountain\", \"bubble\", \"cliff\", \"coral reef\", \"geyser\", \"lakeshore\", \"promontory\", \"sandbar\", \"beach\", \"valley\", \"volcano\", \"baseball player\", \"bridegroom\", \"scuba diver\", \"rapeseed\", \"daisy\", \"yellow lady's slipper\", \"corn\", \"acorn\", \"rose hip\", \"horse chestnut seed\", \"coral fungus\", \"agaric\", \"gyromitra\", \"stinkhorn mushroom\", \"earth star fungus\", \"hen of the woods mushroom\", \"bolete\", \"corn cob\", \"toilet paper\"]\n",
154
- "elif LANGUAGE == 'it':\n",
155
- " raise NotImplementedError"
156
- ]
157
- },
158
- {
159
- "cell_type": "code",
160
- "execution_count": 8,
161
- "metadata": {
162
- "colab": {
163
- "base_uri": "https://localhost:8080/"
164
- },
165
- "id": "toGtcd-Ji_MD",
166
- "outputId": "46bcc85f-3968-4836-f3c6-e48848e944c4"
167
- },
168
- "outputs": [
169
- {
170
- "name": "stdout",
171
- "output_type": "stream",
172
- "text": [
173
- "1000 classes, 80 templates\n"
174
- ]
175
- }
176
- ],
177
- "source": [
178
- "if LANGUAGE == 'en':\n",
179
- " imagenet_templates = [\n",
180
- " 'a bad photo of a {}.',\n",
181
- " 'a photo of many {}.',\n",
182
- " 'a sculpture of a {}.',\n",
183
- " 'a photo of the hard to see {}.',\n",
184
- " 'a low resolution photo of the {}.',\n",
185
- " 'a rendering of a {}.',\n",
186
- " 'graffiti of a {}.',\n",
187
- " 'a bad photo of the {}.',\n",
188
- " 'a cropped photo of the {}.',\n",
189
- " 'a tattoo of a {}.',\n",
190
- " 'the embroidered {}.',\n",
191
- " 'a photo of a hard to see {}.',\n",
192
- " 'a bright photo of a {}.',\n",
193
- " 'a photo of a clean {}.',\n",
194
- " 'a photo of a dirty {}.',\n",
195
- " 'a dark photo of the {}.',\n",
196
- " 'a drawing of a {}.',\n",
197
- " 'a photo of my {}.',\n",
198
- " 'the plastic {}.',\n",
199
- " 'a photo of the cool {}.',\n",
200
- " 'a close-up photo of a {}.',\n",
201
- " 'a black and white photo of the {}.',\n",
202
- " 'a painting of the {}.',\n",
203
- " 'a painting of a {}.',\n",
204
- " 'a pixelated photo of the {}.',\n",
205
- " 'a sculpture of the {}.',\n",
206
- " 'a bright photo of the {}.',\n",
207
- " 'a cropped photo of a {}.',\n",
208
- " 'a plastic {}.',\n",
209
- " 'a photo of the dirty {}.',\n",
210
- " 'a jpeg corrupted photo of a {}.',\n",
211
- " 'a blurry photo of the {}.',\n",
212
- " 'a photo of the {}.',\n",
213
- " 'a good photo of the {}.',\n",
214
- " 'a rendering of the {}.',\n",
215
- " 'a {} in a video game.',\n",
216
- " 'a photo of one {}.',\n",
217
- " 'a doodle of a {}.',\n",
218
- " 'a close-up photo of the {}.',\n",
219
- " 'a photo of a {}.',\n",
220
- " 'the origami {}.',\n",
221
- " 'the {} in a video game.',\n",
222
- " 'a sketch of a {}.',\n",
223
- " 'a doodle of the {}.',\n",
224
- " 'a origami {}.',\n",
225
- " 'a low resolution photo of a {}.',\n",
226
- " 'the toy {}.',\n",
227
- " 'a rendition of the {}.',\n",
228
- " 'a photo of the clean {}.',\n",
229
- " 'a photo of a large {}.',\n",
230
- " 'a rendition of a {}.',\n",
231
- " 'a photo of a nice {}.',\n",
232
- " 'a photo of a weird {}.',\n",
233
- " 'a blurry photo of a {}.',\n",
234
- " 'a cartoon {}.',\n",
235
- " 'art of a {}.',\n",
236
- " 'a sketch of the {}.',\n",
237
- " 'a embroidered {}.',\n",
238
- " 'a pixelated photo of a {}.',\n",
239
- " 'itap of the {}.',\n",
240
- " 'a jpeg corrupted photo of the {}.',\n",
241
- " 'a good photo of a {}.',\n",
242
- " 'a plushie {}.',\n",
243
- " 'a photo of the nice {}.',\n",
244
- " 'a photo of the small {}.',\n",
245
- " 'a photo of the weird {}.',\n",
246
- " 'the cartoon {}.',\n",
247
- " 'art of the {}.',\n",
248
- " 'a drawing of the {}.',\n",
249
- " 'a photo of the large {}.',\n",
250
- " 'a black and white photo of a {}.',\n",
251
- " 'the plushie {}.',\n",
252
- " 'a dark photo of a {}.',\n",
253
- " 'itap of a {}.',\n",
254
- " 'graffiti of the {}.',\n",
255
- " 'a toy {}.',\n",
256
- " 'itap of my {}.',\n",
257
- " 'a photo of a cool {}.',\n",
258
- " 'a photo of a small {}.',\n",
259
- " 'a tattoo of the {}.',\n",
260
- " ]\n",
261
- "elif LANGUAGE == 'it':\n",
262
- " raise NotImplementedError\n",
263
- "\n",
264
- "print(f\"{len(imagenet_classes)} classes, {len(imagenet_templates)} templates\")"
265
- ]
266
- },
267
- {
268
- "cell_type": "markdown",
269
- "metadata": {},
270
- "source": [
271
- "# Set up Validation Set"
272
- ]
273
- },
274
- {
275
- "cell_type": "code",
276
- "execution_count": 9,
277
- "metadata": {
278
- "colab": {
279
- "base_uri": "https://localhost:8080/"
280
- },
281
- "id": "cboKZocQlSYX",
282
- "outputId": "58e644d4-6e23-43b5-964e-1e9e8540d22e"
283
- },
284
- "outputs": [],
285
- "source": [
286
- "val_preprocess = transforms.Compose([\n",
287
- " Resize([config.vision_config.image_size], interpolation=InterpolationMode.BICUBIC),\n",
288
- " CenterCrop(config.vision_config.image_size),\n",
289
- " ToTensor(),\n",
290
- " Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),\n",
291
- "])"
292
- ]
293
- },
294
- {
295
- "cell_type": "code",
296
- "execution_count": 11,
297
- "metadata": {
298
- "colab": {
299
- "base_uri": "https://localhost:8080/"
300
- },
301
- "id": "moHR4UlHKsDc",
302
- "outputId": "178f6d0d-9a34-4cbc-c9c1-e7ce09927980"
303
- },
304
- "outputs": [],
305
- "source": [
306
- "images = torchvision.datasets.ImageNet(IMAGENET_ROOT, split='val', transform=val_preprocess)\n",
307
- "loader = torch.utils.data.DataLoader(\n",
308
- " images,\n",
309
- " batch_size=64,\n",
310
- " shuffle=False,\n",
311
- " num_workers=32,\n",
312
- " persistent_workers=True,\n",
313
- " drop_last=False\n",
314
- ")"
315
- ]
316
- },
317
- {
318
- "cell_type": "markdown",
319
- "metadata": {
320
- "id": "fz6D-F-Wbrtp"
321
- },
322
- "source": [
323
- "# Creating zero-shot classifier weights"
324
- ]
325
- },
326
- {
327
- "cell_type": "code",
328
- "execution_count": 12,
329
- "metadata": {
330
- "colab": {
331
- "base_uri": "https://localhost:8080/",
332
- "height": 66,
333
- "referenced_widgets": [
334
- "4e3a3f83649f45f8bef3434980634664",
335
- "f066bdb766664c788ba1e9de8d311e22",
336
- "4e7a7427d28a4ae684e0be4548eb9944",
337
- "cc9dc019c1334a46b2558ffa6c0dd6e6",
338
- "285c877d4f644f3a8a58c4eb5948101c",
339
- "075d6545e02e419ca565589eb5ffc318",
340
- "53f9106c80e84d5b8c3ec96162d1db98",
341
- "19c57d99e7c44cbda508ce558fde435d"
342
- ]
343
- },
344
- "id": "sRqDoz1Gbsii",
345
- "outputId": "5ab6c001-8a5e-42c9-ab46-4477a693229c"
346
- },
347
- "outputs": [
348
- {
349
- "data": {
350
- "application/vnd.jupyter.widget-view+json": {
351
- "model_id": "64fc81ee79fb45808e9fd8040151940e",
352
- "version_major": 2,
353
- "version_minor": 0
354
- },
355
- "text/plain": [
356
- "HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))"
357
- ]
358
- },
359
- "metadata": {},
360
- "output_type": "display_data"
361
- },
362
- {
363
- "name": "stdout",
364
- "output_type": "stream",
365
- "text": [
366
- "\n"
367
- ]
368
- }
369
- ],
370
- "source": [
371
- "def zeroshot_classifier(classnames, templates):\n",
372
- " zeroshot_weights = []\n",
373
- " for classname in tqdm(classnames):\n",
374
- " texts = [template.format(classname) for template in templates] #format with class\n",
375
- " inputs = tokenizer(texts, max_length=72, padding=\"max_length\", return_tensors=\"np\")\n",
376
- " class_embeddings = model.get_text_features(inputs['input_ids'], inputs['attention_mask'])#embed with text encoder\n",
377
- " class_embeddings /= jnp.linalg.norm(class_embeddings, axis=-1, keepdims=True)\n",
378
- " class_embedding = jnp.mean(class_embeddings, axis=0)\n",
379
- " class_embedding /= jnp.linalg.norm(class_embeddings)\n",
380
- " zeroshot_weights.append(class_embedding)\n",
381
- " zeroshot_weights = jnp.stack(zeroshot_weights, axis=1)\n",
382
- " return zeroshot_weights\n",
383
- "\n",
384
- "\n",
385
- "zeroshot_weights = zeroshot_classifier(imagenet_classes, imagenet_templates)"
386
- ]
387
- },
388
- {
389
- "cell_type": "markdown",
390
- "metadata": {
391
- "id": "1fZo7hG8iJP5"
392
- },
393
- "source": [
394
- "# Zero-shot prediction"
395
- ]
396
- },
397
- {
398
- "cell_type": "code",
399
- "execution_count": 13,
400
- "metadata": {
401
- "id": "j4kPSZoShQxN"
402
- },
403
- "outputs": [],
404
- "source": [
405
- "def accuracy(output, target, topk=(1,)):\n",
406
- " # pred = output.topk(max(topk), 1, True, True)[1].t()\n",
407
- " pred = np.argsort(output, axis=1)[:,-max(topk):]\n",
408
- " correct = jnp.equal(pred, jnp.expand_dims(target, axis=-1))\n",
409
- " return [float(jnp.sum(jnp.reshape(correct[:k], -1), axis=0, keepdims=True)) for k in topk]"
410
- ]
411
- },
412
- {
413
- "cell_type": "code",
414
- "execution_count": 16,
415
- "metadata": {
416
- "colab": {
417
- "base_uri": "https://localhost:8080/",
418
- "height": 100,
419
- "referenced_widgets": [
420
- "fbb2b937b22049f5987f39f48c652a86",
421
- "0a1b6b76984349ccb36ca2fc4a4a0208",
422
- "c136afb47aa14ac2832093ee415c6f3e",
423
- "467a151e73744eccb199fe72aa352e5b",
424
- "f6d637c3fc3c46928d023441227130e5",
425
- "029e6eadacb8480193aab52ff073be8f",
426
- "30178355f76742898d37966b3875ef0a",
427
- "2e62544c03d64d6d92b94fcfaca2fc90"
428
- ]
429
- },
430
- "id": "wKJ7YsdlkDXo",
431
- "outputId": "90e084fd-86bc-4a52-a06e-61bff7aa86e0"
432
- },
433
- "outputs": [
434
- {
435
- "data": {
436
- "application/vnd.jupyter.widget-view+json": {
437
- "model_id": "f5db0a6445f04d1bb5ae08dee278a215",
438
- "version_major": 2,
439
- "version_minor": 0
440
- },
441
- "text/plain": [
442
- "HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))"
443
- ]
444
- },
445
- "metadata": {},
446
- "output_type": "display_data"
447
- },
448
- {
449
- "name": "stdout",
450
- "output_type": "stream",
451
- "text": [
452
- "\n",
453
- "{'top1': 0.958, 'top5': 4.688, 'top10': 9.202, 'top100': 58.29}\n"
454
- ]
455
- }
456
- ],
457
- "source": [
458
- "top_ns = [1, 5, 10, 100]\n",
459
- "acc_counters = [0. for _ in top_ns]\n",
460
- "n = 0.\n",
461
- "\n",
462
- "for i, (images, target) in enumerate(tqdm(loader)):\n",
463
- " images = images.permute(0, 2, 3, 1).numpy()\n",
464
- " target = target.numpy()\n",
465
- " # predict\n",
466
- " image_features = model.get_image_features(images,)\n",
467
- " image_features /= jnp.linalg.norm(image_features, axis=-1, keepdims=True)\n",
468
- " logits = 100. * image_features @ zeroshot_weights\n",
469
- "\n",
470
- " # measure accuracy\n",
471
- " accs = accuracy(logits, target, topk=top_ns)\n",
472
- " for j in range(len(top_ns)):\n",
473
- " acc_counters[j] += accs[j]\n",
474
- " n += images.shape[0]\n",
475
- "\n",
476
- "tops = {f'top{top_ns[i]}': acc_counters[i] / n * 100 for i in range(len(top_ns))}\n",
477
- "\n",
478
- "print(tops)"
479
- ]
480
- }
481
- ],
482
- "metadata": {
483
- "accelerator": "GPU",
484
- "colab": {
485
- "collapsed_sections": [],
486
- "name": "Prompt Engineering for ImageNet.ipynb",
487
- "provenance": []
488
- },
489
- "kernelspec": {
490
- "display_name": "Python 3 (ipykernel)",
491
- "language": "python",
492
- "name": "python3"
493
- },
494
- "language_info": {
495
- "codemirror_mode": {
496
- "name": "ipython",
497
- "version": 3
498
- },
499
- "file_extension": ".py",
500
- "mimetype": "text/x-python",
501
- "name": "python",
502
- "nbconvert_exporter": "python",
503
- "pygments_lexer": "ipython3",
504
- "version": "3.8.10"
505
- },
506
- "widgets": {
507
- "application/vnd.jupyter.widget-state+json": {
508
- "029e6eadacb8480193aab52ff073be8f": {
509
- "model_module": "@jupyter-widgets/base",
510
- "model_name": "LayoutModel",
511
- "state": {
512
- "_model_module": "@jupyter-widgets/base",
513
- "_model_module_version": "1.2.0",
514
- "_model_name": "LayoutModel",
515
- "_view_count": null,
516
- "_view_module": "@jupyter-widgets/base",
517
- "_view_module_version": "1.2.0",
518
- "_view_name": "LayoutView",
519
- "align_content": null,
520
- "align_items": null,
521
- "align_self": null,
522
- "border": null,
523
- "bottom": null,
524
- "display": null,
525
- "flex": null,
526
- "flex_flow": null,
527
- "grid_area": null,
528
- "grid_auto_columns": null,
529
- "grid_auto_flow": null,
530
- "grid_auto_rows": null,
531
- "grid_column": null,
532
- "grid_gap": null,
533
- "grid_row": null,
534
- "grid_template_areas": null,
535
- "grid_template_columns": null,
536
- "grid_template_rows": null,
537
- "height": null,
538
- "justify_content": null,
539
- "justify_items": null,
540
- "left": null,
541
- "margin": null,
542
- "max_height": null,
543
- "max_width": null,
544
- "min_height": null,
545
- "min_width": null,
546
- "object_fit": null,
547
- "object_position": null,
548
- "order": null,
549
- "overflow": null,
550
- "overflow_x": null,
551
- "overflow_y": null,
552
- "padding": null,
553
- "right": null,
554
- "top": null,
555
- "visibility": null,
556
- "width": null
557
- }
558
- },
559
- "075d6545e02e419ca565589eb5ffc318": {
560
- "model_module": "@jupyter-widgets/base",
561
- "model_name": "LayoutModel",
562
- "state": {
563
- "_model_module": "@jupyter-widgets/base",
564
- "_model_module_version": "1.2.0",
565
- "_model_name": "LayoutModel",
566
- "_view_count": null,
567
- "_view_module": "@jupyter-widgets/base",
568
- "_view_module_version": "1.2.0",
569
- "_view_name": "LayoutView",
570
- "align_content": null,
571
- "align_items": null,
572
- "align_self": null,
573
- "border": null,
574
- "bottom": null,
575
- "display": null,
576
- "flex": null,
577
- "flex_flow": null,
578
- "grid_area": null,
579
- "grid_auto_columns": null,
580
- "grid_auto_flow": null,
581
- "grid_auto_rows": null,
582
- "grid_column": null,
583
- "grid_gap": null,
584
- "grid_row": null,
585
- "grid_template_areas": null,
586
- "grid_template_columns": null,
587
- "grid_template_rows": null,
588
- "height": null,
589
- "justify_content": null,
590
- "justify_items": null,
591
- "left": null,
592
- "margin": null,
593
- "max_height": null,
594
- "max_width": null,
595
- "min_height": null,
596
- "min_width": null,
597
- "object_fit": null,
598
- "object_position": null,
599
- "order": null,
600
- "overflow": null,
601
- "overflow_x": null,
602
- "overflow_y": null,
603
- "padding": null,
604
- "right": null,
605
- "top": null,
606
- "visibility": null,
607
- "width": null
608
- }
609
- },
610
- "0a1b6b76984349ccb36ca2fc4a4a0208": {
611
- "model_module": "@jupyter-widgets/base",
612
- "model_name": "LayoutModel",
613
- "state": {
614
- "_model_module": "@jupyter-widgets/base",
615
- "_model_module_version": "1.2.0",
616
- "_model_name": "LayoutModel",
617
- "_view_count": null,
618
- "_view_module": "@jupyter-widgets/base",
619
- "_view_module_version": "1.2.0",
620
- "_view_name": "LayoutView",
621
- "align_content": null,
622
- "align_items": null,
623
- "align_self": null,
624
- "border": null,
625
- "bottom": null,
626
- "display": null,
627
- "flex": null,
628
- "flex_flow": null,
629
- "grid_area": null,
630
- "grid_auto_columns": null,
631
- "grid_auto_flow": null,
632
- "grid_auto_rows": null,
633
- "grid_column": null,
634
- "grid_gap": null,
635
- "grid_row": null,
636
- "grid_template_areas": null,
637
- "grid_template_columns": null,
638
- "grid_template_rows": null,
639
- "height": null,
640
- "justify_content": null,
641
- "justify_items": null,
642
- "left": null,
643
- "margin": null,
644
- "max_height": null,
645
- "max_width": null,
646
- "min_height": null,
647
- "min_width": null,
648
- "object_fit": null,
649
- "object_position": null,
650
- "order": null,
651
- "overflow": null,
652
- "overflow_x": null,
653
- "overflow_y": null,
654
- "padding": null,
655
- "right": null,
656
- "top": null,
657
- "visibility": null,
658
- "width": null
659
- }
660
- },
661
- "19c57d99e7c44cbda508ce558fde435d": {
662
- "model_module": "@jupyter-widgets/base",
663
- "model_name": "LayoutModel",
664
- "state": {
665
- "_model_module": "@jupyter-widgets/base",
666
- "_model_module_version": "1.2.0",
667
- "_model_name": "LayoutModel",
668
- "_view_count": null,
669
- "_view_module": "@jupyter-widgets/base",
670
- "_view_module_version": "1.2.0",
671
- "_view_name": "LayoutView",
672
- "align_content": null,
673
- "align_items": null,
674
- "align_self": null,
675
- "border": null,
676
- "bottom": null,
677
- "display": null,
678
- "flex": null,
679
- "flex_flow": null,
680
- "grid_area": null,
681
- "grid_auto_columns": null,
682
- "grid_auto_flow": null,
683
- "grid_auto_rows": null,
684
- "grid_column": null,
685
- "grid_gap": null,
686
- "grid_row": null,
687
- "grid_template_areas": null,
688
- "grid_template_columns": null,
689
- "grid_template_rows": null,
690
- "height": null,
691
- "justify_content": null,
692
- "justify_items": null,
693
- "left": null,
694
- "margin": null,
695
- "max_height": null,
696
- "max_width": null,
697
- "min_height": null,
698
- "min_width": null,
699
- "object_fit": null,
700
- "object_position": null,
701
- "order": null,
702
- "overflow": null,
703
- "overflow_x": null,
704
- "overflow_y": null,
705
- "padding": null,
706
- "right": null,
707
- "top": null,
708
- "visibility": null,
709
- "width": null
710
- }
711
- },
712
- "285c877d4f644f3a8a58c4eb5948101c": {
713
- "model_module": "@jupyter-widgets/controls",
714
- "model_name": "ProgressStyleModel",
715
- "state": {
716
- "_model_module": "@jupyter-widgets/controls",
717
- "_model_module_version": "1.5.0",
718
- "_model_name": "ProgressStyleModel",
719
- "_view_count": null,
720
- "_view_module": "@jupyter-widgets/base",
721
- "_view_module_version": "1.2.0",
722
- "_view_name": "StyleView",
723
- "bar_color": null,
724
- "description_width": "initial"
725
- }
726
- },
727
- "2e62544c03d64d6d92b94fcfaca2fc90": {
728
- "model_module": "@jupyter-widgets/base",
729
- "model_name": "LayoutModel",
730
- "state": {
731
- "_model_module": "@jupyter-widgets/base",
732
- "_model_module_version": "1.2.0",
733
- "_model_name": "LayoutModel",
734
- "_view_count": null,
735
- "_view_module": "@jupyter-widgets/base",
736
- "_view_module_version": "1.2.0",
737
- "_view_name": "LayoutView",
738
- "align_content": null,
739
- "align_items": null,
740
- "align_self": null,
741
- "border": null,
742
- "bottom": null,
743
- "display": null,
744
- "flex": null,
745
- "flex_flow": null,
746
- "grid_area": null,
747
- "grid_auto_columns": null,
748
- "grid_auto_flow": null,
749
- "grid_auto_rows": null,
750
- "grid_column": null,
751
- "grid_gap": null,
752
- "grid_row": null,
753
- "grid_template_areas": null,
754
- "grid_template_columns": null,
755
- "grid_template_rows": null,
756
- "height": null,
757
- "justify_content": null,
758
- "justify_items": null,
759
- "left": null,
760
- "margin": null,
761
- "max_height": null,
762
- "max_width": null,
763
- "min_height": null,
764
- "min_width": null,
765
- "object_fit": null,
766
- "object_position": null,
767
- "order": null,
768
- "overflow": null,
769
- "overflow_x": null,
770
- "overflow_y": null,
771
- "padding": null,
772
- "right": null,
773
- "top": null,
774
- "visibility": null,
775
- "width": null
776
- }
777
- },
778
- "30178355f76742898d37966b3875ef0a": {
779
- "model_module": "@jupyter-widgets/controls",
780
- "model_name": "DescriptionStyleModel",
781
- "state": {
782
- "_model_module": "@jupyter-widgets/controls",
783
- "_model_module_version": "1.5.0",
784
- "_model_name": "DescriptionStyleModel",
785
- "_view_count": null,
786
- "_view_module": "@jupyter-widgets/base",
787
- "_view_module_version": "1.2.0",
788
- "_view_name": "StyleView",
789
- "description_width": ""
790
- }
791
- },
792
- "467a151e73744eccb199fe72aa352e5b": {
793
- "model_module": "@jupyter-widgets/controls",
794
- "model_name": "HTMLModel",
795
- "state": {
796
- "_dom_classes": [],
797
- "_model_module": "@jupyter-widgets/controls",
798
- "_model_module_version": "1.5.0",
799
- "_model_name": "HTMLModel",
800
- "_view_count": null,
801
- "_view_module": "@jupyter-widgets/controls",
802
- "_view_module_version": "1.5.0",
803
- "_view_name": "HTMLView",
804
- "description": "",
805
- "description_tooltip": null,
806
- "layout": "IPY_MODEL_2e62544c03d64d6d92b94fcfaca2fc90",
807
- "placeholder": "​",
808
- "style": "IPY_MODEL_30178355f76742898d37966b3875ef0a",
809
- "value": " 313/313 [01:26<00:00, 3.62it/s]"
810
- }
811
- },
812
- "4e3a3f83649f45f8bef3434980634664": {
813
- "model_module": "@jupyter-widgets/controls",
814
- "model_name": "HBoxModel",
815
- "state": {
816
- "_dom_classes": [],
817
- "_model_module": "@jupyter-widgets/controls",
818
- "_model_module_version": "1.5.0",
819
- "_model_name": "HBoxModel",
820
- "_view_count": null,
821
- "_view_module": "@jupyter-widgets/controls",
822
- "_view_module_version": "1.5.0",
823
- "_view_name": "HBoxView",
824
- "box_style": "",
825
- "children": [
826
- "IPY_MODEL_4e7a7427d28a4ae684e0be4548eb9944",
827
- "IPY_MODEL_cc9dc019c1334a46b2558ffa6c0dd6e6"
828
- ],
829
- "layout": "IPY_MODEL_f066bdb766664c788ba1e9de8d311e22"
830
- }
831
- },
832
- "4e7a7427d28a4ae684e0be4548eb9944": {
833
- "model_module": "@jupyter-widgets/controls",
834
- "model_name": "FloatProgressModel",
835
- "state": {
836
- "_dom_classes": [],
837
- "_model_module": "@jupyter-widgets/controls",
838
- "_model_module_version": "1.5.0",
839
- "_model_name": "FloatProgressModel",
840
- "_view_count": null,
841
- "_view_module": "@jupyter-widgets/controls",
842
- "_view_module_version": "1.5.0",
843
- "_view_name": "ProgressView",
844
- "bar_style": "success",
845
- "description": "100%",
846
- "description_tooltip": null,
847
- "layout": "IPY_MODEL_075d6545e02e419ca565589eb5ffc318",
848
- "max": 1000,
849
- "min": 0,
850
- "orientation": "horizontal",
851
- "style": "IPY_MODEL_285c877d4f644f3a8a58c4eb5948101c",
852
- "value": 1000
853
- }
854
- },
855
- "53f9106c80e84d5b8c3ec96162d1db98": {
856
- "model_module": "@jupyter-widgets/controls",
857
- "model_name": "DescriptionStyleModel",
858
- "state": {
859
- "_model_module": "@jupyter-widgets/controls",
860
- "_model_module_version": "1.5.0",
861
- "_model_name": "DescriptionStyleModel",
862
- "_view_count": null,
863
- "_view_module": "@jupyter-widgets/base",
864
- "_view_module_version": "1.2.0",
865
- "_view_name": "StyleView",
866
- "description_width": ""
867
- }
868
- },
869
- "c136afb47aa14ac2832093ee415c6f3e": {
870
- "model_module": "@jupyter-widgets/controls",
871
- "model_name": "FloatProgressModel",
872
- "state": {
873
- "_dom_classes": [],
874
- "_model_module": "@jupyter-widgets/controls",
875
- "_model_module_version": "1.5.0",
876
- "_model_name": "FloatProgressModel",
877
- "_view_count": null,
878
- "_view_module": "@jupyter-widgets/controls",
879
- "_view_module_version": "1.5.0",
880
- "_view_name": "ProgressView",
881
- "bar_style": "success",
882
- "description": "100%",
883
- "description_tooltip": null,
884
- "layout": "IPY_MODEL_029e6eadacb8480193aab52ff073be8f",
885
- "max": 313,
886
- "min": 0,
887
- "orientation": "horizontal",
888
- "style": "IPY_MODEL_f6d637c3fc3c46928d023441227130e5",
889
- "value": 313
890
- }
891
- },
892
- "cc9dc019c1334a46b2558ffa6c0dd6e6": {
893
- "model_module": "@jupyter-widgets/controls",
894
- "model_name": "HTMLModel",
895
- "state": {
896
- "_dom_classes": [],
897
- "_model_module": "@jupyter-widgets/controls",
898
- "_model_module_version": "1.5.0",
899
- "_model_name": "HTMLModel",
900
- "_view_count": null,
901
- "_view_module": "@jupyter-widgets/controls",
902
- "_view_module_version": "1.5.0",
903
- "_view_name": "HTMLView",
904
- "description": "",
905
- "description_tooltip": null,
906
- "layout": "IPY_MODEL_19c57d99e7c44cbda508ce558fde435d",
907
- "placeholder": "​",
908
- "style": "IPY_MODEL_53f9106c80e84d5b8c3ec96162d1db98",
909
- "value": " 1000/1000 [01:09<00:00, 14.35it/s]"
910
- }
911
- },
912
- "f066bdb766664c788ba1e9de8d311e22": {
913
- "model_module": "@jupyter-widgets/base",
914
- "model_name": "LayoutModel",
915
- "state": {
916
- "_model_module": "@jupyter-widgets/base",
917
- "_model_module_version": "1.2.0",
918
- "_model_name": "LayoutModel",
919
- "_view_count": null,
920
- "_view_module": "@jupyter-widgets/base",
921
- "_view_module_version": "1.2.0",
922
- "_view_name": "LayoutView",
923
- "align_content": null,
924
- "align_items": null,
925
- "align_self": null,
926
- "border": null,
927
- "bottom": null,
928
- "display": null,
929
- "flex": null,
930
- "flex_flow": null,
931
- "grid_area": null,
932
- "grid_auto_columns": null,
933
- "grid_auto_flow": null,
934
- "grid_auto_rows": null,
935
- "grid_column": null,
936
- "grid_gap": null,
937
- "grid_row": null,
938
- "grid_template_areas": null,
939
- "grid_template_columns": null,
940
- "grid_template_rows": null,
941
- "height": null,
942
- "justify_content": null,
943
- "justify_items": null,
944
- "left": null,
945
- "margin": null,
946
- "max_height": null,
947
- "max_width": null,
948
- "min_height": null,
949
- "min_width": null,
950
- "object_fit": null,
951
- "object_position": null,
952
- "order": null,
953
- "overflow": null,
954
- "overflow_x": null,
955
- "overflow_y": null,
956
- "padding": null,
957
- "right": null,
958
- "top": null,
959
- "visibility": null,
960
- "width": null
961
- }
962
- },
963
- "f6d637c3fc3c46928d023441227130e5": {
964
- "model_module": "@jupyter-widgets/controls",
965
- "model_name": "ProgressStyleModel",
966
- "state": {
967
- "_model_module": "@jupyter-widgets/controls",
968
- "_model_module_version": "1.5.0",
969
- "_model_name": "ProgressStyleModel",
970
- "_view_count": null,
971
- "_view_module": "@jupyter-widgets/base",
972
- "_view_module_version": "1.2.0",
973
- "_view_name": "StyleView",
974
- "bar_color": null,
975
- "description_width": "initial"
976
- }
977
- },
978
- "fbb2b937b22049f5987f39f48c652a86": {
979
- "model_module": "@jupyter-widgets/controls",
980
- "model_name": "HBoxModel",
981
- "state": {
982
- "_dom_classes": [],
983
- "_model_module": "@jupyter-widgets/controls",
984
- "_model_module_version": "1.5.0",
985
- "_model_name": "HBoxModel",
986
- "_view_count": null,
987
- "_view_module": "@jupyter-widgets/controls",
988
- "_view_module_version": "1.5.0",
989
- "_view_name": "HBoxView",
990
- "box_style": "",
991
- "children": [
992
- "IPY_MODEL_c136afb47aa14ac2832093ee415c6f3e",
993
- "IPY_MODEL_467a151e73744eccb199fe72aa352e5b"
994
- ],
995
- "layout": "IPY_MODEL_0a1b6b76984349ccb36ca2fc4a4a0208"
996
- }
997
- }
998
- }
999
- }
1000
- },
1001
- "nbformat": 4,
1002
- "nbformat_minor": 4
1003
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32679e84a5c7c8148def9a503553b3dddab3529928c3d069b3146f7257e577fe
3
+ size 795766616
hybrid_clip/README.md DELETED
@@ -1,172 +0,0 @@
1
- <!---
2
- Copyright 2021 The HuggingFace Team. All rights reserved.
3
-
4
- Licensed under the Apache License, Version 2.0 (the "License");
5
- you may not use this file except in compliance with the License.
6
- You may obtain a copy of the License at
7
-
8
- http://www.apache.org/licenses/LICENSE-2.0
9
-
10
- Unless required by applicable law or agreed to in writing, software
11
- distributed under the License is distributed on an "AS IS" BASIS,
12
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- See the License for the specific language governing permissions and
14
- limitations under the License.
15
- -->
16
-
17
- # Vision-Text dual encoder model training examples
18
-
19
- > Note: This example is experimental and might not give the best possible results
20
-
21
- The following example showcases how to train a CLIP like vision-text dual encoder model
22
- using a pre-trained vision and text encoder using the JAX/Flax backend.
23
-
24
- Such a model can be used for natural language image search and potentially zero-shot image classification.
25
- The model is inspired by the [CLIP](https://openai.com/blog/clip/) approach, introduced by Alec Radford et al.
26
- The idea is to train a vision encoder and a text encoder jointly to project the representation of images and their
27
- captions into the same embedding space, such that the caption embeddings are located near the embeddings
28
- of the images they describe.
29
-
30
- JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU.
31
- Models written in JAX/Flax are **immutable** and updated in a purely functional
32
- way which enables simple and efficient model parallelism.
33
-
34
- In this example we will use the vision model from [CLIP](https://huggingface.co/models?filter=clip)
35
- as the image encoder and [`roberta-base`](https://huggingface.co/roberta-base) as the text encoder.
36
- Note that one can also use the [ViT](https://huggingface.co/models?filter=vit) model as image encoder and any other BERT or ROBERTa model as text encoder.
37
- To train the model on languages other than English one should choose a text encoder trained on the desired
38
- language and a image-text dataset in that language. One such dataset is [WIT](https://github.com/google-research-datasets/wit).
39
-
40
- Let's start by creating a model repository to save the trained model and logs.
41
- Here we call the model `"clip-roberta-base"`, but you can change the model name as you like.
42
-
43
- You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
44
- you are logged in) or via the command line:
45
-
46
- ```
47
- huggingface-cli repo create clip-roberta-base
48
- ```
49
- Next we clone the model repository to add the tokenizer and model files.
50
- ```
51
- git clone https://huggingface.co/<your-username>/clip-roberta-base
52
- ```
53
- To ensure that all tensorboard traces will be uploaded correctly, we need to
54
- track them. You can run the following command inside your model repo to do so.
55
-
56
- ```
57
- cd clip-roberta-base
58
- git lfs track "*tfevents*"
59
- ```
60
-
61
- Great, we have set up our model repository. During training, we will automatically
62
- push the training logs and model weights to the repo.
63
-
64
- Next, let's add a symbolic link to the `run_hybrid_clip.py`.
65
-
66
- ```bash
67
- export MODEL_DIR="./clip-roberta-base
68
- ln -s ~/transformers/examples/flax/summarization/run_hybrid_clip.py run_hybrid_clip.py
69
- ```
70
-
71
- ## How to use the `FlaxHybridCLIP` model:
72
-
73
- The `FlaxHybridCLIP` class let's you load any text and vision encoder model to create a dual encoder.
74
- Here is an example of how to load the model using pre-trained text and vision models.
75
-
76
- ```python
77
- from modeling_hybrid_clip import FlaxHybridCLIP
78
-
79
- model = FlaxHybridCLIP.from_text_vision_pretrained("bert-base-uncased", "openai/clip-vit-base-patch32")
80
-
81
- # save the model
82
- model.save_pretrained("bert-clip")
83
-
84
- # load the saved model
85
- model = FlaxHybridCLIP.from_pretrained("bert-clip")
86
- ```
87
-
88
- If the checkpoints are in PyTorch then one could pass `text_from_pt=True` and `vision_from_pt=True`. This will load the model
89
- PyTorch checkpoints convert them to flax and load the model.
90
-
91
- ```python
92
- model = FlaxHybridCLIP.from_text_vision_pretrained("bert-base-uncased", "openai/clip-vit-base-patch32", text_from_pt=True, vision_from_pt=True)
93
- ```
94
-
95
- This loads both the text and vision encoders using pre-trained weights, the projection layers are randomly
96
- initialized except for CLIP's vision model. If you use CLIP to initialize the vision model then the vision projection weights are also
97
- loaded using the pre-trained weights.
98
-
99
- ## Prepare the dataset
100
-
101
- We will use the MS-COCO dataset to train our dual encoder model. MS-COCO contains over 82,000 images, each of which has at least 5 different caption annotations. The dataset is usually used for image captioning tasks, but we can repurpose the image-caption pairs to train our dual encoder model for image search.
102
-
103
- ### Download and extract the data.
104
-
105
- It consists of two compressed folders: one with images, and the other—with associated image captions. Note that the compressed images folder is 13GB in size.
106
-
107
- ```bash
108
- wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip
109
- wget http://images.cocodataset.org/zips/train2014.zip
110
-
111
- unzip annotations_trainval2014.zip
112
- unzip train2014.zip
113
-
114
- mkdir coco_dataset
115
- mv train2014 coco_dataset/
116
- mv annotations coco_dataset/
117
- ```
118
-
119
- ### Prepare dataset files and split the dataset.
120
-
121
- ```python
122
- import json
123
- import collections
124
-
125
- images_dir = "coco_dataset/train2014"
126
- annotation_file = "coco_dataset/annotations/captions_train2014.json"
127
- with open(annotation_file, "r") as f:
128
- annotations = json.load(f)["annotations"]
129
-
130
- image_path_to_caption = collections.defaultdict(list)
131
- for element in annotations:
132
- caption = f"{element['caption'].lower().rstrip('.')}"
133
- image_path = images_dir + "/COCO_train2014_" + "%012d.jpg" % (element["image_id"])
134
- image_path_to_caption[image_path].append(caption)
135
-
136
- lines = []
137
- for image_path, captions in image_path_to_caption.items():
138
- lines.append(json.dumps({"image_path": image_path, "captions": captions}))
139
-
140
- train_lines = lines[:-8000]
141
- valid_line = lines[-8000:]
142
- with open("coco_dataset/train_dataset.json", "w") as f:
143
- f.write("\n".join(train_lines))
144
-
145
- with open("coco_dataset/valid_dataset.json", "w") as f:
146
- f.write("\n".join(valid_line))
147
- ```
148
-
149
- > Note: The data loading and processing part of this script can still be improved for maximum performance. In particular one should decode the images beforehand and use those instead decoding them each time. If the dataset is small or if you have huge disk space the you could also pre-process all the dataset beforehand and then use it.
150
-
151
- ## Train the model
152
- Next we can run the example script to train the model:
153
-
154
- ```bash
155
- python run_hybrid_clip.py \
156
- --output_dir ${MODEL_DIR} \
157
- --text_model_name_or_path="roberta-base" \
158
- --vision_model_name_or_path="openai/clip-vit-base-patch32" \
159
- --tokenizer_name="roberta-base" \
160
- --train_file="coco_dataset/train_dataset.json" \
161
- --validation_file="coco_dataset/validation_dataset.json" \
162
- --do_train --do_eval \
163
- --num_train_epochs="40" --max_seq_length 96 \
164
- --per_device_train_batch_size="64" \
165
- --per_device_eval_batch_size="64" \
166
- --learning_rate="5e-5" --warmup_steps="0" --weight_decay 0.1 \
167
- --overwrite_output_dir \
168
- --preprocessing_num_workers 32 \
169
- --push_to_hub
170
- ```
171
-
172
- This should finish in ~1h50 mins with min validation loss 2.43. Training statistics can be accessed on [tfhub.de](https://tensorboard.dev/experiment/RUNPYd1yRgSD5kZSb9hDig/#scalars)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hybrid_clip/configuration_hybrid_clip.py DELETED
@@ -1,108 +0,0 @@
1
- import copy
2
-
3
- from transformers.configuration_utils import PretrainedConfig
4
- from transformers.utils import logging
5
-
6
-
7
- logger = logging.get_logger(__name__)
8
-
9
-
10
- class HybridCLIPConfig(PretrainedConfig):
11
- r"""
12
- :class:`HybridCLIPConfig` is the configuration class to store the configuration of a
13
- :class:`~HybridCLIPModel`. It is used to instantiate HybridCLIPModel model according to the specified arguments,
14
- defining the text model and vision model configs.
15
-
16
- Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
17
- outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
18
-
19
- Args:
20
- text_config_dict (:obj:`dict`):
21
- Dictionary of configuration options that defines text model config.
22
- vision_config_dict (:obj:`dict`):
23
- Dictionary of configuration options that defines vison model config.
24
- projection_dim (:obj:`int`, `optional`, defaults to 512):
25
- Dimentionality of text and vision projection layers.
26
- kwargs (`optional`):
27
- Dictionary of keyword arguments.
28
-
29
- Examples::
30
-
31
- >>> from transformers import BertConfig, CLIPConfig, HybridCLIPConfig, FlaxHybridCLIP
32
-
33
- >>> # Initializing a BERT and CLIP configuration
34
- >>> config_text = BertConfig()
35
- >>> config_vision = CLIPConfig()
36
-
37
- >>> config = HybridCLIPConfig.from_text_vision_configs(config_text, config_vision, projection_dim=512)
38
-
39
- >>> # Initializing a BERT and CLIPVision model
40
- >>> model = EncoderDecoderModel(config=config)
41
-
42
- >>> # Accessing the model configuration
43
- >>> config_text = model.config.text_config
44
- >>> config_vision = model.config.vision_config
45
-
46
- >>> # Saving the model, including its configuration
47
- >>> model.save_pretrained('my-model')
48
-
49
- >>> # loading model and config from pretrained folder
50
- >>> encoder_decoder_config = HybridCLIPConfig.from_pretrained('my-model')
51
- >>> model = FlaxHybridCLIP.from_pretrained('my-model', config=encoder_decoder_config)
52
- """
53
-
54
- model_type = "hybrid-clip"
55
- is_composition = True
56
-
57
- def __init__(self, projection_dim=512, **kwargs):
58
- super().__init__(**kwargs)
59
-
60
- if "text_config" not in kwargs:
61
- raise ValueError("`text_config` can not be `None`.")
62
-
63
- if "vision_config" not in kwargs:
64
- raise ValueError("`vision_config` can not be `None`.")
65
-
66
- text_config = kwargs.pop("text_config")
67
- vision_config = kwargs.pop("vision_config")
68
-
69
- text_model_type = text_config.pop("model_type")
70
- vision_model_type = vision_config.pop("model_type")
71
-
72
- from transformers import AutoConfig
73
-
74
- self.text_config = AutoConfig.for_model(text_model_type, **text_config)
75
-
76
- if vision_model_type == "clip":
77
- self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config
78
- else:
79
- self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)
80
-
81
- self.projection_dim = projection_dim
82
- self.initializer_factor = 1.0
83
-
84
- @classmethod
85
- def from_text_vision_configs(cls, text_config: PretrainedConfig, vision_config: PretrainedConfig, **kwargs):
86
- r"""
87
- Instantiate a :class:`HybridCLIPConfig` (or a derived class) from text model configuration and
88
- vision model configuration.
89
-
90
- Returns:
91
- :class:`HybridCLIPConfig`: An instance of a configuration object
92
- """
93
-
94
- return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
95
-
96
- def to_dict(self):
97
- """
98
- Serializes this instance to a Python dictionary. Override the default
99
- :meth:`~transformers.PretrainedConfig.to_dict`.
100
-
101
- Returns:
102
- :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
103
- """
104
- output = copy.deepcopy(self.__dict__)
105
- output["text_config"] = self.text_config.to_dict()
106
- output["vision_config"] = self.vision_config.to_dict()
107
- output["model_type"] = self.__class__.model_type
108
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hybrid_clip/modeling_hybrid_clip.py DELETED
@@ -1,420 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2021 The HuggingFace Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from typing import Optional, Tuple
17
-
18
- import flax.linen as nn
19
- import jax
20
- import jax.numpy as jnp
21
- from configuration_hybrid_clip import HybridCLIPConfig
22
- from flax.core.frozen_dict import FrozenDict
23
- from transformers import FLAX_MODEL_MAPPING, FlaxCLIPVisionModel
24
- from transformers.modeling_flax_utils import FlaxPreTrainedModel
25
- from transformers.models.clip.modeling_flax_clip import FlaxCLIPOutput
26
- from transformers.utils import logging
27
-
28
-
29
- logger = logging.get_logger(__name__)
30
-
31
-
32
- class FlaxHybridCLIPModule(nn.Module):
33
- config: HybridCLIPConfig
34
- dtype: jnp.dtype = jnp.float32
35
-
36
- def setup(self):
37
- text_config = self.config.text_config
38
- vision_config = self.config.vision_config
39
-
40
- self.projection_dim = self.config.projection_dim
41
- self.text_embed_dim = text_config.hidden_size
42
- self.vision_embed_dim = vision_config.hidden_size
43
-
44
- text_module = FLAX_MODEL_MAPPING[self.config.text_config.__class__].module_class
45
- vision_module = FLAX_MODEL_MAPPING.get(self.config.vision_config.__class__, FlaxCLIPVisionModel).module_class
46
-
47
- self.text_model = text_module(text_config, dtype=self.dtype)
48
- self.vision_model = vision_module(vision_config, dtype=self.dtype)
49
-
50
- self.visual_projection = nn.Dense(
51
- self.projection_dim,
52
- dtype=self.dtype,
53
- kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
54
- use_bias=False,
55
- )
56
- self.text_projection = nn.Dense(
57
- self.projection_dim,
58
- dtype=self.dtype,
59
- kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
60
- use_bias=False,
61
- )
62
- self.logit_scale = self.param("logit_scale", jax.nn.initializers.ones, [])
63
-
64
- def __call__(
65
- self,
66
- input_ids=None,
67
- pixel_values=None,
68
- attention_mask=None,
69
- position_ids=None,
70
- token_type_ids=None,
71
- deterministic: bool = True,
72
- output_attentions=None,
73
- output_hidden_states=None,
74
- return_dict=None,
75
- ):
76
- return_dict = return_dict if return_dict is not None else self.config.return_dict
77
-
78
- vision_outputs = self.vision_model(
79
- pixel_values=pixel_values,
80
- deterministic=deterministic,
81
- output_attentions=output_attentions,
82
- output_hidden_states=output_hidden_states,
83
- return_dict=return_dict,
84
- )
85
-
86
- text_outputs = self.text_model(
87
- input_ids=input_ids,
88
- attention_mask=attention_mask,
89
- token_type_ids=token_type_ids,
90
- position_ids=position_ids,
91
- deterministic=deterministic,
92
- output_attentions=output_attentions,
93
- output_hidden_states=output_hidden_states,
94
- return_dict=return_dict,
95
- )
96
-
97
- image_embeds = vision_outputs[1]
98
- image_embeds = self.visual_projection(image_embeds)
99
-
100
- text_embeds = text_outputs[1]
101
- text_embeds = self.text_projection(text_embeds)
102
-
103
- # normalized features
104
- image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=-1, keepdims=True)
105
- text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True)
106
-
107
- # cosine similarity as logits
108
- logit_scale = jnp.exp(self.logit_scale)
109
- logits_per_text = jnp.matmul(text_embeds, image_embeds.T) * logit_scale
110
- logits_per_image = logits_per_text.T
111
-
112
- if not return_dict:
113
- return (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
114
-
115
- return FlaxCLIPOutput(
116
- logits_per_image=logits_per_image,
117
- logits_per_text=logits_per_text,
118
- text_embeds=text_embeds,
119
- image_embeds=image_embeds,
120
- text_model_output=text_outputs,
121
- vision_model_output=vision_outputs,
122
- )
123
-
124
-
125
- class FlaxHybridCLIP(FlaxPreTrainedModel):
126
- config_class = HybridCLIPConfig
127
- module_class = FlaxHybridCLIPModule
128
-
129
- def __init__(
130
- self,
131
- config: HybridCLIPConfig,
132
- input_shape: Optional[Tuple] = None,
133
- seed: int = 0,
134
- dtype: jnp.dtype = jnp.float32,
135
- **kwargs
136
- ):
137
- if input_shape is None:
138
- input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
139
-
140
- module = self.module_class(config=config, dtype=dtype, **kwargs)
141
- super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
142
-
143
- def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
144
- # init input tensor
145
- input_ids = jnp.zeros(input_shape[0], dtype="i4")
146
- position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
147
- token_type_ids = jnp.ones_like(input_ids)
148
- attention_mask = jnp.ones_like(input_ids)
149
-
150
- pixel_values = jax.random.normal(rng, input_shape[1])
151
-
152
- params_rng, dropout_rng = jax.random.split(rng)
153
- rngs = {"params": params_rng, "dropout": dropout_rng}
154
-
155
- return self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids)["params"]
156
-
157
- def __call__(
158
- self,
159
- input_ids,
160
- pixel_values,
161
- attention_mask=None,
162
- position_ids=None,
163
- token_type_ids=None,
164
- params: dict = None,
165
- dropout_rng: jax.random.PRNGKey = None,
166
- train: bool = False,
167
- output_attentions: Optional[bool] = None,
168
- output_hidden_states: Optional[bool] = None,
169
- return_dict: Optional[bool] = None,
170
- ):
171
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
172
- output_hidden_states = (
173
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
174
- )
175
- return_dict = return_dict if return_dict is not None else self.config.return_dict
176
-
177
- if position_ids is None:
178
- position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
179
-
180
- if token_type_ids is None:
181
- token_type_ids = jnp.zeros_like(input_ids)
182
-
183
- if attention_mask is None:
184
- attention_mask = jnp.ones_like(input_ids)
185
-
186
- # Handle any PRNG if needed
187
- rngs = {}
188
- if dropout_rng is not None:
189
- rngs["dropout"] = dropout_rng
190
-
191
- return self.module.apply(
192
- {"params": params or self.params},
193
- jnp.array(input_ids, dtype="i4"),
194
- jnp.array(pixel_values, dtype=jnp.float32),
195
- jnp.array(attention_mask, dtype="i4"),
196
- jnp.array(position_ids, dtype="i4"),
197
- jnp.array(token_type_ids, dtype="i4"),
198
- not train,
199
- output_attentions,
200
- output_hidden_states,
201
- return_dict,
202
- rngs=rngs,
203
- )
204
-
205
- def get_text_features(
206
- self,
207
- input_ids,
208
- attention_mask=None,
209
- position_ids=None,
210
- token_type_ids=None,
211
- dropout_rng: jax.random.PRNGKey = None,
212
- train=False,
213
- ):
214
- r"""
215
- Args:
216
- input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`):
217
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
218
- provide it.
219
-
220
- Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
221
- :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
222
- for details.
223
-
224
- `What are input IDs? <../glossary.html#input-ids>`__
225
-
226
- Returns:
227
- text_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The text embeddings
228
- obtained by applying the projection layer to the pooled output of text model.
229
- """
230
- if position_ids is None:
231
- position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
232
-
233
- if token_type_ids is None:
234
- token_type_ids = jnp.zeros_like(input_ids)
235
-
236
- if attention_mask is None:
237
- attention_mask = jnp.ones_like(input_ids)
238
-
239
- # Handle any PRNG if needed
240
- rngs = {}
241
- if dropout_rng is not None:
242
- rngs["dropout"] = dropout_rng
243
-
244
- def _get_features(module, input_ids, attention_mask, position_ids, token_type_ids, deterministic):
245
- text_outputs = module.text_model(
246
- input_ids=input_ids,
247
- attention_mask=attention_mask,
248
- position_ids=position_ids,
249
- token_type_ids=token_type_ids,
250
- deterministic=deterministic,
251
- )
252
- pooled_output = text_outputs[1]
253
- text_features = module.text_projection(pooled_output)
254
- return text_features
255
-
256
- return self.module.apply(
257
- {"params": self.params},
258
- jnp.array(input_ids, dtype="i4"),
259
- jnp.array(attention_mask, dtype="i4"),
260
- jnp.array(position_ids, dtype="i4"),
261
- jnp.array(token_type_ids, dtype="i4"),
262
- not train,
263
- method=_get_features,
264
- rngs=rngs,
265
- )
266
-
267
- def get_image_features(self, pixel_values, dropout_rng: jax.random.PRNGKey = None, train=False):
268
- r"""
269
- Args:
270
- pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
271
- Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained
272
- using :class:`~transformers.ImageFeatureExtractionMixin`. See
273
- :meth:`transformers.ImageFeatureExtractionMixin.__call__` for details.
274
-
275
- Returns:
276
- image_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The image embeddings
277
- obtained by applying the projection layer to the pooled output of vision model.
278
- """
279
-
280
- # Handle any PRNG if needed
281
- rngs = {}
282
- if dropout_rng is not None:
283
- rngs["dropout"] = dropout_rng
284
-
285
- def _get_features(module, pixel_values, deterministic):
286
- vision_outputs = module.vision_model(pixel_values=pixel_values, deterministic=deterministic)
287
- pooled_output = vision_outputs[1] # pooled_output
288
- image_features = module.visual_projection(pooled_output)
289
- return image_features
290
-
291
- return self.module.apply(
292
- {"params": self.params},
293
- jnp.array(pixel_values, dtype=jnp.float32),
294
- not train,
295
- method=_get_features,
296
- rngs=rngs,
297
- )
298
-
299
- @classmethod
300
- def from_text_vision_pretrained(
301
- cls,
302
- text_model_name_or_path: str = None,
303
- vision_model_name_or_path: str = None,
304
- *model_args,
305
- **kwargs,
306
- ) -> FlaxPreTrainedModel:
307
- """
308
- Params:
309
- text_model_name_or_path (:obj: `str`, `optional`):
310
- Information necessary to initiate the text model. Can be either:
311
-
312
- - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
313
- Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
314
- a user or organization name, like ``dbmdz/bert-base-german-cased``.
315
- - A path to a `directory` containing model weights saved using
316
- :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
317
- - A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In
318
- this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
319
- as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in
320
- a Flax model using the provided conversion scripts and loading the Flax model afterwards.
321
-
322
- vision_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
323
- Information necessary to initiate the vision model. Can be either:
324
-
325
- - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
326
- Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
327
- a user or organization name, like ``dbmdz/bert-base-german-cased``.
328
- - A path to a `directory` containing model weights saved using
329
- :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
330
- - A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In
331
- this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
332
- as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in
333
- a Flax model using the provided conversion scripts and loading the Flax model afterwards.
334
-
335
- model_args (remaining positional arguments, `optional`):
336
- All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
337
-
338
- kwargs (remaining dictionary of keyword arguments, `optional`):
339
- Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
340
- :obj:`output_attentions=True`).
341
-
342
- - To update the text configuration, use the prefix `text_` for each configuration parameter.
343
- - To update the vision configuration, use the prefix `vision_` for each configuration parameter.
344
- - To update the parent model configuration, do not use a prefix for each configuration parameter.
345
-
346
- Behaves differently depending on whether a :obj:`config` is provided or automatically loaded.
347
-
348
- Example::
349
-
350
- >>> from transformers import FlaxHybridCLIP
351
- >>> # initialize a model from pretrained BERT and CLIP models. Note that the projection layers will be randomly initialized.
352
- >>> # If using CLIP's vision model the vision projection layer will be initialized using pre-trained weights
353
- >>> model = FlaxHybridCLIP.from_text_vision_pretrained('bert-base-uncased', 'openai/clip-vit-base-patch32')
354
- >>> # saving model after fine-tuning
355
- >>> model.save_pretrained("./bert-clip")
356
- >>> # load fine-tuned model
357
- >>> model = FlaxHybridCLIP.from_pretrained("./bert-clip")
358
- """
359
-
360
- kwargs_text = {
361
- argument[len("text_") :]: value for argument, value in kwargs.items() if argument.startswith("text_")
362
- }
363
-
364
- kwargs_vision = {
365
- argument[len("vision_") :]: value for argument, value in kwargs.items() if argument.startswith("vision_")
366
- }
367
-
368
- # remove text, vision kwargs from kwargs
369
- for key in kwargs_text.keys():
370
- del kwargs["text_" + key]
371
- for key in kwargs_vision.keys():
372
- del kwargs["vision_" + key]
373
-
374
- # Load and initialize the text and vision model
375
- text_model = kwargs_text.pop("model", None)
376
- if text_model is None:
377
- assert (
378
- text_model_name_or_path is not None
379
- ), "If `model` is not defined as an argument, a `text_model_name_or_path` has to be defined"
380
- from transformers import FlaxAutoModel
381
-
382
- if "config" not in kwargs_text:
383
- from transformers import AutoConfig
384
-
385
- text_config = AutoConfig.from_pretrained(text_model_name_or_path)
386
- kwargs_text["config"] = text_config
387
-
388
- text_model = FlaxAutoModel.from_pretrained(text_model_name_or_path, *model_args, **kwargs_text)
389
-
390
- vision_model = kwargs_vision.pop("model", None)
391
- if vision_model is None:
392
- assert (
393
- vision_model_name_or_path is not None
394
- ), "If `model` is not defined as an argument, a `vision_model_name_or_path` has to be defined"
395
- from transformers import FlaxAutoModel
396
-
397
- if "config" not in kwargs_vision:
398
- from transformers import AutoConfig
399
-
400
- vision_config = AutoConfig.from_pretrained(vision_model_name_or_path)
401
- kwargs_vision["config"] = vision_config
402
-
403
- vision_model = FlaxAutoModel.from_pretrained(vision_model_name_or_path, *model_args, **kwargs_vision)
404
-
405
- # instantiate config with corresponding kwargs
406
- dtype = kwargs.pop("dtype", jnp.float32)
407
- config = HybridCLIPConfig.from_text_vision_configs(text_model.config, vision_model.config, **kwargs)
408
-
409
- # init model
410
- model = cls(config, *model_args, dtype=dtype, **kwargs)
411
-
412
- if vision_config.model_type == "clip":
413
- model.params["vision_model"]["vision_model"] = vision_model.params["vision_model"]
414
- model.params["visual_projection"]["kernel"] = vision_model.params["visual_projection"]["kernel"]
415
- else:
416
- model.params["vision_model"] = vision_model.params
417
-
418
- model.params["text_model"] = text_model.params
419
-
420
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hybrid_clip/requirements.txt DELETED
@@ -1,8 +0,0 @@
1
- jax>=0.2.8
2
- jaxlib>=0.1.59
3
- flax>=0.3.4
4
- optax>=0.0.8
5
- -f https://download.pytorch.org/whl/torch_stable.html
6
- torch==1.9.0+cpu
7
- -f https://download.pytorch.org/whl/torch_stable.html
8
- torchvision==0.10.0+cpu
 
 
 
 
 
 
 
 
 
hybrid_clip/run_hybrid_clip.py DELETED
@@ -1,562 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
- # Copyright 2021 The HuggingFace Team All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
- Training a CLIP like dual encoder models using text and vision encoders in the library.
18
-
19
- The script can be used to train CLIP like models for languages other than english by using
20
- a text encoder pre-trained in the desired language. Currently this script support the following vision
21
- and text models:
22
- Vision models: ViT(https://huggingface.co/models?filter=vit), CLIP (https://huggingface.co/models?filter=clip)
23
- Text models: BERT, ROBERTa (https://huggingface.co/models?filter=masked-lm)
24
- """
25
-
26
- import json
27
- import logging
28
- import os
29
- import sys
30
- import time
31
- from dataclasses import dataclass, field
32
- from pathlib import Path
33
- from typing import Callable, Optional
34
-
35
- import torch
36
- from torchvision.datasets import VisionDataset
37
- from torchvision.io import ImageReadMode, read_image
38
- from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
39
- from torchvision.transforms.functional import InterpolationMode
40
- from tqdm import tqdm
41
-
42
- import jax
43
- import jax.numpy as jnp
44
- import optax
45
- import transformers
46
- from flax import jax_utils
47
- from flax.jax_utils import unreplicate
48
- from flax.training import train_state
49
- from flax.training.common_utils import get_metrics, shard, shard_prng_key
50
- from modeling_hybrid_clip import FlaxHybridCLIP
51
- from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, is_tensorboard_available, set_seed
52
-
53
-
54
- logger = logging.getLogger(__name__)
55
-
56
- # Cache the result
57
- has_tensorboard = is_tensorboard_available()
58
- if has_tensorboard:
59
- try:
60
- from flax.metrics.tensorboard import SummaryWriter
61
- except ImportError as ie:
62
- has_tensorboard = False
63
- print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
64
-
65
- else:
66
- print(
67
- "Unable to display metrics through TensorBoard because the package is not installed: "
68
- "Please run pip install tensorboard to enable."
69
- )
70
-
71
-
72
- @dataclass
73
- class ModelArguments:
74
- """
75
- Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
76
- """
77
-
78
- text_model_name_or_path: str = field(
79
- metadata={
80
- "help": "The text model checkpoint for weights initialization."
81
- "Don't set if you want to train a model from scratch."
82
- },
83
- )
84
- vision_model_name_or_path: str = field(
85
- metadata={
86
- "help": "The vision model checkpoint for weights initialization."
87
- "Don't set if you want to train a model from scratch."
88
- },
89
- )
90
- from_pt: bool = field(
91
- default=True,
92
- metadata={"help": "whether to load the text and vision model using PyTorch checkpoints."},
93
- )
94
- config_name: Optional[str] = field(
95
- default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
96
- )
97
- tokenizer_name: Optional[str] = field(
98
- default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
99
- )
100
- cache_dir: Optional[str] = field(
101
- default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
102
- )
103
- use_fast_tokenizer: bool = field(
104
- default=True,
105
- metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
106
- )
107
- dtype: Optional[str] = field(
108
- default="float32",
109
- metadata={
110
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
111
- },
112
- )
113
-
114
-
115
- @dataclass
116
- class DataTrainingArguments:
117
- """
118
- Arguments pertaining to what data we are going to input our model for training and eval.
119
- """
120
-
121
- data_dir: Optional[str] = field(default=None, metadata={"help": "The data directory containing input files."})
122
- train_file: Optional[str] = field(
123
- default=None, metadata={"help": "The input training data file (a jsonlines file)."}
124
- )
125
- validation_file: Optional[str] = field(
126
- default=None,
127
- metadata={"help": "An optional input evaluation data file (a jsonlines file)."},
128
- )
129
- max_seq_length: Optional[int] = field(
130
- default=72,
131
- metadata={
132
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
133
- "than this will be truncated, sequences shorter will be padded."
134
- },
135
- )
136
- max_train_samples: Optional[int] = field(
137
- default=None,
138
- metadata={
139
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
140
- "value if set."
141
- },
142
- )
143
- max_eval_samples: Optional[int] = field(
144
- default=None,
145
- metadata={
146
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
147
- "value if set."
148
- },
149
- )
150
- overwrite_cache: bool = field(
151
- default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
152
- )
153
- overwrite_cache: bool = field(
154
- default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
155
- )
156
- preprocessing_num_workers: Optional[int] = field(
157
- default=None,
158
- metadata={"help": "The number of processes to use for the preprocessing."},
159
- )
160
-
161
- def __post_init__(self):
162
- if self.train_file is None and self.validation_file is None:
163
- raise ValueError("Need either a dataset name or a training/validation file.")
164
- else:
165
- if self.train_file is not None:
166
- extension = self.train_file.split(".")[-1]
167
- assert extension == "json", "`train_file` should be a json file."
168
- if self.validation_file is not None:
169
- extension = self.validation_file.split(".")[-1]
170
- assert extension == "json", "`validation_file` should be a json file."
171
-
172
-
173
- # We use torchvision for faster image pre-processing.
174
- # We need to ensure faster processing speed as it can become a bottleneck on TPU
175
- class Transform(torch.nn.Module):
176
- def __init__(self, image_size):
177
- super().__init__()
178
- self.transforms = torch.nn.Sequential(
179
- Resize([image_size], interpolation=InterpolationMode.BICUBIC),
180
- CenterCrop(image_size),
181
- ConvertImageDtype(torch.float),
182
- Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
183
- )
184
-
185
- def forward(self, x: torch.Tensor) -> torch.Tensor:
186
- with torch.no_grad():
187
- x = self.transforms(x)
188
- return x
189
-
190
-
191
- class ImageTextDataset(VisionDataset):
192
- """
193
- Dtaset for loading image-text data for tasks like CLIP training, Image Captioning.
194
-
195
- Args:
196
- root: (string): The root path where the dataset is stored
197
- file_path: (string): Path to the file containing the image_paths and associated captions.
198
- The expected format is jsonlines where each line is a json object containing to keys.
199
- `image_path`: The path to the image.
200
- `captions`: An `array` of captions.
201
- transform (callable, optional): A function/transform that takes in an PIL image
202
- and returns a transformed version. E.g, ``transforms.ToTensor``
203
- target_transform (callable, optional): A function/transform that takes in the
204
- target and transforms it.
205
- transforms (callable, optional): A function/transform that takes input sample and its target as entry
206
- and returns a transformed version.
207
- """
208
-
209
- def __init__(
210
- self,
211
- root: str,
212
- file_path: str,
213
- captions_per_image=2,
214
- transform: Optional[Callable] = None,
215
- target_transform: Optional[Callable] = None,
216
- transforms: Optional[Callable] = None,
217
- ):
218
- super().__init__(root, transforms, transform, target_transform)
219
-
220
- with open(file_path, "r") as f:
221
- examples = [json.loads(line) for line in f.readlines()]
222
-
223
- self.captions = []
224
- self.image_paths = []
225
-
226
- for example in examples:
227
- self.captions.extend(example["captions"][:captions_per_image])
228
- self.image_paths.extend([example["image_path"]] * captions_per_image)
229
-
230
- def _load_image(self, idx: int):
231
- path = self.image_paths[idx]
232
- return read_image(path, mode=ImageReadMode.RGB)
233
-
234
- def _load_target(self, idx):
235
- return self.captions[idx]
236
-
237
- def __getitem__(self, index: int):
238
- image = self._load_image(index)
239
- target = self._load_target(index)
240
-
241
- if self.transforms is not None:
242
- image, target = self.transforms(image, target)
243
-
244
- return image, target
245
-
246
- def __len__(self) -> int:
247
- return len(self.captions)
248
-
249
-
250
- class TrainState(train_state.TrainState):
251
- dropout_rng: jnp.ndarray
252
-
253
- def replicate(self):
254
- return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
255
-
256
-
257
- def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
258
- summary_writer.scalar("train_time", train_time, step)
259
-
260
- train_metrics = get_metrics(train_metrics)
261
- for key, vals in train_metrics.items():
262
- tag = f"train_{key}"
263
- for i, val in enumerate(vals):
264
- summary_writer.scalar(tag, val, step - len(vals) + i + 1)
265
-
266
- for metric_name, value in eval_metrics.items():
267
- summary_writer.scalar(f"eval_{metric_name}", value, step)
268
-
269
-
270
- def create_learning_rate_fn(
271
- train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
272
- ) -> Callable[[int], jnp.array]:
273
- """Returns a linear warmup, linear_decay learning rate function."""
274
- steps_per_epoch = train_ds_size // train_batch_size
275
- num_train_steps = steps_per_epoch * num_train_epochs
276
- warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
277
- decay_fn = optax.linear_schedule(
278
- init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
279
- )
280
- schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
281
- return schedule_fn
282
-
283
-
284
- def main():
285
- parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
286
- if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
287
- # If we pass only one argument to the script and it's the path to a json file,
288
- # let's parse it to get our arguments.
289
- model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
290
- else:
291
- model_args, data_args, training_args = parser.parse_args_into_dataclasses()
292
-
293
- if (
294
- os.path.exists(training_args.output_dir)
295
- and os.listdir(training_args.output_dir)
296
- and training_args.do_train
297
- and not training_args.overwrite_output_dir
298
- ):
299
- raise ValueError(
300
- f"Output directory ({training_args.output_dir}) already exists and is not empty."
301
- "Use --overwrite_output_dir to overcome."
302
- )
303
-
304
- # Make one log on every process with the configuration for debugging.
305
- logging.basicConfig(
306
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
307
- datefmt="%m/%d/%Y %H:%M:%S",
308
- level=logging.INFO,
309
- )
310
- # Setup logging, we only want one process per machine to log things on the screen.
311
- logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
312
- if jax.process_index() == 0:
313
- transformers.utils.logging.set_verbosity_info()
314
- else:
315
- transformers.utils.logging.set_verbosity_error()
316
-
317
- # Set the verbosity to info of the Transformers logger (on main process only):
318
- logger.info(f"Training/evaluation parameters {training_args}")
319
-
320
- if model_args.tokenizer_name:
321
- tokenizer = AutoTokenizer.from_pretrained(
322
- model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
323
- )
324
- elif model_args.text_model_name_or_path:
325
- tokenizer = AutoTokenizer.from_pretrained(
326
- model_args.text_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
327
- )
328
- else:
329
- raise ValueError(
330
- "You are instantiating a new tokenizer from scratch. This is not supported by this script."
331
- "You can do it from another script, save it, and load it from here, using --tokenizer_name."
332
- )
333
-
334
- model = FlaxHybridCLIP.from_text_vision_pretrained(
335
- model_args.text_model_name_or_path,
336
- model_args.vision_model_name_or_path,
337
- seed=training_args.seed,
338
- dtype=getattr(jnp, model_args.dtype),
339
- text_from_pt=model_args.from_pt,
340
- vision_from_pt=model_args.from_pt,
341
- )
342
- config = model.config
343
- # set seed for torch dataloaders
344
- set_seed(training_args.seed)
345
-
346
- # Initialize torchvision transforms and jit them for faster processing
347
- preprocess = Transform(config.vision_config.image_size)
348
- preprocess = torch.jit.script(preprocess)
349
-
350
- # Initialize the image-text dataset
351
- train_dataset = ImageTextDataset(
352
- data_args.data_dir,
353
- data_args.train_file,
354
- captions_per_image=2,
355
- transform=preprocess,
356
- )
357
-
358
- eval_dataset = ImageTextDataset(
359
- data_args.data_dir,
360
- data_args.validation_file,
361
- captions_per_image=1,
362
- transform=preprocess,
363
- )
364
-
365
- # Store some constant
366
- num_epochs = int(training_args.num_train_epochs)
367
- train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
368
- eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
369
- steps_per_epoch = len(train_dataset) // train_batch_size
370
- total_train_steps = steps_per_epoch * num_epochs
371
-
372
- # Use collate function to tokenizer the text and convert the processed images to numpy
373
- def collate_fn(examples):
374
- pixel_values = torch.stack([example[0] for example in examples]).permute(0, 2, 3, 1).numpy()
375
- captions = [example[1] for example in examples]
376
- inputs = tokenizer(captions, max_length=data_args.max_seq_length, padding="max_length", return_tensors="np")
377
-
378
- batch = {
379
- "pixel_values": pixel_values,
380
- "input_ids": inputs["input_ids"],
381
- "attention_mask": inputs["attention_mask"],
382
- }
383
-
384
- return batch
385
-
386
- # Create data loaders
387
- train_loader = torch.utils.data.DataLoader(
388
- train_dataset,
389
- batch_size=train_batch_size,
390
- shuffle=True,
391
- num_workers=data_args.preprocessing_num_workers,
392
- persistent_workers=True,
393
- drop_last=True,
394
- collate_fn=collate_fn,
395
- )
396
-
397
- eval_loader = torch.utils.data.DataLoader(
398
- eval_dataset,
399
- batch_size=eval_batch_size,
400
- shuffle=False,
401
- num_workers=data_args.preprocessing_num_workers,
402
- persistent_workers=True,
403
- drop_last=True,
404
- collate_fn=collate_fn,
405
- )
406
-
407
- # Enable tensorboard only on the master node
408
- if has_tensorboard and jax.process_index() == 0:
409
- summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix())
410
-
411
- # Initialize our training
412
- rng = jax.random.PRNGKey(training_args.seed)
413
- rng, dropout_rng = jax.random.split(rng)
414
-
415
- # Create learning rate schedule
416
- linear_decay_lr_schedule_fn = create_learning_rate_fn(
417
- len(train_dataset),
418
- train_batch_size,
419
- training_args.num_train_epochs,
420
- training_args.warmup_steps,
421
- training_args.learning_rate,
422
- )
423
-
424
- # create adam optimizer
425
- adamw = optax.adamw(
426
- learning_rate=linear_decay_lr_schedule_fn,
427
- b1=training_args.adam_beta1,
428
- b2=training_args.adam_beta2,
429
- eps=training_args.adam_epsilon,
430
- weight_decay=training_args.weight_decay,
431
- )
432
-
433
- # Setup train state
434
- state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
435
-
436
- def cross_entropy(logits, axis):
437
- logprobs = jax.nn.log_softmax(logits, axis=axis)
438
- nll = jnp.diag(logprobs)
439
- ce = -jnp.mean(nll)
440
- return ce
441
-
442
- def clip_loss(similarity):
443
- loss = (cross_entropy(similarity, axis=0) + cross_entropy(similarity, axis=1)) / 2
444
- return loss
445
-
446
- # Define gradient update step fn
447
- def train_step(state, batch):
448
- dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
449
-
450
- def compute_loss(params):
451
- logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
452
- loss = clip_loss(logits)
453
- return loss
454
-
455
- grad_fn = jax.value_and_grad(compute_loss)
456
- loss, grad = grad_fn(state.params)
457
- grad = jax.lax.pmean(grad, "batch")
458
-
459
- new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
460
-
461
- metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
462
- metrics = jax.lax.pmean(metrics, axis_name="batch")
463
-
464
- return new_state, metrics
465
-
466
- # Define eval fn
467
- def eval_step(params, batch):
468
- logits = model(**batch, params=params, train=False)[0]
469
- loss = clip_loss(logits)
470
-
471
- # summarize metrics
472
- metrics = {"loss": loss}
473
- metrics = jax.lax.pmean(metrics, axis_name="batch")
474
- return metrics
475
-
476
- # Create parallel version of the train and eval step
477
- p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
478
- p_eval_step = jax.pmap(eval_step, "batch")
479
-
480
- # Replicate the train state on each device
481
- state = state.replicate()
482
-
483
- logger.info("***** Running training *****")
484
- logger.info(f" Num examples = {len(train_dataset)}")
485
- logger.info(f" Num Epochs = {num_epochs}")
486
- logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
487
- logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
488
- logger.info(f" Total optimization steps = {total_train_steps}")
489
-
490
- train_time = 0
491
- # Create sampling rng
492
- rng, input_rng = jax.random.split(rng)
493
-
494
- epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
495
- for epoch in epochs:
496
- # ======================== Training ================================
497
- train_start = time.time()
498
-
499
- # Create sampling rng
500
- rng, input_rng = jax.random.split(rng)
501
- train_metrics = []
502
-
503
- steps_per_epoch = len(train_dataset) // train_batch_size
504
- train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
505
- # train
506
- for batch in train_loader:
507
- batch = shard(batch)
508
- state, train_metric = p_train_step(state, batch)
509
- train_metrics.append(train_metric)
510
-
511
- train_step_progress_bar.update(1)
512
-
513
- train_time += time.time() - train_start
514
-
515
- train_metric = unreplicate(train_metric)
516
-
517
- train_step_progress_bar.close()
518
- epochs.write(
519
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
520
- )
521
-
522
- # ======================== Evaluating ==============================
523
- eval_metrics = []
524
- eval_steps = len(eval_dataset) // eval_batch_size
525
- eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False)
526
- for batch in eval_loader:
527
- # Model forward
528
- batch = shard(batch)
529
- metrics = p_eval_step(state.params, batch)
530
- eval_metrics.append(metrics)
531
-
532
- eval_step_progress_bar.update(1)
533
-
534
- # normalize eval metrics
535
- eval_metrics = get_metrics(eval_metrics)
536
-
537
- eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
538
-
539
- # Print metrics and update progress bar
540
- eval_step_progress_bar.close()
541
- desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
542
- epochs.write(desc)
543
- epochs.desc = desc
544
-
545
- # Save metrics
546
- if has_tensorboard and jax.process_index() == 0:
547
- cur_step = epoch * (len(train_dataset) // train_batch_size)
548
- write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
549
-
550
- # save checkpoint after each epoch and push checkpoint to the hub
551
- if jax.process_index() == 0:
552
- params = jax.device_get(unreplicate(state.params))
553
- model.save_pretrained(
554
- training_args.output_dir,
555
- params=params,
556
- push_to_hub=training_args.push_to_hub,
557
- commit_message=f"Saving weights and logs of epoch {epoch+1}",
558
- )
559
-
560
-
561
- if __name__ == "__main__":
562
- main()