microhum commited on
Commit
94dff7f
·
1 Parent(s): ccae047
Files changed (2) hide show
  1. models/model_main.py +14 -13
  2. test_few_shot.py +7 -5
models/model_main.py CHANGED
@@ -5,6 +5,7 @@ from models.vgg_perceptual_loss import VGGPerceptualLoss
5
  from models.transformers import *
6
  from torch.autograd import Variable
7
 
 
8
  class ModelMain(nn.Module):
9
 
10
  def __init__(self, opts, mode='train'):
@@ -72,7 +73,7 @@ class ModelMain(nn.Module):
72
 
73
  if mode in {'train', 'val'}:
74
  # seq decoding (training or val mode)
75
- tgt_mask = Variable(subsequent_mask(self.opts.max_seq_len).type_as(ref_pad_mask.data)).unsqueeze(0).expand(z.size(0), -1, -1, -1).cuda().float()
76
  command_logits, args_logits, attn = self.transformer_seqdec(x=trg_seq_shifted, memory=latent_feat_seq, trg_char=trg_cls, tgt_mask=tgt_mask)
77
  command_logits_2, args_logits_2 = self.transformer_seqdec.parallel_decoder(command_logits, args_logits, memory=latent_feat_seq.detach(), trg_char=trg_cls)
78
 
@@ -97,10 +98,10 @@ class ModelMain(nn.Module):
97
  else: # testing (inference)
98
 
99
  trg_len = trg_seq_shifted.size(0)
100
- sampled_svg = torch.zeros(1, trg_seq.size(1), self.opts.dim_seq_short).cuda()
101
 
102
  for t in range(0, trg_len):
103
- tgt_mask = Variable(subsequent_mask(sampled_svg.size(0)).type_as(ref_seq_cat.data)).unsqueeze(0).expand(sampled_svg.size(1), -1, -1, -1).cuda().float()
104
  command_logits, args_logits, attn = self.transformer_seqdec(x=sampled_svg, memory=latent_feat_seq, trg_char=trg_cls, tgt_mask=tgt_mask)
105
  prob_comand = F.softmax(command_logits[:, -1, :], -1)
106
  prob_args = F.softmax(args_logits[:, -1, :], -1)
@@ -151,29 +152,29 @@ class ModelMain(nn.Module):
151
 
152
 
153
  if mode == 'train':
154
- ref_cls = torch.randint(0, self.opts.char_num, (input_image.size(0), self.opts.ref_nshot)).cuda()
155
  if opts.ref_nshot == 52: # For ENG to TH
