pamixsun commited on
Commit
3d72fc8
1 Parent(s): 4bee0c2

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +15 -8
README.md CHANGED
@@ -58,22 +58,29 @@ Use the code below to get started with the model.
58
  import cv2
59
  import torch
60
 
61
- from transformers import AutoImageProcessor, Swinv2ForImageClassification
62
 
63
  image = cv2.imread('./example.jpg')
64
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
65
 
66
- processor = AutoImageProcessor.from_pretrained("pamixsun/swinv2_tiny_for_glaucoma_classification")
67
- model = Swinv2ForImageClassification.from_pretrained("pamixsun/swinv2_tiny_for_glaucoma_classification")
68
 
69
  inputs = processor(image, return_tensors="pt")
70
 
71
  with torch.no_grad():
72
- logits = model(**inputs).logits
73
-
74
- # model predicts either glaucoma or non-glaucoma.
75
- predicted_label = logits.argmax(-1).item()
76
- print(model.config.id2label[predicted_label])
 
 
 
 
 
 
 
77
 
78
  ```
79
 
 
58
  import cv2
59
  import torch
60
 
61
+ from transformers import AutoImageProcessor, SegformerForSemanticSegmentation
62
 
63
  image = cv2.imread('./example.jpg')
64
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
65
 
66
+ processor = AutoImageProcessor.from_pretrained("pamixsun/segformer_for_optic_disc_cup_segmentation")
67
+ model = SegformerForSemanticSegmentation.from_pretrained("pamixsun/segformer_for_optic_disc_cup_segmentation")
68
 
69
  inputs = processor(image, return_tensors="pt")
70
 
71
  with torch.no_grad():
72
+ inputs.to(self.device)
73
+ outputs = self.seg_model(**inputs)
74
+ logits = outputs.logits.cpu()
75
+
76
+ upsampled_logits = nn.functional.interpolate(
77
+ logits,
78
+ size=image.shape[:2],
79
+ mode="bilinear",
80
+ align_corners=False,
81
+ )
82
+
83
+ pred_disc_cup = upsampled_logits.argmax(dim=1)[0].numpy().astype(np.uint8)
84
 
85
  ```
86