Wa2erGo commited on
Commit
360475e
·
1 Parent(s): 9507e9c

Update README.md

Browse files

Fix "device" and "pixel_values" not defined

Files changed (1) hide show
  1. README.md +3 -1
README.md CHANGED
@@ -22,6 +22,8 @@ import requests
22
  from PIL import Image
23
  import re
24
 
 
 
25
  url = "http://images.cocodataset.org/val2017/000000039769.jpg"
26
  image = Image.open(requests.get(url, stream=True).raw)
27
  text = "a bunch of [MASK] laying on a [MASK]."
@@ -44,7 +46,7 @@ with torch.no_grad():
44
  encoded = processor.tokenizer(inferred_token)
45
  input_ids = torch.tensor(encoded.input_ids).to(device)
46
  encoded = encoded["input_ids"][0][1:-1]
47
- outputs = model(input_ids=input_ids, pixel_values=pixel_values)
48
  mlm_logits = outputs.logits[0] # shape (seq_len, vocab_size)
49
  # only take into account text features (minus CLS and SEP token)
50
  mlm_logits = mlm_logits[1 : input_ids.shape[1] - 1, :]
 
22
  from PIL import Image
23
  import re
24
 
25
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26
+
27
  url = "http://images.cocodataset.org/val2017/000000039769.jpg"
28
  image = Image.open(requests.get(url, stream=True).raw)
29
  text = "a bunch of [MASK] laying on a [MASK]."
 
46
  encoded = processor.tokenizer(inferred_token)
47
  input_ids = torch.tensor(encoded.input_ids).to(device)
48
  encoded = encoded["input_ids"][0][1:-1]
49
+ outputs = model(input_ids=input_ids, pixel_values=encoding.pixel_values)
50
  mlm_logits = outputs.logits[0] # shape (seq_len, vocab_size)
51
  # only take into account text features (minus CLS and SEP token)
52
  mlm_logits = mlm_logits[1 : input_ids.shape[1] - 1, :]