|
import torch |
|
|
|
def generate_text(model, image, tokenizer, image_transfrom, max_length=30): |
|
|
|
|
|
model.eval() |
|
|
|
|
|
temperature = 0.9 |
|
stop_token_id = tokenizer.pad_token_id |
|
output_ids = [] |
|
|
|
|
|
image = image_transfrom(image) |
|
img_tensor = image.unsqueeze(0) |
|
images_embedding = model.clip(img_tensor) |
|
|
|
images_projection = model.mapping_network(images_embedding).view(-1, model.max_length, model.gpt_embedding_size) |
|
|
|
input_state = images_projection |
|
|
|
with torch.no_grad(): |
|
for i in range(max_length): |
|
outputs = model.gpt(input_state, None).logits |
|
|
|
next_token_scores = outputs[0, -1, :].detach().div(temperature).softmax(dim=0) |
|
|
|
|
|
next_token_id = next_token_scores.max(dim=0).indices.item() |
|
|
|
if next_token_id == stop_token_id: |
|
break |
|
|
|
output_ids.append(next_token_id) |
|
|
|
|
|
|
|
next_token_id = torch.tensor([next_token_id]).unsqueeze(0) |
|
next_token_embed = model.gpt.base_network.transformer.wte(next_token_id) |
|
input_state = torch.cat((input_state, next_token_embed), dim=1) |
|
|
|
return tokenizer.decode(output_ids) |