Littlehongman's picture
First version success
4cea813
raw
history blame
No virus
1.43 kB
import torch
def generate_text(model, image, tokenizer, image_transfrom, max_length=30):
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
# model = model.to(device)
temperature = 0.9
stop_token_id = tokenizer.pad_token_id
output_ids = []
image = image_transfrom(image)
img_tensor = image.unsqueeze(0)#.to(device)
images_embedding = model.clip(img_tensor)
images_projection = model.mapping_network(images_embedding).view(-1, model.max_length, model.gpt_embedding_size)
input_state = images_projection
with torch.no_grad():
for i in range(max_length):
outputs = model.gpt(input_state, None).logits
next_token_scores = outputs[0, -1, :].detach().div(temperature).softmax(dim=0)
#next_token_id = np.random.choice(len(next_token_scores), p = next_token_scores.cpu().numpy())
next_token_id = next_token_scores.max(dim=0).indices.item()
if next_token_id == stop_token_id:
break
output_ids.append(next_token_id)
# Update state
next_token_id = torch.tensor([next_token_id]).unsqueeze(0)#.to(device)
next_token_embed = model.gpt.base_network.transformer.wte(next_token_id)
input_state = torch.cat((input_state, next_token_embed), dim=1)
return tokenizer.decode(output_ids)