import torch def generate_text(model, image, tokenizer, image_transfrom, max_length=30): # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.eval() # model = model.to(device) temperature = 0.9 stop_token_id = tokenizer.pad_token_id output_ids = [] image = image_transfrom(image) img_tensor = image.unsqueeze(0)#.to(device) images_embedding = model.clip(img_tensor) images_projection = model.mapping_network(images_embedding).view(-1, model.max_length, model.gpt_embedding_size) input_state = images_projection with torch.no_grad(): for i in range(max_length): outputs = model.gpt(input_state, None).logits next_token_scores = outputs[0, -1, :].detach().div(temperature).softmax(dim=0) #next_token_id = np.random.choice(len(next_token_scores), p = next_token_scores.cpu().numpy()) next_token_id = next_token_scores.max(dim=0).indices.item() if next_token_id == stop_token_id: break output_ids.append(next_token_id) # Update state next_token_id = torch.tensor([next_token_id]).unsqueeze(0)#.to(device) next_token_embed = model.gpt.base_network.transformer.wte(next_token_id) input_state = torch.cat((input_state, next_token_embed), dim=1) return tokenizer.decode(output_ids)