Mehdi Cherti commited on
Commit
19988d9
1 Parent(s): 428234f

generation example

Browse files
Files changed (2) hide show
  1. example.png +0 -0
  2. gen.py +39 -0
example.png ADDED
gen.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import open_clip
2
+ import torch
3
+ from PIL import Image
4
+ device = "cuda" if torch.cuda.is_available() else "cpu"
5
+ model, _, transform = open_clip.create_model_and_transforms(
6
+ model_name="coca_biogpt_vitb16",
7
+ pretrained="coca_biogpt_vitb16.pt",
8
+ )
9
+ model.to(device)
10
+ model.eval()
11
+ nb = 1
12
+ path = "example.png"
13
+ im = Image.open(path).convert("RGB")
14
+ im = transform(im).unsqueeze(0)
15
+ im = im.to(device)
16
+ im = im.repeat(nb,1,1,1)
17
+ print(im.shape)
18
+ tokenizer = open_clip.get_tokenizer("coca_biogpt_vitb16")
19
+ print(tokenizer.tokenizer)
20
+ with torch.no_grad():
21
+ generated = model.generate(
22
+ im,
23
+ pad_token_id=1, eos_token_id=2, sot_token_id=0, max_seq_len=256, seq_len=60,
24
+ # generation_type='top_p',
25
+ #generation_type='top_k',
26
+ generation_type='beam_search',
27
+ #repetition_penalty=1.,
28
+ #top_k=100,
29
+ #top_p=0.1,
30
+ # text=text,
31
+ #temperature=1.0,
32
+ #min_seq_len=40,
33
+ )
34
+ print(generated)
35
+ for i in range(nb):
36
+ if hasattr(tokenizer, "tokenizer"):
37
+ print(tokenizer.tokenizer.decode(generated[i]))
38
+ else:
39
+ print(open_clip.decode(generated[i]))