timm
PyTorch
medical
Image Feature Extraction
Ege Oezsoy commited on
Commit
74033b8
1 Parent(s): 372b042

Initial model commit

Browse files
Files changed (4) hide show
  1. README.MD +0 -0
  2. endovit.pth +3 -0
  3. endovit_demo.py +39 -0
  4. requirements.txt +2 -0
README.MD ADDED
File without changes
endovit.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec41888fd928eb404e518d61344c25822b3c4a776f997b876db040b86a5ba21a
3
+ size 1341188699
endovit_demo.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as T
3
+ from PIL import Image
4
+ from pathlib import Path
5
+ from timm.models.vision_transformer import VisionTransformer
6
+ from functools import partial
7
+ from torch import nn
8
+
9
+ # requires: pytorch 2.0.1, timm 0.9.16
10
+ def process_single_image(image_path, input_size=224, dataset_mean=[0.3464, 0.2280, 0.2228], dataset_std=[0.2520, 0.2128, 0.2093]):
11
+ # Define the transformations
12
+ transform = T.Compose([
13
+ T.Resize((input_size, input_size)),
14
+ T.ToTensor(),
15
+ T.Normalize(mean=dataset_mean, std=dataset_std)
16
+ ])
17
+
18
+ # Open the image
19
+ image = Image.open(image_path).convert('RGB')
20
+
21
+ # Apply the transformations
22
+ processed_image = transform(image)
23
+
24
+ return processed_image
25
+
26
+
27
+ image_paths = sorted(Path('demo_images').glob('*.png'))
28
+ images = torch.stack([process_single_image(image_path) for image_path in image_paths])
29
+
30
+ device = "cuda"
31
+ dtype = torch.float16
32
+
33
+ model_weights = torch.load('endovit_seg.pth')['model']
34
+
35
+ model = VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)).to(device, dtype).eval()
36
+ loading = model.load_state_dict(model_weights, strict=False)
37
+ print(loading)
38
+ output = model.forward_features(images.to(device, dtype))
39
+ print(output.shape)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch==2.0.1
2
+ timm==0.9.16