Tudohuang commited on
Commit
6a3d9c3
·
verified ·
1 Parent(s): dc9f48a

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +170 -0
  2. im2text_model_full.pt +3 -0
  3. vocab_full.json +0 -0
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import torch.nn as nn
5
+ from torchvision import transforms, models
6
+ from flask import Flask, request, jsonify , render_template
7
+ import imageio.v3 as imageio
8
+ import numpy as np
9
+ from io import BytesIO
10
+ from PIL import Image
11
+ # 建立詞彙表
12
+ class tokenizer():
13
+ def __init__(self, threshold=5):
14
+ self.word2idx = {}
15
+ self.idx2word = {}
16
+ self.threshold = threshold
17
+ self.word2count = {}
18
+
19
+ def build_vocab(self, corpus):
20
+ print('buiding vocab......')
21
+ tokens = corpus.lower().split()
22
+ for token in tokens:
23
+ self.word2count[token] = self.word2count.get(token, 0) + 1
24
+ idx = 0
25
+ for word, count in self.word2count.items():
26
+ if count >= self.threshold:
27
+ self.word2idx[word] = idx
28
+ self.idx2word[idx] = word
29
+ idx += 1
30
+ print(f'Vocab size: {len(self.idx2word)}')
31
+
32
+ def encode(self, sentence):
33
+ tokens = sentence.lower().split()
34
+ return [self.word2idx.get(token, self.word2idx['<unk>']) for token in tokens]
35
+
36
+ def decode(self, indices):
37
+ return ' '.join([self.idx2word.get(idx, '<unk>') for idx in indices])
38
+
39
+ def save_vocab(self, filepath):
40
+ with open(filepath, 'w') as f:
41
+ json.dump({'word2idx': self.word2idx, 'idx2word': self.idx2word}, f)
42
+
43
+ def load_vocab(self, filepath):
44
+ with open(filepath, 'r') as f:
45
+ data = json.load(f)
46
+ self.word2idx = data['word2idx']
47
+ self.idx2word = {int(k): v for k, v in data['idx2word'].items()}
48
+
49
+ # 定義CNN編碼器
50
+ class CNNEncoder(nn.Module):
51
+ def __init__(self, embed_size, num_groups=32):
52
+ super(CNNEncoder, self).__init__()
53
+ resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
54
+ for param in resnet.parameters():
55
+ param.requires_grad = False
56
+ self.resnet = nn.Sequential(*list(resnet.children())[:-1])
57
+ self.linear = nn.Linear(resnet.fc.in_features, embed_size)
58
+ self.gn = nn.GroupNorm(num_groups, embed_size)
59
+
60
+ def forward(self, images):
61
+ with torch.no_grad():
62
+ features = self.resnet(images)
63
+ features = features.view(features.size(0), -1)
64
+ features = self.gn(self.linear(features))
65
+ return features
66
+
67
+ # 定義RNN解碼器
68
+ class RNNDecoder(nn.Module):
69
+ def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
70
+ super(RNNDecoder, self).__init__()
71
+ self.embed = nn.Embedding(vocab_size, embed_size)
72
+ self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
73
+ self.linear = nn.Linear(hidden_size, vocab_size)
74
+ self.embed_size = embed_size
75
+ self.hidden_size = hidden_size
76
+ self.num_layers = num_layers
77
+
78
+ def forward(self, features, captions):
79
+ embeddings = self.embed(captions)
80
+ embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
81
+ hiddens, _ = self.lstm(embeddings)
82
+ outputs = self.linear(hiddens[:, 1:, :])
83
+ return outputs
84
+
85
+ def sample(self, features, states=None, max_len=20):
86
+ sampled_ids = [vocab.word2idx['<start>']]
87
+ inputs = features.unsqueeze(1)
88
+ start_token = torch.tensor([vocab.word2idx['<start>']]).to(device).unsqueeze(0)
89
+ inputs = torch.cat((features.unsqueeze(1), self.embed(start_token)), 1)
90
+ for i in range(max_len):
91
+ hiddens, states = self.lstm(inputs, states)
92
+ outputs = self.linear(hiddens[:, -1, :]) # take the output of the last time step
93
+ _, predicted = outputs.max(1)
94
+ sampled_ids.append(predicted.item())
95
+ if predicted.item() == vocab.word2idx['<end>']:
96
+ break
97
+ inputs = self.embed(predicted).unsqueeze(1)
98
+ return sampled_ids
99
+
100
+ # 定義ImageToText模型
101
+ class im2text_model(nn.Module):
102
+ def __init__(self, cnn_encoder, rnn_decoder):
103
+ super(im2text_model, self).__init__()
104
+ self.encoder = cnn_encoder
105
+ self.decoder = rnn_decoder
106
+
107
+ def forward(self, images, captions):
108
+ features = self.encoder(images)
109
+ outputs = self.decoder(features, captions)
110
+ return outputs
111
+
112
+ def sample(self, images, states=None):
113
+ features = self.encoder(images)
114
+ sampled_ids = self.decoder.sample(features, states)
115
+ return sampled_ids
116
+
117
+ # 初始化應用
118
+ app = Flask(__name__)
119
+
120
+ # 加載詞彙表
121
+ vocab = tokenizer()
122
+ vocab.load_vocab('vocab_full.json')
123
+
124
+ # 加載模型
125
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
126
+ model = torch.load('im2text_model_full.pt', map_location=torch.device('cpu'))
127
+ model.to(device)
128
+ model.eval()
129
+
130
+ transform = transforms.Compose([
131
+ transforms.Resize((224, 224), antialias=True),
132
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
133
+ ])
134
+ @app.route('/')
135
+ def index():
136
+ return render_template('index.html')
137
+
138
+ @app.route('/upload', methods=['POST'])
139
+ def upload_image():
140
+ if 'file' not in request.files:
141
+ return jsonify({'error': 'No file part'})
142
+ file = request.files['file']
143
+ if file.filename == '':
144
+ return jsonify({'error': 'No selected file'})
145
+ if file:
146
+ # Convert image to RGB format if necessary and process in memory
147
+ image = Image.open(file.stream)
148
+ if image.format in ['GIF', 'WebP', 'PNG']:
149
+ image = image.convert('RGB')
150
+
151
+ # Save image to a BytesIO object
152
+ byte_io = BytesIO()
153
+ image.save(byte_io, 'JPEG')
154
+ byte_io.seek(0)
155
+
156
+ image = imageio.imread(byte_io)
157
+ if len(image.shape) == 2:
158
+ image = np.stack([image] * 3, axis=0)
159
+ else:
160
+ image = np.transpose(image, (2, 0, 1))
161
+ image = torch.tensor(image / 255.0).float()
162
+ image = transform(image).unsqueeze(0).to(device)
163
+
164
+ with torch.no_grad():
165
+ generated_caption = model.sample(image)
166
+ generated_caption_text = vocab.decode(generated_caption)
167
+
168
+ return jsonify({'caption': generated_caption_text})
169
+ if __name__ == '__main__':
170
+ app.run(debug=True)
im2text_model_full.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6f3a66c5ee31f06ee6ab67860d01c6c3230b318bfa8e441db91e7f52e5768d2
3
+ size 145253389
vocab_full.json ADDED
The diff for this file is too large to render. See raw diff