JustinLin610 commited on
Commit
2a4f282
1 Parent(s): 98e784b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +6 -6
README.md CHANGED
@@ -22,7 +22,7 @@ After, refer the path to OFA-huge to `ckpt_dir`, and prepare an image for the te
22
  >>> from generate import sequence_generator
23
 
24
  >>> mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
25
- >>> resolution = 256
26
  >>> patch_resize_transform = transforms.Compose([
27
  lambda image: image.convert("RGB"),
28
  transforms.Resize((resolution, resolution), interpolation=Image.BICUBIC),
@@ -30,7 +30,7 @@ After, refer the path to OFA-huge to `ckpt_dir`, and prepare an image for the te
30
  transforms.Normalize(mean=mean, std=std)
31
  ])
32
 
33
- >>> model = OFAModel.from_pretrained(ckpt_dir)
34
  >>> tokenizer = OFATokenizer.from_pretrained(ckpt_dir)
35
 
36
  >>> txt = " what does the image describe?"
@@ -40,16 +40,16 @@ After, refer the path to OFA-huge to `ckpt_dir`, and prepare an image for the te
40
 
41
 
42
  >>> # using the generator of fairseq version
43
- >>> generator = sequence_generator.SequenceGenerator(tokenizer=tokenizer,beam_size=5,
44
- max_len_b=16,
45
- min_len=0,
46
- no_repeat_ngram_size=3) # using the generator of fairseq version
47
  >>> data = {}
48
  >>> data["net_input"] = {"input_ids": inputs, 'patch_images': patch_img, 'patch_masks':torch.tensor([True])}
49
  >>> gen_output = generator.generate([model], data)
50
  >>> gen = [gen_output[i][0]["tokens"] for i in range(len(gen_output))]
51
 
52
  >>> # using the generator of huggingface version
 
53
  >>> gen = model.generate(inputs, patch_images=patch_img, num_beams=5, no_repeat_ngram_size=3)
54
 
55
  >>> print(tokenizer.batch_decode(gen, skip_special_tokens=True))
 
22
  >>> from generate import sequence_generator
23
 
24
  >>> mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
25
+ >>> resolution = 480
26
  >>> patch_resize_transform = transforms.Compose([
27
  lambda image: image.convert("RGB"),
28
  transforms.Resize((resolution, resolution), interpolation=Image.BICUBIC),
 
30
  transforms.Normalize(mean=mean, std=std)
31
  ])
32
 
33
+
34
  >>> tokenizer = OFATokenizer.from_pretrained(ckpt_dir)
35
 
36
  >>> txt = " what does the image describe?"
 
40
 
41
 
42
  >>> # using the generator of fairseq version
43
+ >>> model = OFAModel.from_pretrained(ckpt_dir, use_cache=True)
44
+ >>> generator = sequence_generator.SequenceGenerator(tokenizer=tokenizer,beam_size=5, max_len_b=16,
45
+ min_len=0, no_repeat_ngram_size=3) # using the generator of fairseq version
 
46
  >>> data = {}
47
  >>> data["net_input"] = {"input_ids": inputs, 'patch_images': patch_img, 'patch_masks':torch.tensor([True])}
48
  >>> gen_output = generator.generate([model], data)
49
  >>> gen = [gen_output[i][0]["tokens"] for i in range(len(gen_output))]
50
 
51
  >>> # using the generator of huggingface version
52
+ >>> model = OFAModel.from_pretrained(ckpt_dir, use_cache=False)
53
  >>> gen = model.generate(inputs, patch_images=patch_img, num_beams=5, no_repeat_ngram_size=3)
54
 
55
  >>> print(tokenizer.batch_decode(gen, skip_special_tokens=True))