Spaces:
Sleeping
Sleeping
cpu pls
Browse files- models/model_main.py +14 -13
- test_few_shot.py +7 -5
models/model_main.py
CHANGED
|
@@ -5,6 +5,7 @@ from models.vgg_perceptual_loss import VGGPerceptualLoss
|
|
| 5 |
from models.transformers import *
|
| 6 |
from torch.autograd import Variable
|
| 7 |
|
|
|
|
| 8 |
class ModelMain(nn.Module):
|
| 9 |
|
| 10 |
def __init__(self, opts, mode='train'):
|
|
@@ -72,7 +73,7 @@ class ModelMain(nn.Module):
|
|
| 72 |
|
| 73 |
if mode in {'train', 'val'}:
|
| 74 |
# seq decoding (training or val mode)
|
| 75 |
-
tgt_mask = Variable(subsequent_mask(self.opts.max_seq_len).type_as(ref_pad_mask.data)).unsqueeze(0).expand(z.size(0), -1, -1, -1).
|
| 76 |
command_logits, args_logits, attn = self.transformer_seqdec(x=trg_seq_shifted, memory=latent_feat_seq, trg_char=trg_cls, tgt_mask=tgt_mask)
|
| 77 |
command_logits_2, args_logits_2 = self.transformer_seqdec.parallel_decoder(command_logits, args_logits, memory=latent_feat_seq.detach(), trg_char=trg_cls)
|
| 78 |
|
|
@@ -97,10 +98,10 @@ class ModelMain(nn.Module):
|
|
| 97 |
else: # testing (inference)
|
| 98 |
|
| 99 |
trg_len = trg_seq_shifted.size(0)
|
| 100 |
-
sampled_svg = torch.zeros(1, trg_seq.size(1), self.opts.dim_seq_short).
|
| 101 |
|
| 102 |
for t in range(0, trg_len):
|
| 103 |
-
tgt_mask = Variable(subsequent_mask(sampled_svg.size(0)).type_as(ref_seq_cat.data)).unsqueeze(0).expand(sampled_svg.size(1), -1, -1, -1).
|
| 104 |
command_logits, args_logits, attn = self.transformer_seqdec(x=sampled_svg, memory=latent_feat_seq, trg_char=trg_cls, tgt_mask=tgt_mask)
|
| 105 |
prob_comand = F.softmax(command_logits[:, -1, :], -1)
|
| 106 |
prob_args = F.softmax(args_logits[:, -1, :], -1)
|
|
@@ -151,29 +152,29 @@ class ModelMain(nn.Module):
|
|
| 151 |
|
| 152 |
|
| 153 |
if mode == 'train':
|
| 154 |
-
ref_cls = torch.randint(0, self.opts.char_num, (input_image.size(0), self.opts.ref_nshot)).
|
| 155 |
if opts.ref_nshot == 52: # For ENG to TH
|
| 156 |
-
ref_cls_upper = torch.randint(0, 26, (input_image.size(0), self.opts.ref_nshot // 2)).
|
| 157 |
-
ref_cls_lower = torch.randint(26, 52, (input_image.size(0), self.opts.ref_nshot // 2)).
|
| 158 |
ref_cls = torch.cat((ref_cls_upper, ref_cls_lower), -1)
|
| 159 |
elif mode == 'val':
|
| 160 |
-
ref_cls = torch.arange(0, self.opts.ref_nshot, 1).
|
| 161 |
else:
|
| 162 |
ref_ids = self.opts.ref_char_ids.split(',')
|
| 163 |
ref_ids = list(map(int, ref_ids))
|
| 164 |
assert len(ref_ids) == self.opts.ref_nshot
|
| 165 |
-
ref_cls = torch.tensor(ref_ids).
|
| 166 |
|
| 167 |
|
| 168 |
|
| 169 |
if mode in {'train', 'val'}:
|
| 170 |
-
trg_cls = torch.randint(0, self.opts.char_num, (input_image.size(0), 1)).
|
| 171 |
if opts.ref_nshot == 52:
|
| 172 |
-
trg_cls = torch.randint(52, opts.char_num, (input_image.size(0), 1)).
|
| 173 |
else:
|
| 174 |
-
trg_cls = torch.arange(0, self.opts.char_num).
|
| 175 |
if opts.ref_nshot == 52:
|
| 176 |
-
trg_cls = torch.randint(52, opts.char_num, (input_image.size(0), 1)).
|
| 177 |
trg_cls = trg_cls.view(self.opts.char_num, 1)
|
| 178 |
input_image = input_image.expand(self.opts.char_num, -1, -1, -1)
|
| 179 |
input_sequence = input_sequence.expand(self.opts.char_num, -1, -1, -1)
|
|
@@ -205,7 +206,7 @@ class ModelMain(nn.Module):
|
|
| 205 |
ref_pad_mask = torch.zeros(ref_seqlen_cat.size(0), self.opts.max_seq_len) # value = 1 means pos to be masked
|
| 206 |
for i in range(ref_seqlen_cat.size(0)):
|
| 207 |
ref_pad_mask[i,:ref_seqlen_cat[i]] = 1
|
| 208 |
-
ref_pad_mask = ref_pad_mask.
|
| 209 |
trg_seqlen = util_funcs.select_seqlens(input_seqlen, trg_cls, self.opts)
|
| 210 |
trg_seqlen = trg_seqlen.squeeze()
|
| 211 |
|
|
|
|
| 5 |
from models.transformers import *
|
| 6 |
from torch.autograd import Variable
|
| 7 |
|
| 8 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 9 |
class ModelMain(nn.Module):
|
| 10 |
|
| 11 |
def __init__(self, opts, mode='train'):
|
|
|
|
| 73 |
|
| 74 |
if mode in {'train', 'val'}:
|
| 75 |
# seq decoding (training or val mode)
|
| 76 |
+
tgt_mask = Variable(subsequent_mask(self.opts.max_seq_len).type_as(ref_pad_mask.data)).unsqueeze(0).expand(z.size(0), -1, -1, -1).to(device).float()
|
| 77 |
command_logits, args_logits, attn = self.transformer_seqdec(x=trg_seq_shifted, memory=latent_feat_seq, trg_char=trg_cls, tgt_mask=tgt_mask)
|
| 78 |
command_logits_2, args_logits_2 = self.transformer_seqdec.parallel_decoder(command_logits, args_logits, memory=latent_feat_seq.detach(), trg_char=trg_cls)
|
| 79 |
|
|
|
|
| 98 |
else: # testing (inference)
|
| 99 |
|
| 100 |
trg_len = trg_seq_shifted.size(0)
|
| 101 |
+
sampled_svg = torch.zeros(1, trg_seq.size(1), self.opts.dim_seq_short).to(device)
|
| 102 |
|
| 103 |
for t in range(0, trg_len):
|
| 104 |
+
tgt_mask = Variable(subsequent_mask(sampled_svg.size(0)).type_as(ref_seq_cat.data)).unsqueeze(0).expand(sampled_svg.size(1), -1, -1, -1).to(device).float()
|
| 105 |
command_logits, args_logits, attn = self.transformer_seqdec(x=sampled_svg, memory=latent_feat_seq, trg_char=trg_cls, tgt_mask=tgt_mask)
|
| 106 |
prob_comand = F.softmax(command_logits[:, -1, :], -1)
|
| 107 |
prob_args = F.softmax(args_logits[:, -1, :], -1)
|
|
|
|
| 152 |
|
| 153 |
|
| 154 |
if mode == 'train':
|
| 155 |
+
ref_cls = torch.randint(0, self.opts.char_num, (input_image.size(0), self.opts.ref_nshot)).to(device)
|
| 156 |
if opts.ref_nshot == 52: # For ENG to TH
|
| 157 |
+
ref_cls_upper = torch.randint(0, 26, (input_image.size(0), self.opts.ref_nshot // 2)).to(device)
|
| 158 |
+
ref_cls_lower = torch.randint(26, 52, (input_image.size(0), self.opts.ref_nshot // 2)).to(device)
|
| 159 |
ref_cls = torch.cat((ref_cls_upper, ref_cls_lower), -1)
|
| 160 |
elif mode == 'val':
|
| 161 |
+
ref_cls = torch.arange(0, self.opts.ref_nshot, 1).to(device).unsqueeze(0).expand(input_image.size(0), -1)
|
| 162 |
else:
|
| 163 |
ref_ids = self.opts.ref_char_ids.split(',')
|
| 164 |
ref_ids = list(map(int, ref_ids))
|
| 165 |
assert len(ref_ids) == self.opts.ref_nshot
|
| 166 |
+
ref_cls = torch.tensor(ref_ids).to(device).unsqueeze(0).expand(self.opts.char_num, -1)
|
| 167 |
|
| 168 |
|
| 169 |
|
| 170 |
if mode in {'train', 'val'}:
|
| 171 |
+
trg_cls = torch.randint(0, self.opts.char_num, (input_image.size(0), 1)).to(device)
|
| 172 |
if opts.ref_nshot == 52:
|
| 173 |
+
trg_cls = torch.randint(52, opts.char_num, (input_image.size(0), 1)).to(device)
|
| 174 |
else:
|
| 175 |
+
trg_cls = torch.arange(0, self.opts.char_num).to(device)
|
| 176 |
if opts.ref_nshot == 52:
|
| 177 |
+
trg_cls = torch.randint(52, opts.char_num, (input_image.size(0), 1)).to(device)
|
| 178 |
trg_cls = trg_cls.view(self.opts.char_num, 1)
|
| 179 |
input_image = input_image.expand(self.opts.char_num, -1, -1, -1)
|
| 180 |
input_sequence = input_sequence.expand(self.opts.char_num, -1, -1, -1)
|
|
|
|
| 206 |
ref_pad_mask = torch.zeros(ref_seqlen_cat.size(0), self.opts.max_seq_len) # value = 1 means pos to be masked
|
| 207 |
for i in range(ref_seqlen_cat.size(0)):
|
| 208 |
ref_pad_mask[i,:ref_seqlen_cat[i]] = 1
|
| 209 |
+
ref_pad_mask = ref_pad_mask.to(device).float().unsqueeze(1)
|
| 210 |
trg_seqlen = util_funcs.select_seqlens(input_seqlen, trg_cls, self.opts)
|
| 211 |
trg_seqlen = trg_seqlen.squeeze()
|
| 212 |
|
test_few_shot.py
CHANGED
|
@@ -23,16 +23,18 @@ def test_main_model(opts):
|
|
| 23 |
dir_res = os.path.join(f"{opts.exp_path}", "experiments/", opts.name_exp, "results")
|
| 24 |
|
| 25 |
test_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, 'test')
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
device = torch.device("cpu")
|
| 30 |
if opts.streamlit:
|
| 31 |
st.write("Loading Model Weight...")
|
|
|
|
|
|
|
| 32 |
model_main = ModelMain(opts)
|
| 33 |
path_ckpt = os.path.join(f"{opts.model_path}")
|
| 34 |
-
model_main.load_state_dict(torch.load(path_ckpt, map_location=
|
| 35 |
model_main.to(device)
|
|
|
|
| 36 |
model_main.eval()
|
| 37 |
with torch.no_grad():
|
| 38 |
|
|
|
|
| 23 |
dir_res = os.path.join(f"{opts.exp_path}", "experiments/", opts.name_exp, "results")
|
| 24 |
|
| 25 |
test_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, 'test')
|
| 26 |
+
|
| 27 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 28 |
+
print("Inference With Device:", device)
|
|
|
|
| 29 |
if opts.streamlit:
|
| 30 |
st.write("Loading Model Weight...")
|
| 31 |
+
st.write("Inference With Device:", device)
|
| 32 |
+
|
| 33 |
model_main = ModelMain(opts)
|
| 34 |
path_ckpt = os.path.join(f"{opts.model_path}")
|
| 35 |
+
model_main.load_state_dict(torch.load(path_ckpt, map_location=device)['model'])
|
| 36 |
model_main.to(device)
|
| 37 |
+
|
| 38 |
model_main.eval()
|
| 39 |
with torch.no_grad():
|
| 40 |
|