ybelkada commited on
Commit
bf544e9
1 Parent(s): 8b5fd15

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +3 -3
README.md CHANGED
@@ -85,7 +85,7 @@ processor = Pix2StructProcessor.from_pretrained("google/pix2struct-ai2d-base")
85
 
86
  question = "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud"
87
 
88
- inputs = processor(images=image, text=text, return_tensors="pt")
89
 
90
  predictions = model.generate(**inputs)
91
  print(processor.decode(predictions[0], skip_special_tokens=True))
@@ -108,7 +108,7 @@ processor = Pix2StructProcessor.from_pretrained("google/pix2struct-ai2d-base")
108
 
109
  question = "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud"
110
 
111
- inputs = processor(images=image, text=text, return_tensors="pt").to("cuda")
112
 
113
  predictions = model.generate(**inputs)
114
  print(processor.decode(predictions[0], skip_special_tokens=True))
@@ -133,7 +133,7 @@ processor = Pix2StructProcessor.from_pretrained("google/pix2struct-ai2d-base")
133
 
134
  question = "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud"
135
 
136
- inputs = processor(images=image, text=text, return_tensors="pt").to("cuda", torch.bfloat16)
137
 
138
  predictions = model.generate(**inputs)
139
  print(processor.decode(predictions[0], skip_special_tokens=True))
 
85
 
86
  question = "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud"
87
 
88
+ inputs = processor(images=image, text=question, return_tensors="pt")
89
 
90
  predictions = model.generate(**inputs)
91
  print(processor.decode(predictions[0], skip_special_tokens=True))
 
108
 
109
  question = "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud"
110
 
111
+ inputs = processor(images=image, text=question, return_tensors="pt").to("cuda")
112
 
113
  predictions = model.generate(**inputs)
114
  print(processor.decode(predictions[0], skip_special_tokens=True))
 
133
 
134
  question = "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud"
135
 
136
+ inputs = processor(images=image, text=question, return_tensors="pt").to("cuda", torch.bfloat16)
137
 
138
  predictions = model.generate(**inputs)
139
  print(processor.decode(predictions[0], skip_special_tokens=True))