Littlehongman commited on
Commit
4cea813
β€’
1 Parent(s): 20dcad7

First version success

Browse files
Files changed (5) hide show
  1. .gitignore +2 -0
  2. app.py +10 -0
  3. model.py +137 -0
  4. predict.py +42 -0
  5. requirements.txt +7 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ artifacts/
2
+ wandb/
app.py CHANGED
@@ -2,6 +2,9 @@ import streamlit as st
2
  import streamlit.components.v1 as components
3
  from PIL import Image
4
 
 
 
 
5
 
6
  # Configure Streamlit page
7
  st.set_page_config(page_title="Caption Machine", page_icon="πŸ’₯")
@@ -40,4 +43,11 @@ if upload_file is not None:
40
  st.image(img)
41
  st.write("Image Uploaded Successfully")
42
 
 
 
 
 
 
 
 
43
 
 
2
  import streamlit.components.v1 as components
3
  from PIL import Image
4
 
5
+ from predict import generate_text
6
+ from model import load_clip_model, load_gpt_model, load_model
7
+
8
 
9
  # Configure Streamlit page
10
  st.set_page_config(page_title="Caption Machine", page_icon="πŸ’₯")
 
43
  st.image(img)
44
  st.write("Image Uploaded Successfully")
45
 
46
+ # gpt_model, tokenizer = load_gpt_model()
47
+
48
+ model, image_transform, tokenizer = load_model()
49
+ caption = generate_text(model, img, tokenizer, image_transform)
50
+
51
+ st.write(caption)
52
+
53
 
