Split code in blocks
Browse files
README.md
CHANGED
@@ -51,31 +51,36 @@ Fine-tuning RAD-DINO is typically not necessary to obtain good performance in do
|
|
51 |
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
52 |
|
53 |
RAD-DINO was trained with data from three countries, therefore it might be biased towards population in the training data.
|
54 |
-
Underlying biases of the training datasets may not be well
|
55 |
|
56 |
## Getting started
|
57 |
|
|
|
|
|
58 |
```python
|
59 |
-
>>> import
|
60 |
>>> from PIL import Image
|
61 |
-
>>> from transformers import AutoModel
|
62 |
-
>>> from transformers import AutoImageProcessor
|
63 |
-
>>>
|
64 |
-
>>> # Define a small function to get a sample image
|
65 |
>>> def download_sample_image() -> Image.Image:
|
66 |
... """Download chest X-ray with CC license."""
|
67 |
-
... import requests
|
68 |
-
... from PIL import Image
|
69 |
... base_url = "https://upload.wikimedia.org/wikipedia/commons"
|
70 |
... image_url = f"{base_url}/2/20/Chest_X-ray_in_influenza_and_Haemophilus_influenzae.jpg"
|
71 |
-
... headers = {"User-Agent": "
|
72 |
... response = requests.get(image_url, headers=headers, stream=True)
|
73 |
... return Image.open(response.raw)
|
74 |
...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
>>> # Download the model
|
76 |
>>> repo = "microsoft/rad-dino"
|
77 |
>>> model = AutoModel.from_pretrained(repo)
|
78 |
-
|
79 |
>>> # The processor takes a PIL image, performs resizing, center-cropping, and
|
80 |
>>> # intensity normalization using stats from MIMIC-CXR, and returns a
|
81 |
>>> # dictionary with a PyTorch tensor ready for the encoder
|
@@ -95,8 +100,12 @@ Underlying biases of the training datasets may not be well characterised.
|
|
95 |
>>> cls_embeddings = outputs.pooler_output
|
96 |
>>> cls_embeddings.shape # (batch_size, num_channels)
|
97 |
torch.Size([1, 768])
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
100 |
>>> def reshape_patch_embeddings(flat_tokens: torch.Tensor) -> torch.Tensor:
|
101 |
... """Reshape flat list of patch tokens into a nice grid."""
|
102 |
... from einops import rearrange
|
|
|
51 |
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
52 |
|
53 |
RAD-DINO was trained with data from three countries, therefore it might be biased towards population in the training data.
|
54 |
+
Underlying biases of the training datasets may not be well characterized.
|
55 |
|
56 |
## Getting started
|
57 |
|
58 |
+
Let us first write an auxiliary function to download a chest X-ray.
|
59 |
+
|
60 |
```python
|
61 |
+
>>> import requests
|
62 |
>>> from PIL import Image
|
|
|
|
|
|
|
|
|
63 |
>>> def download_sample_image() -> Image.Image:
|
64 |
... """Download chest X-ray with CC license."""
|
|
|
|
|
65 |
... base_url = "https://upload.wikimedia.org/wikipedia/commons"
|
66 |
... image_url = f"{base_url}/2/20/Chest_X-ray_in_influenza_and_Haemophilus_influenzae.jpg"
|
67 |
+
... headers = {"User-Agent": "RAD-DINO"}
|
68 |
... response = requests.get(image_url, headers=headers, stream=True)
|
69 |
... return Image.open(response.raw)
|
70 |
...
|
71 |
+
```
|
72 |
+
|
73 |
+
Now let us download the model and encode an image.
|
74 |
+
|
75 |
+
```python
|
76 |
+
>>> import torch
|
77 |
+
>>> from transformers import AutoModel
|
78 |
+
>>> from transformers import AutoImageProcessor
|
79 |
+
>>>
|
80 |
>>> # Download the model
|
81 |
>>> repo = "microsoft/rad-dino"
|
82 |
>>> model = AutoModel.from_pretrained(repo)
|
83 |
+
>>>
|
84 |
>>> # The processor takes a PIL image, performs resizing, center-cropping, and
|
85 |
>>> # intensity normalization using stats from MIMIC-CXR, and returns a
|
86 |
>>> # dictionary with a PyTorch tensor ready for the encoder
|
|
|
100 |
>>> cls_embeddings = outputs.pooler_output
|
101 |
>>> cls_embeddings.shape # (batch_size, num_channels)
|
102 |
torch.Size([1, 768])
|
103 |
+
```
|
104 |
+
|
105 |
+
If we are interested in the feature maps, we can reshape the patch embeddings into a grid.
|
106 |
+
We will use [`einops`](https://einops.rocks/) (install with `pip install einops`) for this.
|
107 |
+
|
108 |
+
```python
|
109 |
>>> def reshape_patch_embeddings(flat_tokens: torch.Tensor) -> torch.Tensor:
|
110 |
... """Reshape flat list of patch tokens into a nice grid."""
|
111 |
... from einops import rearrange
|