|
|
|
|
|
import time |
|
import pickle |
|
import torch |
|
import torchvision.transforms as transforms |
|
from torch.utils.data import DataLoader |
|
from torch.autograd import Variable |
|
from PIL import Image |
|
import cv2 |
|
|
|
from utils.models import * |
|
from utils.dataset import * |
|
from utils.loss import * |
|
from utils.build_tag import * |
|
|
|
|
|
class CaptionSampler(object): |
|
def __init__(self): |
|
|
|
self.args = { |
|
"model_dir": "/Users/jkottu/Desktop/image-captioning-chest-xrays/vit-coatten", |
|
"image_dir": "./data/images", |
|
"caption_json": "data/new_data/captions.json", |
|
"vocab_path": "/Users/jkottu/Desktop/image-captioning-chest-xrays/vit-coatten/vocab.pkl", |
|
"file_lists": "data/new_data/test_data.txt", |
|
"load_model_path": "train_best_loss.pth.tar", |
|
"resize": 224, |
|
"cam_size": 224, |
|
"generate_dir": "cam", |
|
"result_path": "results", |
|
"result_name": "debug", |
|
"momentum": 0.1, |
|
"visual_model_name": "densenet201", |
|
"pretrained": False, |
|
"classes": 210, |
|
"sementic_features_dim": 512, |
|
"k": 10, |
|
"attention_version": "v4", |
|
"embed_size": 512, |
|
"hidden_size": 512, |
|
"sent_version": "v1", |
|
"sentence_num_layers": 2, |
|
"dropout": 0.1, |
|
"word_num_layers": 1, |
|
"s_max": 10, |
|
"n_max": 30, |
|
"batch_size": 8, |
|
"lambda_tag": 10000, |
|
"lambda_stop": 10, |
|
"lambda_word": 1, |
|
"cuda": False |
|
} |
|
|
|
self.vocab = self.__init_vocab() |
|
self.tagger = self.__init_tagger() |
|
self.transform = self.__init_transform() |
|
self.model_state_dict = self.__load_mode_state_dict() |
|
|
|
self.extractor = self.__init_visual_extractor() |
|
self.mlc = self.__init_mlc() |
|
self.co_attention = self.__init_co_attention() |
|
self.sentence_model = self.__init_sentence_model() |
|
self.word_model = self.__init_word_word() |
|
|
|
self.ce_criterion = self._init_ce_criterion() |
|
self.mse_criterion = self._init_mse_criterion() |
|
|
|
@staticmethod |
|
def _init_ce_criterion(): |
|
return nn.CrossEntropyLoss(size_average=False, reduce=False) |
|
|
|
@staticmethod |
|
def _init_mse_criterion(): |
|
return nn.MSELoss() |
|
|
|
|
|
def sample(self, image_file): |
|
self.extractor.eval() |
|
self.mlc.eval() |
|
self.co_attention.eval() |
|
self.sentence_model.eval() |
|
self.word_model.eval() |
|
|
|
|
|
imageData = self.transform(imageData) |
|
imageData = imageData.unsqueeze_(0) |
|
|
|
print(imageData.shape) |
|
|
|
image = self.__to_var(imageData, requires_grad=False) |
|
|
|
visual_features, avg_features = self.extractor.forward(image) |
|
|
|
tags, semantic_features = self.mlc(avg_features) |
|
sentence_states = None |
|
prev_hidden_states = self.__to_var(torch.zeros(image.shape[0], 1, self.args["hidden_size"])) |
|
|
|
pred_sentences = [] |
|
|
|
for i in range(self.args["s_max"]): |
|
ctx, alpha_v, alpha_a = self.co_attention.forward(avg_features, semantic_features, prev_hidden_states) |
|
topic, p_stop, hidden_state, sentence_states = self.sentence_model.forward(ctx, |
|
prev_hidden_states, |
|
sentence_states) |
|
p_stop = p_stop.squeeze(1) |
|
p_stop = torch.max(p_stop, 1)[1].unsqueeze(1) |
|
|
|
start_tokens = np.zeros((topic.shape[0], 1)) |
|
start_tokens[:, 0] = self.vocab('<start>') |
|
start_tokens = self.__to_var(torch.Tensor(start_tokens).long(), requires_grad=False) |
|
|
|
sampled_ids = self.word_model.sample(topic, start_tokens) |
|
prev_hidden_states = hidden_state |
|
|
|
sampled_ids = sampled_ids * p_stop.numpy() |
|
|
|
|
|
pred_sentences.append(self.__vec2sent(sampled_ids[0])) |
|
|
|
return pred_sentences |
|
|
|
|
|
def __init_cam_path(self, image_file): |
|
generate_dir = os.path.join(self.args["model_dir"], self.args["generate_dir"]) |
|
if not os.path.exists(generate_dir): |
|
os.makedirs(generate_dir) |
|
|
|
image_dir = os.path.join(generate_dir, image_file) |
|
|
|
if not os.path.exists(image_dir): |
|
os.makedirs(image_dir) |
|
return image_dir |
|
|
|
def __save_json(self, result): |
|
result_path = os.path.join(self.args["model_dir"], self.args["result_path"]) |
|
if not os.path.exists(result_path): |
|
os.makedirs(result_path) |
|
with open(os.path.join(result_path, '{}.json'.format(self.args["result_name"])), 'w') as f: |
|
json.dump(result, f) |
|
|
|
def __load_mode_state_dict(self): |
|
try: |
|
model_state_dict = torch.load(os.path.join(self.args["model_dir"], self.args["load_model_path"]), map_location=torch.device('cpu')) |
|
print("[Load Model-{} Succeed!]".format(self.args["load_model_path"])) |
|
print("Load From Epoch {}".format(model_state_dict['epoch'])) |
|
return model_state_dict |
|
except Exception as err: |
|
print("[Load Model Failed] {}".format(err)) |
|
raise err |
|
|
|
def __init_tagger(self): |
|
return Tag() |
|
|
|
def __vec2sent(self, array): |
|
sampled_caption = [] |
|
for word_id in array: |
|
word = self.vocab.get_word_by_id(word_id) |
|
if word == '<start>': |
|
continue |
|
if word == '<end>' or word == '<pad>': |
|
break |
|
sampled_caption.append(word) |
|
return ' '.join(sampled_caption) |
|
|
|
def __init_vocab(self): |
|
with open(self.args["vocab_path"], 'rb') as f: |
|
vocab = pickle.load(f) |
|
return vocab |
|
|
|
def __init_data_loader(self, file_list): |
|
data_loader = get_loader(image_dir=self.args.image_dir, |
|
caption_json=self.args.caption_json, |
|
file_list=file_list, |
|
vocabulary=self.vocab, |
|
transform=self.transform, |
|
batch_size=self.args.batch_size, |
|
s_max=self.args.s_max, |
|
n_max=self.args.n_max, |
|
shuffle=False) |
|
return data_loader |
|
|
|
def __init_transform(self): |
|
transform = transforms.Compose([ |
|
transforms.Resize((self.args["resize"], self.args["resize"])), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.485, 0.456, 0.406), |
|
(0.229, 0.224, 0.225))]) |
|
return transform |
|
|
|
def __to_var(self, x, requires_grad=True): |
|
if self.args["cuda"]: |
|
x = x.cuda() |
|
return Variable(x, requires_grad=requires_grad) |
|
|
|
def __init_visual_extractor(self): |
|
model = VisualFeatureExtractor(model_name=self.args["visual_model_name"], |
|
pretrained=self.args["pretrained"]) |
|
|
|
if self.model_state_dict is not None: |
|
print("Visual Extractor Loaded!") |
|
model.load_state_dict(self.model_state_dict['extractor']) |
|
|
|
if self.args["cuda"]: |
|
model = model.cuda() |
|
|
|
return model |
|
|
|
def __init_mlc(self): |
|
model = MLC(classes=self.args["classes"], |
|
sementic_features_dim=self.args["sementic_features_dim"], |
|
fc_in_features=self.extractor.out_features, |
|
k=self.args["k"]) |
|
|
|
if self.model_state_dict is not None: |
|
print("MLC Loaded!") |
|
model.load_state_dict(self.model_state_dict['mlc']) |
|
|
|
if self.args["cuda"]: |
|
model = model.cuda() |
|
|
|
return model |
|
|
|
def __init_co_attention(self): |
|
model = CoAttention(version=self.args["attention_version"], |
|
embed_size=self.args["embed_size"], |
|
hidden_size=self.args["hidden_size"], |
|
visual_size=self.extractor.out_features, |
|
k=self.args["k"], |
|
momentum=self.args["momentum"]) |
|
|
|
if self.model_state_dict is not None: |
|
print("Co-Attention Loaded!") |
|
model.load_state_dict(self.model_state_dict['co_attention']) |
|
|
|
if self.args["cuda"]: |
|
model = model.cuda() |
|
|
|
return model |
|
|
|
def __init_sentence_model(self): |
|
model = SentenceLSTM(version=self.args["sent_version"], |
|
embed_size=self.args["embed_size"], |
|
hidden_size=self.args["hidden_size"], |
|
num_layers=self.args["sentence_num_layers"], |
|
dropout=self.args["dropout"], |
|
momentum=self.args["momentum"]) |
|
|
|
if self.model_state_dict is not None: |
|
print("Sentence Model Loaded!") |
|
model.load_state_dict(self.model_state_dict['sentence_model']) |
|
|
|
if self.args["cuda"]: |
|
model = model.cuda() |
|
|
|
return model |
|
|
|
def __init_word_word(self): |
|
model = WordLSTM(vocab_size=len(self.vocab), |
|
embed_size=self.args["embed_size"], |
|
hidden_size=self.args["hidden_size"], |
|
num_layers=self.args["word_num_layers"], |
|
n_max=self.args["n_max"]) |
|
|
|
if self.model_state_dict is not None: |
|
print("Word Model Loaded!") |
|
model.load_state_dict(self.model_state_dict['word_model']) |
|
|
|
if self.args["cuda"]: |
|
model = model.cuda() |
|
|
|
return model |
|
|
|
|
|
|
|
def main(image): |
|
sampler = CaptionSampler() |
|
|
|
caption = sampler.sample(image) |
|
print(caption[0]) |
|
|
|
return caption[0] |
|
|
|
|