hamdanhh07 commited on
Commit
5f07799
1 Parent(s): 9f308b5

Update UltraSound-Lung.py

Browse files
Files changed (1) hide show
  1. models/hamdan07/UltraSound-Lung.py +146 -0
models/hamdan07/UltraSound-Lung.py CHANGED
@@ -19,3 +19,149 @@ predicted_class_idx = logits.argmax(-1).item()
19
  print("Predicted class:", model.config.id2label[predicted_class_idx])
20
  API_URL = "https://api-inference.huggingface.co/models/hamdan07/UltraSound-Lung"
21
  headers = {"Authorization": "Bearer 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"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  print("Predicted class:", model.config.id2label[predicted_class_idx])
20
  API_URL = "https://api-inference.huggingface.co/models/hamdan07/UltraSound-Lung"
21
  headers = {"Authorization": "Bearer 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"}
22
+ # Clone repository and pull latest changes.
23
+ ![ -d vision_transformer ] || git clone --depth=1 https://github.com/google-research/vision_transformer
24
+ !cd vision_transformer && git pull
25
+
26
+ # Helper functions for images.
27
+
28
+ labelnames = dict(
29
+ # https://www.cs.toronto.edu/~kriz/cifar.html
30
+ cifar10=('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'),
31
+ # https://www.cs.toronto.edu/~kriz/cifar.html
32
+ cifar100=('apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'computer_keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm')
33
+ )
34
+ def make_label_getter(dataset):
35
+ """Returns a function converting label indices to names."""
36
+ def getter(label):
37
+ if dataset in labelnames:
38
+ return labelnames[dataset][label]
39
+ return f'label={label}'
40
+ return getter
41
+
42
+ def show_img(img, ax=None, title=None):
43
+ """Shows a single image."""
44
+ if ax is None:
45
+ ax = plt.gca()
46
+ ax.imshow(img[...])
47
+ ax.set_xticks([])
48
+ ax.set_yticks([])
49
+ if title:
50
+ ax.set_title(title)
51
+
52
+ def show_img_grid(imgs, titles):
53
+ """Shows a grid of images."""
54
+ n = int(np.ceil(len(imgs)**.5))
55
+ _, axs = plt.subplots(n, n, figsize=(3 * n, 3 * n))
56
+ for i, (img, title) in enumerate(zip(imgs, titles)):
57
+ img = (img + 1) / 2 # Denormalize
58
+ show_img(img, axs[i // n][i % n], title)
59
+
60
+ # For details about setting up datasets, see input_pipeline.py on the right.
61
+ ds_train = input_pipeline.get_data_from_tfds(config=config, mode='train')
62
+ ds_test = input_pipeline.get_data_from_tfds(config=config, mode='test')
63
+
64
+ del config # Only needed to instantiate datasets.
65
+
66
+ # Fetch a batch of test images for illustration purposes.
67
+ batch = next(iter(ds_test.as_numpy_iterator()))
68
+ # Note the shape : [num_local_devices, local_batch_size, h, w, c]
69
+ batch['image'].shape
70
+
71
+ # Show some imags with their labels.
72
+ images, labels = batch['image'][0][:9], batch['label'][0][:9]
73
+ titles = map(make_label_getter(dataset), labels.argmax(axis=1))
74
+ show_img_grid(images, titles)
75
+
76
+ # Same as above, but with train images.
77
+ # Note how images are cropped/scaled differently.
78
+ # Check out input_pipeline.get_data() in the editor at your right to see how the
79
+ # images are preprocessed differently.
80
+ batch = next(iter(ds_train.as_numpy_iterator()))
81
+ images, labels = batch['image'][0][:9], batch['label'][0][:9]
82
+ titles = map(make_label_getter(dataset), labels.argmax(axis=1))
83
+ show_img_grid(images, titles)
84
+
85
+ model_config = models_config.MODEL_CONFIGS[model_name]
86
+ model_config
87
+
88
+ # Load model definition & initialize random parameters.
89
+ # This also compiles the model to XLA (takes some minutes the first time).
90
+ if model_name.startswith('Mixer'):
91
+ model = models.MlpMixer(num_classes=num_classes, **model_config)
92
+ else:
93
+ model = models.VisionTransformer(num_classes=num_classes, **model_config)
94
+ variables = jax.jit(lambda: model.init(
95
+ jax.random.PRNGKey(0),
96
+ # Discard the "num_local_devices" dimension of the batch for initialization.
97
+ batch['image'][0, :1],
98
+ train=False,
99
+ ), backend='cpu')()
100
+
101
+ # Load and convert pretrained checkpoint.
102
+ # This involves loading the actual pre-trained model results, but then also also
103
+ # modifying the parameters a bit, e.g. changing the final layers, and resizing
104
+ # the positional embeddings.
105
+ # For details, refer to the code and to the methods of the paper.
106
+ params = checkpoint.load_pretrained(
107
+ pretrained_path=f'{model_name}.npz',
108
+ init_params=variables['params'],
109
+ model_config=model_config,
110
+ )
111
+
112
+ # So far, all our data is in the host memory. Let's now replicate the arrays
113
+ # into the devices.
114
+ # This will make every array in the pytree params become a ShardedDeviceArray
115
+ # that has the same data replicated across all local devices.
116
+ # For TPU it replicates the params in every core.
117
+ # For a single GPU this simply moves the data onto the device.
118
+ # For CPU it simply creates a copy.
119
+ params_repl = flax.jax_utils.replicate(params)
120
+ print('params.cls:', type(params['head']['bias']).__name__,
121
+ params['head']['bias'].shape)
122
+ print('params_repl.cls:', type(params_repl['head']['bias']).__name__,
123
+ params_repl['head']['bias'].shape)
124
+ # Then map the call to our model's forward pass onto all available devices.
125
+ vit_apply_repl = jax.pmap(lambda params, inputs: model.apply(
126
+ dict(params=params), inputs, train=False))
127
+ def get_accuracy(params_repl):
128
+ """Returns accuracy evaluated on the test set."""
129
+ good = total = 0
130
+ steps = input_pipeline.get_dataset_info(dataset, 'test')['num_examples'] // batch_size
131
+ for _, batch in zip(tqdm.trange(steps), ds_test.as_numpy_iterator()):
132
+ predicted = vit_apply_repl(params_repl, batch['image'])
133
+ is_same = predicted.argmax(axis=-1) == batch['label'].argmax(axis=-1)
134
+ good += is_same.sum()
135
+ total += len(is_same.flatten())
136
+ return good / total
137
+ # Random performance without fine-tuning.
138
+ get_accuracy(params_repl)
139
+ # 100 Steps take approximately 15 minutes in the TPU runtime.
140
+ total_steps = 100
141
+ warmup_steps = 5
142
+ decay_type = 'cosine'
143
+ grad_norm_clip = 1
144
+ # This controls in how many forward passes the batch is split. 8 works well with
145
+ # a TPU runtime that has 8 devices. 64 should work on a GPU. You can of course
146
+ # also adjust the batch_size above, but that would require you to adjust the
147
+ # learning rate accordingly.
148
+ accum_steps = 8
149
+ base_lr = 0.03
150
+ # Check out train.make_update_fn in the editor on the right side for details.
151
+ lr_fn = utils.create_learning_rate_schedule(total_steps, base_lr, decay_type, warmup_steps)
152
+ # We use a momentum optimizer that uses half precision for state to save
153
+ # memory. It als implements the gradient clipping.
154
+ tx = optax.chain(
155
+ optax.clip_by_global_norm(grad_norm_clip),
156
+ optax.sgd(
157
+ learning_rate=lr_fn,
158
+ momentum=0.9,
159
+ accumulator_dtype='bfloat16',
160
+ ),
161
+ )
162
+ update_fn_repl = train.make_update_fn(
163
+ apply_fn=model.apply, accum_steps=accum_steps, tx=tx)
164
+ opt_state = tx.init(params)
165
+ opt_state_repl = flax.jax_utils.replicate(opt_state)
166
+ # Initialize PRNGs for dropout.
167
+ update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0))