model.py CHANGED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import wandb
4
+ import streamlit as st
5
+
6
+ import clip
7
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
8
+
9
+
10
+ class ImageEncoder(nn.Module):
11
+
12
+ def __init__(self, base_network):
13
+ super(ImageEncoder, self).__init__()
14
+ self.base_network = base_network
15
+ self.embedding_size = self.base_network.token_embedding.weight.shape[1]
16
+
17
+ def forward(self, images):
18
+ with torch.no_grad():
19
+ x = self.base_network.encode_image(images)
20
+ x = x / x.norm(dim=1, keepdim=True)
21
+ x = x.float()
22
+
23
+ return x
24
+
25
+ class Mapping(nn.Module):
26
+ # Map the featureMap from CLIP model to GPT2
27
+ def __init__(self, clip_embedding_size, gpt_embedding_size, length=30): # length: sentence length
28
+ super(Mapping, self).__init__()
29
+
30
+ self.clip_embedding_size = clip_embedding_size
31
+ self.gpt_embedding_size = gpt_embedding_size
32
+ self.length = length
33
+
34
+ self.fc1 = nn.Linear(clip_embedding_size, gpt_embedding_size * length)
35
+
36
+ def forward(self, x):
37
+ x = self.fc1(x)
38
+
39
+ return x.view(-1, self.length, self.gpt_embedding_size)
40
+
41
+
42
+ class TextDecoder(nn.Module):
43
+ def __init__(self, base_network):
44
+ super(TextDecoder, self).__init__()
45
+ self.base_network = base_network
46
+ self.embedding_size = self.base_network.transformer.wte.weight.shape[1]
47
+ self.vocab_size = self.base_network.transformer.wte.weight.shape[0]
48
+
49
+ def forward(self, concat_embedding, mask=None):
50
+ return self.base_network(inputs_embeds=concat_embedding, attention_mask=mask)
51
+
52
+
53
+ def get_embedding(self, texts):
54
+ return self.base_network.transformer.wte(texts)
55
+
56
+
57
+ import pytorch_lightning as pl
58
+
59
+
60
+ class ImageCaptioner(pl.LightningModule):
61
+ def __init__(self, clip_model, gpt_model, tokenizer, total_steps, max_length=20):
62
+ super(ImageCaptioner, self).__init__()
63
+
64
+ self.padding_token_id = tokenizer.pad_token_id
65
+ #self.stop_token_id = tokenizer.encode('.')[0]
66
+
67
+ # Define networks
68
+ self.clip = ImageEncoder(clip_model)
69
+ self.gpt = TextDecoder(gpt_model)
70
+ self.mapping_network = Mapping(self.clip.embedding_size, self.gpt.embedding_size, max_length)
71
+
72
+ # Define variables
73
+ self.total_steps = total_steps
74
+ self.max_length = max_length
75
+ self.clip_embedding_size = self.clip.embedding_size
76
+ self.gpt_embedding_size = self.gpt.embedding_size
77
+ self.gpt_vocab_size = self.gpt.vocab_size
78
+
79
+
80
+ def forward(self, images, texts, masks):
81
+ texts_embedding = self.gpt.get_embedding(texts)
82
+ images_embedding = self.clip(images)
83
+
84
+ images_projection = self.mapping_network(images_embedding).view(-1, self.max_length, self.gpt_embedding_size)
85
+ embedding_concat = torch.cat((images_projection, texts_embedding), dim=1)
86
+
87
+ out = self.gpt(embedding_concat, masks)
88
+
89
+ return out
90
+
91
+ @st.cache_resource
92
+ def download_trained_model():
93
+ wandb.init(anonymous="must")
94
+
95
+ api = wandb.Api()
96
+ artifact = api.artifact('hungchiehwu/CLIP-L14_GPT/model-ql03493w:v3')
97
+ artifact_dir = artifact.download()
98
+
99
+ wandb.finish()
100
+
101
+ return artifact_dir
102
+
103
+ @st.cache_resource
104
+ def load_clip_model():
105
+
106
+ clip_model, image_transform = clip.load("ViT-L/14", device="cpu")
107
+
108
+ return clip_model, image_transform
109
+
110
+ @st.cache_resource
111
+ def load_gpt_model():
112
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
113
+ gpt_model = GPT2LMHeadModel.from_pretrained('gpt2')
114
+
115
+ tokenizer.pad_token = tokenizer.eos_token
116
+
117
+ return gpt_model, tokenizer
118
+
119
+ @st.cache_resource
120
+ def load_model():
121
+
122
+ # # Load fine-tuned model from wandb
123
+ artifact_dir = download_trained_model()
124
+ PATH = f"{artifact_dir[2:]}/model.ckpt"
125
+
126
+ # Load pretrained GPT, CLIP model from OpenAI
127
+ clip_model, image_transfrom = load_clip_model()
128
+ gpt_model, tokenizer = load_gpt_model()
129
+
130
+
131
+
132
+ # Load weights
133
+ model = ImageCaptioner(clip_model, gpt_model, tokenizer, 0)
134
+ checkpoint = torch.load(PATH, map_location=torch.device('cpu'))
135
+ model.load_state_dict(checkpoint["state_dict"])
136
+
137
+ return model, image_transfrom, tokenizer
predict.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def generate_text(model, image, tokenizer, image_transfrom, max_length=30):
4
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5
+
6
+ model.eval()
7
+ # model = model.to(device)
8
+
9
+ temperature = 0.9
10
+ stop_token_id = tokenizer.pad_token_id
11
+ output_ids = []
12
+
13
+
14
+ image = image_transfrom(image)
15
+ img_tensor = image.unsqueeze(0)#.to(device)
16
+ images_embedding = model.clip(img_tensor)
17
+
18
+ images_projection = model.mapping_network(images_embedding).view(-1, model.max_length, model.gpt_embedding_size)
19
+
20
+ input_state = images_projection
21
+
22
+ with torch.no_grad():
23
+ for i in range(max_length):
24
+ outputs = model.gpt(input_state, None).logits
25
+
26
+ next_token_scores = outputs[0, -1, :].detach().div(temperature).softmax(dim=0)
27
+
28
+ #next_token_id = np.random.choice(len(next_token_scores), p = next_token_scores.cpu().numpy())
29
+ next_token_id = next_token_scores.max(dim=0).indices.item()
30
+
31
+ if next_token_id == stop_token_id:
32
+ break
33
+
34
+ output_ids.append(next_token_id)
35
+
36
+
37
+ # Update state
38
+ next_token_id = torch.tensor([next_token_id]).unsqueeze(0)#.to(device)
39
+ next_token_embed = model.gpt.base_network.transformer.wte(next_token_id)
40
+ input_state = torch.cat((input_state, next_token_embed), dim=1)
41
+
42
+ return tokenizer.decode(output_ids)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ ftfy
2
+ regex
3
+ tqdm
4
+ git+https://github.com/openai/CLIP.git
5
+ transformers
6
+ pytorch-lightning==1.9.0
7
+ wandb