Transformers
Safetensors
ijepa
Inference Endpoints
jmtzt commited on
Commit
7929d14
1 Parent(s): 110f8d6

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +20 -13
README.md CHANGED
@@ -31,27 +31,34 @@ I-JEPA can be used for image classification or feature extraction. This checkpoi
31
 
32
  ## How to use
33
 
34
- Here is how to use this model to classify an image of the COCO 2017 dataset into one of the 1,000 ImageNet classes:
35
 
36
  ```python
37
  import requests
38
-
39
  from PIL import Image
40
- from transformers import AutoProcessor, IJepaForImageClassification
 
 
41
 
42
- url = "http://images.cocodataset.org/val2017/000000039769.jpg"
43
- image = Image.open(requests.get(url, stream=True).raw)
 
 
44
 
45
  model_id = "jmtzt/ijepa_vith14_1k"
46
  processor = AutoProcessor.from_pretrained(model_id)
47
- model = IJepaForImageClassification.from_pretrained(model_id)
48
-
49
- inputs = processor(images=image, return_tensors="pt")
50
- outputs = model(**inputs)
51
- logits = outputs.logits
52
- # model predicts one of the 1000 ImageNet classes
53
- predicted_class_idx = logits.argmax(-1).item()
54
- print("Predicted class:", model.config.id2label[predicted_class_idx])
 
 
 
 
55
  ```
56
 
57
  ### BibTeX entry and citation info
 
31
 
32
  ## How to use
33
 
34
+ Here is how to use this model for image feature extraction:
35
 
36
  ```python
37
  import requests
 
38
  from PIL import Image
39
+ from torch.nn.functional import cosine_similarity
40
+
41
+ from transformers import AutoModel, AutoProcessor
42
 
43
+ url_1 = "http://images.cocodataset.org/val2017/000000039769.jpg"
44
+ url_2 = "http://images.cocodataset.org/val2017/000000219578.jpg"
45
+ image_1 = Image.open(requests.get(url_1, stream=True).raw)
46
+ image_2 = Image.open(requests.get(url_2, stream=True).raw)
47
 
48
  model_id = "jmtzt/ijepa_vith14_1k"
49
  processor = AutoProcessor.from_pretrained(model_id)
50
+ model = AutoModel.from_pretrained(model_id)
51
+
52
+ def infer(image):
53
+ inputs = processor(image, return_tensors="pt")
54
+ outputs = model(**inputs)
55
+ return outputs.pooler_output
56
+
57
+ embed_1 = infer(image_1)
58
+ embed_2 = infer(image_2)
59
+
60
+ similarity = cosine_similarity(embed_1, embed_2)
61
+ print(similarity)
62
  ```
63
 
64
  ### BibTeX entry and citation info