timm
PyTorch
medical
Image Feature Extraction
egeozsoy commited on
Commit
100a0c2
1 Parent(s): 74033b8

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +103 -0
README.md ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - medical
5
+ ---
6
+
7
+ <!-- markdownlint-disable first-line-h1 -->
8
+ <!-- markdownlint-disable html -->
9
+
10
+ <div align="center">
11
+ <h1>
12
+ EndoViT
13
+ </h1>
14
+ </div>
15
+
16
+ <p align="center">
17
+ <a href="https://link.springer.com/article/10.1007/s11548-024-03091-5" target="_blank">Paper</a> <a href="https://github.com/DominikBatic/EndoViT" target="_blank">Github</a></a>
18
+ </p>
19
+
20
+ <div align="center">
21
+ </div>
22
+
23
+
24
+ ##Get Started
25
+
26
+ This section provides a quick start example for using the EndoViT model.
27
+
28
+ Installation:
29
+
30
+ ```python
31
+ pip install torch==2.0.1 timm==0.9.16 huggingface-hub==0.22.2
32
+ ```
33
+
34
+ Extracting features from a list of images. (Can also be a good starting point for using EndoViT as backbone)
35
+
36
+ ```python
37
+ import torch
38
+ import torchvision.transforms as T
39
+ from PIL import Image
40
+ from pathlib import Path
41
+ from timm.models.vision_transformer import VisionTransformer
42
+ from functools import partial
43
+ from torch import nn
44
+ from huggingface_hub import snapshot_download
45
+
46
+
47
+ def process_single_image(image_path, input_size=224, dataset_mean=[0.3464, 0.2280, 0.2228], dataset_std=[0.2520, 0.2128, 0.2093]):
48
+ # Define the transformations
49
+ transform = T.Compose([
50
+ T.Resize((input_size, input_size)),
51
+ T.ToTensor(),
52
+ T.Normalize(mean=dataset_mean, std=dataset_std)
53
+ ])
54
+
55
+ # Open the image
56
+ image = Image.open(image_path).convert('RGB')
57
+
58
+ # Apply the transformations
59
+ processed_image = transform(image)
60
+
61
+ return processed_image
62
+ def load_model_from_huggingface(repo_id, model_filename):
63
+ # Download model files
64
+ model_path = snapshot_download(repo_id=repo_id, revision="main")
65
+ model_weights_path = Path(model_path) / model_filename
66
+
67
+ # Load model weights
68
+ model_weights = torch.load(model_weights_path)['model']
69
+
70
+ # Define the model (ensure this matches your model's architecture)
71
+ model = VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)).eval()
72
+
73
+ # Load the weights into the model
74
+ loading = model.load_state_dict(model_weights, strict=False)
75
+
76
+ return model, loading
77
+
78
+
79
+ image_paths = sorted(Path('demo_images').glob('*.png')) # TODO replace with image pass
80
+ images = torch.stack([process_single_image(image_path) for image_path in image_paths])
81
+
82
+ device = "cuda"
83
+ dtype = torch.float16
84
+ model, loading_info = load_model_from_huggingface("egeozsoy/EndoViT", "endovit.pth")
85
+ model = model.to(device, dtype)
86
+ print(loading_info)
87
+ output = model.forward_features(images.to(device, dtype))
88
+ print(output.shape)
89
+ ```
90
+
91
+
92
+ ## ✏️ Citation
93
+
94
+ ```
95
+ @article{batic2024endovit,
96
+ title={EndoViT: pretraining vision transformers on a large collection of endoscopic images},
97
+ author={Bati{\'c}, Dominik and Holm, Felix and {\"O}zsoy, Ege and Czempiel, Tobias and Navab, Nassir},
98
+ journal={International Journal of Computer Assisted Radiology and Surgery},
99
+ pages={1--7},
100
+ year={2024},
101
+ publisher={Springer}
102
+ }
103
+ ```