|
import torch |
|
import caption_model |
|
from transformers import BertTokenizer |
|
import torchvision |
|
from PIL import Image |
|
from configuration import Config |
|
import numpy as np |
|
|
|
def under_max(image): |
|
if image.mode != 'RGB': |
|
image = image.convert("RGB") |
|
|
|
shape = np.array(image.size, dtype=np.float) |
|
long_dim = max(shape) |
|
scale = 299 / long_dim |
|
|
|
new_shape = (shape * scale).astype(int) |
|
image = image.resize(new_shape) |
|
|
|
return image |
|
|
|
class Model(object): |
|
def __init__(self, gpu=None): |
|
config = Config() |
|
config.device = 'cpu' if gpu is None else 'cuda:{}'.format(gpu) |
|
model, _ = caption_model.build_model(config) |
|
checkpoint = torch.load('./checkpoint.pth', map_location='cpu') |
|
model.load_state_dict(checkpoint['model']) |
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
start_token = tokenizer.convert_tokens_to_ids(tokenizer._cls_token) |
|
end_token = tokenizer.convert_tokens_to_ids(tokenizer._sep_token) |
|
|
|
self.caption = torch.zeros((1, config.max_position_embeddings), dtype=torch.long).to(config.device) |
|
self.cap_mask = torch.ones((1, config.max_position_embeddings), dtype=torch.bool).to(config.device) |
|
|
|
self.caption[:, 0] = start_token |
|
self.cap_mask[:, 0] = False |
|
|
|
self.val_transform = torchvision.transforms.Compose([ |
|
torchvision.transforms.Lambda(under_max), |
|
torchvision.transforms.ToTensor(), |
|
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
]) |
|
model.to(config.device) |
|
self.model = model |
|
self.config = config |
|
self.tokenizer = tokenizer |
|
|
|
def evaluate(self, im): |
|
self.model.eval() |
|
for i in range(self.config.max_position_embeddings - 1): |
|
predictions = self.model(im.to(self.config.device), self.caption.to(self.config.device), self.cap_mask.to(self.config.device)) |
|
predictions = predictions[:, i, :] |
|
predicted_id = torch.argmax(predictions, axis=-1).to(self.config.device) |
|
|
|
if predicted_id[0] == 102: |
|
return self.caption |
|
|
|
self.caption[:, i+1] = predicted_id[0] |
|
self.cap_mask[:, i+1] = False |
|
return caption |
|
|
|
def predict(self, image_path): |
|
image = Image.open(image_path) |
|
image = self.val_transform(image) |
|
image = image.unsqueeze(0) |
|
output = self.evaluate(image) |
|
return self.tokenizer.decode(output[0].tolist(), skip_special_tokens=True) |
|
|
|
if __name__ == "__main__": |
|
model = Model() |
|
result = model.predict("./image.jpg") |
|
print(result) |
|
|