File size: 1,425 Bytes
4cea813
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch

def generate_text(model, image, tokenizer, image_transfrom, max_length=30):
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.eval()
    # model = model.to(device)

    temperature = 0.9
    stop_token_id = tokenizer.pad_token_id
    output_ids = []


    image = image_transfrom(image)
    img_tensor = image.unsqueeze(0)#.to(device)
    images_embedding = model.clip(img_tensor)

    images_projection = model.mapping_network(images_embedding).view(-1, model.max_length, model.gpt_embedding_size)
    
    input_state = images_projection

    with torch.no_grad():
        for i in range(max_length):
            outputs = model.gpt(input_state, None).logits

            next_token_scores = outputs[0, -1, :].detach().div(temperature).softmax(dim=0)

            #next_token_id = np.random.choice(len(next_token_scores), p = next_token_scores.cpu().numpy())
            next_token_id = next_token_scores.max(dim=0).indices.item()

            if next_token_id == stop_token_id:
                break

            output_ids.append(next_token_id)

        
            # Update state
            next_token_id = torch.tensor([next_token_id]).unsqueeze(0)#.to(device)
            next_token_embed = model.gpt.base_network.transformer.wte(next_token_id)
            input_state = torch.cat((input_state, next_token_embed), dim=1)
    
    return tokenizer.decode(output_ids)