import gradio as gr import requests from datasets import load_dataset from transformers import AutoFeatureExtractor, AutoModelForImageClassification from transformers import ViTFeatureExtractor, ViTForImageClassification from PIL import Image import requests extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224") model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224") dataset = load_dataset("hamdan07/UltraSound-lung") image = Image.open(requests.get(dataset, stream=True).raw) feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224') model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') inputs = feature_extractor(images=image, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() print("Predicted class:", model.config.id2label[predicted_class_idx]) API_URL = "https://api-inference.huggingface.co/models/hamdan07/UltraSound-Lung" headers = {"Authorization": "Bearer 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"} # Clone repository and pull latest changes. ![ -d vision_transformer ] || git clone --depth=1 https://github.com/google-research/vision_transformer !cd vision_transformer && git pull # Helper functions for images. labelnames = dict( # https://www.cs.toronto.edu/~kriz/cifar.html cifar10=('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'), # https://www.cs.toronto.edu/~kriz/cifar.html cifar100=('apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'computer_keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm') ) def make_label_getter(dataset): """Returns a function converting label indices to names.""" def getter(label): if dataset in labelnames: return labelnames[dataset][label] return f'label={label}' return getter def show_img(img, ax=None, title=None): """Shows a single image.""" if ax is None: ax = plt.gca() ax.imshow(img[...]) ax.set_xticks([]) ax.set_yticks([]) if title: ax.set_title(title) def show_img_grid(imgs, titles): """Shows a grid of images.""" n = int(np.ceil(len(imgs)**.5)) _, axs = plt.subplots(n, n, figsize=(3 * n, 3 * n)) for i, (img, title) in enumerate(zip(imgs, titles)): img = (img + 1) / 2 # Denormalize show_img(img, axs[i // n][i % n], title) # For details about setting up datasets, see input_pipeline.py on the right. ds_train = input_pipeline.get_data_from_tfds(config=config, mode='train') ds_test = input_pipeline.get_data_from_tfds(config=config, mode='test') del config # Only needed to instantiate datasets. # Fetch a batch of test images for illustration purposes. batch = next(iter(ds_test.as_numpy_iterator())) # Note the shape : [num_local_devices, local_batch_size, h, w, c] batch['image'].shape # Show some imags with their labels. images, labels = batch['image'][0][:9], batch['label'][0][:9] titles = map(make_label_getter(dataset), labels.argmax(axis=1)) show_img_grid(images, titles) # Same as above, but with train images. # Note how images are cropped/scaled differently. # Check out input_pipeline.get_data() in the editor at your right to see how the # images are preprocessed differently. batch = next(iter(ds_train.as_numpy_iterator())) images, labels = batch['image'][0][:9], batch['label'][0][:9] titles = map(make_label_getter(dataset), labels.argmax(axis=1)) show_img_grid(images, titles) model_config = models_config.MODEL_CONFIGS[model_name] model_config # Load model definition & initialize random parameters. # This also compiles the model to XLA (takes some minutes the first time). if model_name.startswith('Mixer'): model = models.MlpMixer(num_classes=num_classes, **model_config) else: model = models.VisionTransformer(num_classes=num_classes, **model_config) variables = jax.jit(lambda: model.init( jax.random.PRNGKey(0), # Discard the "num_local_devices" dimension of the batch for initialization. batch['image'][0, :1], train=False, ), backend='cpu')() # Load and convert pretrained checkpoint. # This involves loading the actual pre-trained model results, but then also also # modifying the parameters a bit, e.g. changing the final layers, and resizing # the positional embeddings. # For details, refer to the code and to the methods of the paper. params = checkpoint.load_pretrained( pretrained_path=f'{model_name}.npz', init_params=variables['params'], model_config=model_config, ) # So far, all our data is in the host memory. Let's now replicate the arrays # into the devices. # This will make every array in the pytree params become a ShardedDeviceArray # that has the same data replicated across all local devices. # For TPU it replicates the params in every core. # For a single GPU this simply moves the data onto the device. # For CPU it simply creates a copy. params_repl = flax.jax_utils.replicate(params) print('params.cls:', type(params['head']['bias']).__name__, params['head']['bias'].shape) print('params_repl.cls:', type(params_repl['head']['bias']).__name__, params_repl['head']['bias'].shape) # Then map the call to our model's forward pass onto all available devices. vit_apply_repl = jax.pmap(lambda params, inputs: model.apply( dict(params=params), inputs, train=False)) def get_accuracy(params_repl): """Returns accuracy evaluated on the test set.""" good = total = 0 steps = input_pipeline.get_dataset_info(dataset, 'test')['num_examples'] // batch_size for _, batch in zip(tqdm.trange(steps), ds_test.as_numpy_iterator()): predicted = vit_apply_repl(params_repl, batch['image']) is_same = predicted.argmax(axis=-1) == batch['label'].argmax(axis=-1) good += is_same.sum() total += len(is_same.flatten()) return good / total # Random performance without fine-tuning. get_accuracy(params_repl) # 100 Steps take approximately 15 minutes in the TPU runtime. total_steps = 100 warmup_steps = 5 decay_type = 'cosine' grad_norm_clip = 1 # This controls in how many forward passes the batch is split. 8 works well with # a TPU runtime that has 8 devices. 64 should work on a GPU. You can of course # also adjust the batch_size above, but that would require you to adjust the # learning rate accordingly. accum_steps = 8 base_lr = 0.03 # Check out train.make_update_fn in the editor on the right side for details. lr_fn = utils.create_learning_rate_schedule(total_steps, base_lr, decay_type, warmup_steps) # We use a momentum optimizer that uses half precision for state to save # memory. It als implements the gradient clipping. tx = optax.chain( optax.clip_by_global_norm(grad_norm_clip), optax.sgd( learning_rate=lr_fn, momentum=0.9, accumulator_dtype='bfloat16', ), ) update_fn_repl = train.make_update_fn( apply_fn=model.apply, accum_steps=accum_steps, tx=tx) opt_state = tx.init(params) opt_state_repl = flax.jax_utils.replicate(opt_state) # Initialize PRNGs for dropout. update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0))