156
- ref_cls_upper = torch.randint(0, 26, (input_image.size(0), self.opts.ref_nshot // 2)).cuda()
157
- ref_cls_lower = torch.randint(26, 52, (input_image.size(0), self.opts.ref_nshot // 2)).cuda()
158
  ref_cls = torch.cat((ref_cls_upper, ref_cls_lower), -1)
159
  elif mode == 'val':
160
- ref_cls = torch.arange(0, self.opts.ref_nshot, 1).cuda().unsqueeze(0).expand(input_image.size(0), -1)
161
  else:
162
  ref_ids = self.opts.ref_char_ids.split(',')
163
  ref_ids = list(map(int, ref_ids))
164
  assert len(ref_ids) == self.opts.ref_nshot
165
- ref_cls = torch.tensor(ref_ids).cuda().unsqueeze(0).expand(self.opts.char_num, -1)
166
 
167
 
168
 
169
  if mode in {'train', 'val'}:
170
- trg_cls = torch.randint(0, self.opts.char_num, (input_image.size(0), 1)).cuda()
171
  if opts.ref_nshot == 52:
172
- trg_cls = torch.randint(52, opts.char_num, (input_image.size(0), 1)).cuda()
173
  else:
174
- trg_cls = torch.arange(0, self.opts.char_num).cuda()
175
  if opts.ref_nshot == 52:
176
- trg_cls = torch.randint(52, opts.char_num, (input_image.size(0), 1)).cuda()
177
  trg_cls = trg_cls.view(self.opts.char_num, 1)
178
  input_image = input_image.expand(self.opts.char_num, -1, -1, -1)
179
  input_sequence = input_sequence.expand(self.opts.char_num, -1, -1, -1)
@@ -205,7 +206,7 @@ class ModelMain(nn.Module):
205
  ref_pad_mask = torch.zeros(ref_seqlen_cat.size(0), self.opts.max_seq_len) # value = 1 means pos to be masked
206
  for i in range(ref_seqlen_cat.size(0)):
207
  ref_pad_mask[i,:ref_seqlen_cat[i]] = 1
208
- ref_pad_mask = ref_pad_mask.cuda().float().unsqueeze(1)
209
  trg_seqlen = util_funcs.select_seqlens(input_seqlen, trg_cls, self.opts)
210
  trg_seqlen = trg_seqlen.squeeze()
211
 
 
5
  from models.transformers import *
6
  from torch.autograd import Variable
7
 
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
  class ModelMain(nn.Module):
10
 
11
  def __init__(self, opts, mode='train'):
 
73
 
74
  if mode in {'train', 'val'}:
75
  # seq decoding (training or val mode)
76
+ tgt_mask = Variable(subsequent_mask(self.opts.max_seq_len).type_as(ref_pad_mask.data)).unsqueeze(0).expand(z.size(0), -1, -1, -1).to(device).float()
77
  command_logits, args_logits, attn = self.transformer_seqdec(x=trg_seq_shifted, memory=latent_feat_seq, trg_char=trg_cls, tgt_mask=tgt_mask)
78
  command_logits_2, args_logits_2 = self.transformer_seqdec.parallel_decoder(command_logits, args_logits, memory=latent_feat_seq.detach(), trg_char=trg_cls)
79
 
 
98
  else: # testing (inference)
99
 
100
  trg_len = trg_seq_shifted.size(0)
101
+ sampled_svg = torch.zeros(1, trg_seq.size(1), self.opts.dim_seq_short).to(device)
102
 
103
  for t in range(0, trg_len):
104
+ tgt_mask = Variable(subsequent_mask(sampled_svg.size(0)).type_as(ref_seq_cat.data)).unsqueeze(0).expand(sampled_svg.size(1), -1, -1, -1).to(device).float()
105
  command_logits, args_logits, attn = self.transformer_seqdec(x=sampled_svg, memory=latent_feat_seq, trg_char=trg_cls, tgt_mask=tgt_mask)
106
  prob_comand = F.softmax(command_logits[:, -1, :], -1)
107
  prob_args = F.softmax(args_logits[:, -1, :], -1)
 
152
 
153
 
154
  if mode == 'train':
155
+ ref_cls = torch.randint(0, self.opts.char_num, (input_image.size(0), self.opts.ref_nshot)).to(device)
156
  if opts.ref_nshot == 52: # For ENG to TH
157
+ ref_cls_upper = torch.randint(0, 26, (input_image.size(0), self.opts.ref_nshot // 2)).to(device)
158
+ ref_cls_lower = torch.randint(26, 52, (input_image.size(0), self.opts.ref_nshot // 2)).to(device)
159
  ref_cls = torch.cat((ref_cls_upper, ref_cls_lower), -1)
160
  elif mode == 'val':
161
+ ref_cls = torch.arange(0, self.opts.ref_nshot, 1).to(device).unsqueeze(0).expand(input_image.size(0), -1)
162
  else:
163
  ref_ids = self.opts.ref_char_ids.split(',')
164
  ref_ids = list(map(int, ref_ids))
165
  assert len(ref_ids) == self.opts.ref_nshot
166
+ ref_cls = torch.tensor(ref_ids).to(device).unsqueeze(0).expand(self.opts.char_num, -1)
167
 
168
 
169
 
170
  if mode in {'train', 'val'}:
171
+ trg_cls = torch.randint(0, self.opts.char_num, (input_image.size(0), 1)).to(device)
172
  if opts.ref_nshot == 52:
173
+ trg_cls = torch.randint(52, opts.char_num, (input_image.size(0), 1)).to(device)
174
  else:
175
+ trg_cls = torch.arange(0, self.opts.char_num).to(device)
176
  if opts.ref_nshot == 52:
177
+ trg_cls = torch.randint(52, opts.char_num, (input_image.size(0), 1)).to(device)
178
  trg_cls = trg_cls.view(self.opts.char_num, 1)
179
  input_image = input_image.expand(self.opts.char_num, -1, -1, -1)
180
  input_sequence = input_sequence.expand(self.opts.char_num, -1, -1, -1)
 
206
  ref_pad_mask = torch.zeros(ref_seqlen_cat.size(0), self.opts.max_seq_len) # value = 1 means pos to be masked
207
  for i in range(ref_seqlen_cat.size(0)):
208
  ref_pad_mask[i,:ref_seqlen_cat[i]] = 1
209
+ ref_pad_mask = ref_pad_mask.to(device).float().unsqueeze(1)
210
  trg_seqlen = util_funcs.select_seqlens(input_seqlen, trg_cls, self.opts)
211
  trg_seqlen = trg_seqlen.squeeze()
212
 
test_few_shot.py CHANGED
@@ -23,16 +23,18 @@ def test_main_model(opts):
23
  dir_res = os.path.join(f"{opts.exp_path}", "experiments/", opts.name_exp, "results")
24
 
25
  test_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, 'test')
26
- if torch.cuda.is_available():
27
- device = torch.device("cuda")
28
- else:
29
- device = torch.device("cpu")
30
  if opts.streamlit:
31
  st.write("Loading Model Weight...")
 
 
32
  model_main = ModelMain(opts)
33
  path_ckpt = os.path.join(f"{opts.model_path}")
34
- model_main.load_state_dict(torch.load(path_ckpt, map_location=torch.device('cpu'))['model'])
35
  model_main.to(device)
 
36
  model_main.eval()
37
  with torch.no_grad():
38
 
 
23
  dir_res = os.path.join(f"{opts.exp_path}", "experiments/", opts.name_exp, "results")
24
 
25
  test_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, 'test')
26
+
27
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28
+ print("Inference With Device:", device)
 
29
  if opts.streamlit:
30
  st.write("Loading Model Weight...")
31
+ st.write("Inference With Device:", device)
32
+
33
  model_main = ModelMain(opts)
34
  path_ckpt = os.path.join(f"{opts.model_path}")
35
+ model_main.load_state_dict(torch.load(path_ckpt, map_location=device)['model'])
36
  model_main.to(device)
37
+
38
  model_main.eval()
39
  with torch.no_grad():
40