Spaces:
Runtime error
Runtime error
import gradio as gr | |
import requests | |
from datasets import load_dataset | |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification | |
from transformers import ViTFeatureExtractor, ViTForImageClassification | |
from PIL import Image | |
import requests | |
extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224") | |
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224") | |
dataset = load_dataset("hamdan07/UltraSound-lung") | |
image = Image.open(requests.get(dataset, stream=True).raw) | |
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224') | |
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') | |
inputs = feature_extractor(images=image, return_tensors="pt") | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
print("Predicted class:", model.config.id2label[predicted_class_idx]) | |
API_URL = "https://api-inference.huggingface.co/models/hamdan07/UltraSound-Lung" | |
headers = {"Authorization": "Bearer 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"} | |
# Clone repository and pull latest changes. | |
![ -d vision_transformer ] || git clone --depth=1 https://github.com/google-research/vision_transformer | |
!cd vision_transformer && git pull | |
# Helper functions for images. | |
labelnames = dict( | |
# https://www.cs.toronto.edu/~kriz/cifar.html | |
cifar10=('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'), | |
# https://www.cs.toronto.edu/~kriz/cifar.html | |
cifar100=('apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'computer_keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm') | |
) | |
def make_label_getter(dataset): | |
"""Returns a function converting label indices to names.""" | |
def getter(label): | |
if dataset in labelnames: | |
return labelnames[dataset][label] | |
return f'label={label}' | |
return getter | |
def show_img(img, ax=None, title=None): | |
"""Shows a single image.""" | |
if ax is None: | |
ax = plt.gca() | |
ax.imshow(img[...]) | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
if title: | |
ax.set_title(title) | |
def show_img_grid(imgs, titles): | |
"""Shows a grid of images.""" | |
n = int(np.ceil(len(imgs)**.5)) | |
_, axs = plt.subplots(n, n, figsize=(3 * n, 3 * n)) | |
for i, (img, title) in enumerate(zip(imgs, titles)): | |
img = (img + 1) / 2 # Denormalize | |
show_img(img, axs[i // n][i % n], title) | |
# For details about setting up datasets, see input_pipeline.py on the right. | |
ds_train = input_pipeline.get_data_from_tfds(config=config, mode='train') | |
ds_test = input_pipeline.get_data_from_tfds(config=config, mode='test') | |
del config # Only needed to instantiate datasets. | |
# Fetch a batch of test images for illustration purposes. | |
batch = next(iter(ds_test.as_numpy_iterator())) | |
# Note the shape : [num_local_devices, local_batch_size, h, w, c] | |
batch['image'].shape | |
# Show some imags with their labels. | |
images, labels = batch['image'][0][:9], batch['label'][0][:9] | |
titles = map(make_label_getter(dataset), labels.argmax(axis=1)) | |
show_img_grid(images, titles) | |
# Same as above, but with train images. | |
# Note how images are cropped/scaled differently. | |
# Check out input_pipeline.get_data() in the editor at your right to see how the | |
# images are preprocessed differently. | |
batch = next(iter(ds_train.as_numpy_iterator())) | |
images, labels = batch['image'][0][:9], batch['label'][0][:9] | |
titles = map(make_label_getter(dataset), labels.argmax(axis=1)) | |
show_img_grid(images, titles) | |
model_config = models_config.MODEL_CONFIGS[model_name] | |
model_config | |
# Load model definition & initialize random parameters. | |
# This also compiles the model to XLA (takes some minutes the first time). | |
if model_name.startswith('Mixer'): | |
model = models.MlpMixer(num_classes=num_classes, **model_config) | |
else: | |
model = models.VisionTransformer(num_classes=num_classes, **model_config) | |
variables = jax.jit(lambda: model.init( | |
jax.random.PRNGKey(0), | |
# Discard the "num_local_devices" dimension of the batch for initialization. | |
batch['image'][0, :1], | |
train=False, | |
), backend='cpu')() | |
# Load and convert pretrained checkpoint. | |
# This involves loading the actual pre-trained model results, but then also also | |
# modifying the parameters a bit, e.g. changing the final layers, and resizing | |
# the positional embeddings. | |
# For details, refer to the code and to the methods of the paper. | |
params = checkpoint.load_pretrained( | |
pretrained_path=f'{model_name}.npz', | |
init_params=variables['params'], | |
model_config=model_config, | |
) | |
# So far, all our data is in the host memory. Let's now replicate the arrays | |
# into the devices. | |
# This will make every array in the pytree params become a ShardedDeviceArray | |
# that has the same data replicated across all local devices. | |
# For TPU it replicates the params in every core. | |
# For a single GPU this simply moves the data onto the device. | |
# For CPU it simply creates a copy. | |
params_repl = flax.jax_utils.replicate(params) | |
print('params.cls:', type(params['head']['bias']).__name__, | |
params['head']['bias'].shape) | |
print('params_repl.cls:', type(params_repl['head']['bias']).__name__, | |
params_repl['head']['bias'].shape) | |
# Then map the call to our model's forward pass onto all available devices. | |
vit_apply_repl = jax.pmap(lambda params, inputs: model.apply( | |
dict(params=params), inputs, train=False)) | |
def get_accuracy(params_repl): | |
"""Returns accuracy evaluated on the test set.""" | |
good = total = 0 | |
steps = input_pipeline.get_dataset_info(dataset, 'test')['num_examples'] // batch_size | |
for _, batch in zip(tqdm.trange(steps), ds_test.as_numpy_iterator()): | |
predicted = vit_apply_repl(params_repl, batch['image']) | |
is_same = predicted.argmax(axis=-1) == batch['label'].argmax(axis=-1) | |
good += is_same.sum() | |
total += len(is_same.flatten()) | |
return good / total | |
# Random performance without fine-tuning. | |
get_accuracy(params_repl) | |
# 100 Steps take approximately 15 minutes in the TPU runtime. | |
total_steps = 100 | |
warmup_steps = 5 | |
decay_type = 'cosine' | |
grad_norm_clip = 1 | |
# This controls in how many forward passes the batch is split. 8 works well with | |
# a TPU runtime that has 8 devices. 64 should work on a GPU. You can of course | |
# also adjust the batch_size above, but that would require you to adjust the | |
# learning rate accordingly. | |
accum_steps = 8 | |
base_lr = 0.03 | |
# Check out train.make_update_fn in the editor on the right side for details. | |
lr_fn = utils.create_learning_rate_schedule(total_steps, base_lr, decay_type, warmup_steps) | |
# We use a momentum optimizer that uses half precision for state to save | |
# memory. It als implements the gradient clipping. | |
tx = optax.chain( | |
optax.clip_by_global_norm(grad_norm_clip), | |
optax.sgd( | |
learning_rate=lr_fn, | |
momentum=0.9, | |
accumulator_dtype='bfloat16', | |
), | |
) | |
update_fn_repl = train.make_update_fn( | |
apply_fn=model.apply, accum_steps=accum_steps, tx=tx) | |
opt_state = tx.init(params) | |
opt_state_repl = flax.jax_utils.replicate(opt_state) | |
# Initialize PRNGs for dropout. | |
update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0)) | |