flaviagiammarino commited on
Commit
c81c5ce
1 Parent(s): 58b1a78

add examples

Browse files
Files changed (2) hide show
  1. scripts/pt_example.py +21 -0
  2. scripts/tf_example.py +22 -0
scripts/pt_example.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from PIL import Image
3
+ import matplotlib.pyplot as plt
4
+
5
+ from transformers import CLIPProcessor, CLIPModel
6
+
7
+ model = CLIPModel.from_pretrained("flaviagiammarino/pubmed-clip-vit-base-patch32")
8
+ processor = CLIPProcessor.from_pretrained("flaviagiammarino/pubmed-clip-vit-base-patch32")
9
+
10
+ url = "https://encrypted-tbn3.gstatic.com/images?q=tbn:ANd9GcSjP8UWzpGqXKwlC1zPRhcJOXThfI4pXgg2Zhd1B-cstpnEDalY"
11
+ image = Image.open(requests.get(url, stream=True).raw)
12
+ text = ["Chest X-Ray", "Brain MRI"]
13
+
14
+ inputs = processor(text=text, images=image, return_tensors="pt", padding=True)
15
+ probs = model(**inputs).logits_per_image.softmax(dim=1).detach().numpy().flatten()
16
+
17
+ plt.subplots()
18
+ plt.imshow(image)
19
+ plt.title("".join([x[0] + ": " + x[1] + " " for x in zip(text, [format(prob, ".4%") for prob in probs])]))
20
+ plt.axis("off")
21
+ plt.show()
scripts/tf_example.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from PIL import Image
3
+ import matplotlib.pyplot as plt
4
+ import tensorflow as tf
5
+
6
+ from transformers import CLIPProcessor, TFCLIPModel
7
+
8
+ model = TFCLIPModel.from_pretrained("flaviagiammarino/pubmed-clip-vit-base-patch32")
9
+ processor = CLIPProcessor.from_pretrained("flaviagiammarino/pubmed-clip-vit-base-patch32")
10
+
11
+ url = "https://encrypted-tbn3.gstatic.com/images?q=tbn:ANd9GcSjP8UWzpGqXKwlC1zPRhcJOXThfI4pXgg2Zhd1B-cstpnEDalY"
12
+ image = Image.open(requests.get(url, stream=True).raw)
13
+ text = ["Chest X-Ray", "Brain MRI"]
14
+
15
+ inputs = processor(text=text, images=image, return_tensors="tf", padding=True)
16
+ probs = tf.nn.softmax(model(**inputs).logits_per_image, axis=-1).numpy().flatten()
17
+
18
+ plt.subplots()
19
+ plt.imshow(image)
20
+ plt.title("".join([x[0] + ": " + x[1] + " " for x in zip(text, [format(prob, ".4%") for prob in probs])]))
21
+ plt.axis("off")
22
+ plt.show()