File size: 1,425 Bytes
4cea813 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import torch
def generate_text(model, image, tokenizer, image_transfrom, max_length=30):
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
# model = model.to(device)
temperature = 0.9
stop_token_id = tokenizer.pad_token_id
output_ids = []
image = image_transfrom(image)
img_tensor = image.unsqueeze(0)#.to(device)
images_embedding = model.clip(img_tensor)
images_projection = model.mapping_network(images_embedding).view(-1, model.max_length, model.gpt_embedding_size)
input_state = images_projection
with torch.no_grad():
for i in range(max_length):
outputs = model.gpt(input_state, None).logits
next_token_scores = outputs[0, -1, :].detach().div(temperature).softmax(dim=0)
#next_token_id = np.random.choice(len(next_token_scores), p = next_token_scores.cpu().numpy())
next_token_id = next_token_scores.max(dim=0).indices.item()
if next_token_id == stop_token_id:
break
output_ids.append(next_token_id)
# Update state
next_token_id = torch.tensor([next_token_id]).unsqueeze(0)#.to(device)
next_token_embed = model.gpt.base_network.transformer.wte(next_token_id)
input_state = torch.cat((input_state, next_token_embed), dim=1)
return tokenizer.decode(output_ids) |