UltraSound-Lung / models /hamdan07 /UӏtraSound-Lung.py
hamdanhh07
Rename UltraSound-Lung.py to UӏtraSound-Lung.py
555deeb unverified
import gradio as gr
import requests
from datasets import load_dataset
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import requests
extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")
dataset = load_dataset("hamdan07/UltraSound-lung")
image = Image.open(requests.get(dataset, stream=True).raw)
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
API_URL = "https://api-inference.huggingface.co/models/hamdan07/UltraSound-Lung"
headers = {"Authorization": "Bearer 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"}
# Clone repository and pull latest changes.
![ -d vision_transformer ] || git clone --depth=1 https://github.com/google-research/vision_transformer
!cd vision_transformer && git pull
# Helper functions for images.
labelnames = dict(
# https://www.cs.toronto.edu/~kriz/cifar.html
cifar10=('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'),
# https://www.cs.toronto.edu/~kriz/cifar.html
cifar100=('apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'computer_keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm')
)
def make_label_getter(dataset):
"""Returns a function converting label indices to names."""
def getter(label):
if dataset in labelnames:
return labelnames[dataset][label]
return f'label={label}'
return getter
def show_img(img, ax=None, title=None):
"""Shows a single image."""
if ax is None:
ax = plt.gca()
ax.imshow(img[...])
ax.set_xticks([])
ax.set_yticks([])
if title:
ax.set_title(title)
def show_img_grid(imgs, titles):
"""Shows a grid of images."""
n = int(np.ceil(len(imgs)**.5))
_, axs = plt.subplots(n, n, figsize=(3 * n, 3 * n))
for i, (img, title) in enumerate(zip(imgs, titles)):
img = (img + 1) / 2 # Denormalize
show_img(img, axs[i // n][i % n], title)
# For details about setting up datasets, see input_pipeline.py on the right.
ds_train = input_pipeline.get_data_from_tfds(config=config, mode='train')
ds_test = input_pipeline.get_data_from_tfds(config=config, mode='test')
del config # Only needed to instantiate datasets.
# Fetch a batch of test images for illustration purposes.
batch = next(iter(ds_test.as_numpy_iterator()))
# Note the shape : [num_local_devices, local_batch_size, h, w, c]
batch['image'].shape
# Show some imags with their labels.
images, labels = batch['image'][0][:9], batch['label'][0][:9]
titles = map(make_label_getter(dataset), labels.argmax(axis=1))
show_img_grid(images, titles)
# Same as above, but with train images.
# Note how images are cropped/scaled differently.
# Check out input_pipeline.get_data() in the editor at your right to see how the
# images are preprocessed differently.
batch = next(iter(ds_train.as_numpy_iterator()))
images, labels = batch['image'][0][:9], batch['label'][0][:9]
titles = map(make_label_getter(dataset), labels.argmax(axis=1))
show_img_grid(images, titles)
model_config = models_config.MODEL_CONFIGS[model_name]
model_config
# Load model definition & initialize random parameters.
# This also compiles the model to XLA (takes some minutes the first time).
if model_name.startswith('Mixer'):
model = models.MlpMixer(num_classes=num_classes, **model_config)
else:
model = models.VisionTransformer(num_classes=num_classes, **model_config)
variables = jax.jit(lambda: model.init(
jax.random.PRNGKey(0),
# Discard the "num_local_devices" dimension of the batch for initialization.
batch['image'][0, :1],
train=False,
), backend='cpu')()
# Load and convert pretrained checkpoint.
# This involves loading the actual pre-trained model results, but then also also
# modifying the parameters a bit, e.g. changing the final layers, and resizing
# the positional embeddings.
# For details, refer to the code and to the methods of the paper.
params = checkpoint.load_pretrained(
pretrained_path=f'{model_name}.npz',
init_params=variables['params'],
model_config=model_config,
)
# So far, all our data is in the host memory. Let's now replicate the arrays
# into the devices.
# This will make every array in the pytree params become a ShardedDeviceArray
# that has the same data replicated across all local devices.
# For TPU it replicates the params in every core.
# For a single GPU this simply moves the data onto the device.
# For CPU it simply creates a copy.
params_repl = flax.jax_utils.replicate(params)
print('params.cls:', type(params['head']['bias']).__name__,
params['head']['bias'].shape)
print('params_repl.cls:', type(params_repl['head']['bias']).__name__,
params_repl['head']['bias'].shape)
# Then map the call to our model's forward pass onto all available devices.
vit_apply_repl = jax.pmap(lambda params, inputs: model.apply(
dict(params=params), inputs, train=False))
def get_accuracy(params_repl):
"""Returns accuracy evaluated on the test set."""
good = total = 0
steps = input_pipeline.get_dataset_info(dataset, 'test')['num_examples'] // batch_size
for _, batch in zip(tqdm.trange(steps), ds_test.as_numpy_iterator()):
predicted = vit_apply_repl(params_repl, batch['image'])
is_same = predicted.argmax(axis=-1) == batch['label'].argmax(axis=-1)
good += is_same.sum()
total += len(is_same.flatten())
return good / total
# Random performance without fine-tuning.
get_accuracy(params_repl)
# 100 Steps take approximately 15 minutes in the TPU runtime.
total_steps = 100
warmup_steps = 5
decay_type = 'cosine'
grad_norm_clip = 1
# This controls in how many forward passes the batch is split. 8 works well with
# a TPU runtime that has 8 devices. 64 should work on a GPU. You can of course
# also adjust the batch_size above, but that would require you to adjust the
# learning rate accordingly.
accum_steps = 8
base_lr = 0.03
# Check out train.make_update_fn in the editor on the right side for details.
lr_fn = utils.create_learning_rate_schedule(total_steps, base_lr, decay_type, warmup_steps)
# We use a momentum optimizer that uses half precision for state to save
# memory. It als implements the gradient clipping.
tx = optax.chain(
optax.clip_by_global_norm(grad_norm_clip),
optax.sgd(
learning_rate=lr_fn,
momentum=0.9,
accumulator_dtype='bfloat16',
),
)
update_fn_repl = train.make_update_fn(
apply_fn=model.apply, accum_steps=accum_steps, tx=tx)
opt_state = tx.init(params)
opt_state_repl = flax.jax_utils.replicate(opt_state)
# Initialize PRNGs for dropout.
update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0))