Spaces:
Sleeping
Sleeping
cpu pls
Browse files- models/model_main.py +14 -13
- test_few_shot.py +7 -5
models/model_main.py
CHANGED
@@ -5,6 +5,7 @@ from models.vgg_perceptual_loss import VGGPerceptualLoss
|
|
5 |
from models.transformers import *
|
6 |
from torch.autograd import Variable
|
7 |
|
|
|
8 |
class ModelMain(nn.Module):
|
9 |
|
10 |
def __init__(self, opts, mode='train'):
|
@@ -72,7 +73,7 @@ class ModelMain(nn.Module):
|
|
72 |
|
73 |
if mode in {'train', 'val'}:
|
74 |
# seq decoding (training or val mode)
|
75 |
-
tgt_mask = Variable(subsequent_mask(self.opts.max_seq_len).type_as(ref_pad_mask.data)).unsqueeze(0).expand(z.size(0), -1, -1, -1).
|
76 |
command_logits, args_logits, attn = self.transformer_seqdec(x=trg_seq_shifted, memory=latent_feat_seq, trg_char=trg_cls, tgt_mask=tgt_mask)
|
77 |
command_logits_2, args_logits_2 = self.transformer_seqdec.parallel_decoder(command_logits, args_logits, memory=latent_feat_seq.detach(), trg_char=trg_cls)
|
78 |
|
@@ -97,10 +98,10 @@ class ModelMain(nn.Module):
|
|
97 |
else: # testing (inference)
|
98 |
|
99 |
trg_len = trg_seq_shifted.size(0)
|
100 |
-
sampled_svg = torch.zeros(1, trg_seq.size(1), self.opts.dim_seq_short).
|
101 |
|
102 |
for t in range(0, trg_len):
|
103 |
-
tgt_mask = Variable(subsequent_mask(sampled_svg.size(0)).type_as(ref_seq_cat.data)).unsqueeze(0).expand(sampled_svg.size(1), -1, -1, -1).
|
104 |
command_logits, args_logits, attn = self.transformer_seqdec(x=sampled_svg, memory=latent_feat_seq, trg_char=trg_cls, tgt_mask=tgt_mask)
|
105 |
prob_comand = F.softmax(command_logits[:, -1, :], -1)
|
106 |
prob_args = F.softmax(args_logits[:, -1, :], -1)
|
@@ -151,29 +152,29 @@ class ModelMain(nn.Module):
|
|
151 |
|
152 |
|
153 |
if mode == 'train':
|
154 |
-
ref_cls = torch.randint(0, self.opts.char_num, (input_image.size(0), self.opts.ref_nshot)).
|
155 |
if opts.ref_nshot == 52: # For ENG to TH
|
156 |
-
ref_cls_upper = torch.randint(0, 26, (input_image.size(0), self.opts.ref_nshot // 2)).
|
157 |
-
ref_cls_lower = torch.randint(26, 52, (input_image.size(0), self.opts.ref_nshot // 2)).
|
158 |
ref_cls = torch.cat((ref_cls_upper, ref_cls_lower), -1)
|
159 |
elif mode == 'val':
|
160 |
-
ref_cls = torch.arange(0, self.opts.ref_nshot, 1).
|
161 |
else:
|
162 |
ref_ids = self.opts.ref_char_ids.split(',')
|
163 |
ref_ids = list(map(int, ref_ids))
|
164 |
assert len(ref_ids) == self.opts.ref_nshot
|
165 |
-
ref_cls = torch.tensor(ref_ids).
|
166 |
|
167 |
|
168 |
|
169 |
if mode in {'train', 'val'}:
|
170 |
-
trg_cls = torch.randint(0, self.opts.char_num, (input_image.size(0), 1)).
|
171 |
if opts.ref_nshot == 52:
|
172 |
-
trg_cls = torch.randint(52, opts.char_num, (input_image.size(0), 1)).
|
173 |
else:
|
174 |
-
trg_cls = torch.arange(0, self.opts.char_num).
|
175 |
if opts.ref_nshot == 52:
|
176 |
-
trg_cls = torch.randint(52, opts.char_num, (input_image.size(0), 1)).
|
177 |
trg_cls = trg_cls.view(self.opts.char_num, 1)
|
178 |
input_image = input_image.expand(self.opts.char_num, -1, -1, -1)
|
179 |
input_sequence = input_sequence.expand(self.opts.char_num, -1, -1, -1)
|
@@ -205,7 +206,7 @@ class ModelMain(nn.Module):
|
|
205 |
ref_pad_mask = torch.zeros(ref_seqlen_cat.size(0), self.opts.max_seq_len) # value = 1 means pos to be masked
|
206 |
for i in range(ref_seqlen_cat.size(0)):
|
207 |
ref_pad_mask[i,:ref_seqlen_cat[i]] = 1
|
208 |
-
ref_pad_mask = ref_pad_mask.
|
209 |
trg_seqlen = util_funcs.select_seqlens(input_seqlen, trg_cls, self.opts)
|
210 |
trg_seqlen = trg_seqlen.squeeze()
|
211 |
|
|
|
5 |
from models.transformers import *
|
6 |
from torch.autograd import Variable
|
7 |
|
8 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
9 |
class ModelMain(nn.Module):
|
10 |
|
11 |
def __init__(self, opts, mode='train'):
|
|
|
73 |
|
74 |
if mode in {'train', 'val'}:
|
75 |
# seq decoding (training or val mode)
|
76 |
+
tgt_mask = Variable(subsequent_mask(self.opts.max_seq_len).type_as(ref_pad_mask.data)).unsqueeze(0).expand(z.size(0), -1, -1, -1).to(device).float()
|
77 |
command_logits, args_logits, attn = self.transformer_seqdec(x=trg_seq_shifted, memory=latent_feat_seq, trg_char=trg_cls, tgt_mask=tgt_mask)
|
78 |
command_logits_2, args_logits_2 = self.transformer_seqdec.parallel_decoder(command_logits, args_logits, memory=latent_feat_seq.detach(), trg_char=trg_cls)
|
79 |
|
|
|
98 |
else: # testing (inference)
|
99 |
|
100 |
trg_len = trg_seq_shifted.size(0)
|
101 |
+
sampled_svg = torch.zeros(1, trg_seq.size(1), self.opts.dim_seq_short).to(device)
|
102 |
|
103 |
for t in range(0, trg_len):
|
104 |
+
tgt_mask = Variable(subsequent_mask(sampled_svg.size(0)).type_as(ref_seq_cat.data)).unsqueeze(0).expand(sampled_svg.size(1), -1, -1, -1).to(device).float()
|
105 |
command_logits, args_logits, attn = self.transformer_seqdec(x=sampled_svg, memory=latent_feat_seq, trg_char=trg_cls, tgt_mask=tgt_mask)
|
106 |
prob_comand = F.softmax(command_logits[:, -1, :], -1)
|
107 |
prob_args = F.softmax(args_logits[:, -1, :], -1)
|
|
|
152 |
|
153 |
|
154 |
if mode == 'train':
|
155 |
+
ref_cls = torch.randint(0, self.opts.char_num, (input_image.size(0), self.opts.ref_nshot)).to(device)
|
156 |
if opts.ref_nshot == 52: # For ENG to TH
|
157 |
+
ref_cls_upper = torch.randint(0, 26, (input_image.size(0), self.opts.ref_nshot // 2)).to(device)
|
158 |
+
ref_cls_lower = torch.randint(26, 52, (input_image.size(0), self.opts.ref_nshot // 2)).to(device)
|
159 |
ref_cls = torch.cat((ref_cls_upper, ref_cls_lower), -1)
|
160 |
elif mode == 'val':
|
161 |
+
ref_cls = torch.arange(0, self.opts.ref_nshot, 1).to(device).unsqueeze(0).expand(input_image.size(0), -1)
|
162 |
else:
|
163 |
ref_ids = self.opts.ref_char_ids.split(',')
|
164 |
ref_ids = list(map(int, ref_ids))
|
165 |
assert len(ref_ids) == self.opts.ref_nshot
|
166 |
+
ref_cls = torch.tensor(ref_ids).to(device).unsqueeze(0).expand(self.opts.char_num, -1)
|
167 |
|
168 |
|
169 |
|
170 |
if mode in {'train', 'val'}:
|
171 |
+
trg_cls = torch.randint(0, self.opts.char_num, (input_image.size(0), 1)).to(device)
|
172 |
if opts.ref_nshot == 52:
|
173 |
+
trg_cls = torch.randint(52, opts.char_num, (input_image.size(0), 1)).to(device)
|
174 |
else:
|
175 |
+
trg_cls = torch.arange(0, self.opts.char_num).to(device)
|
176 |
if opts.ref_nshot == 52:
|
177 |
+
trg_cls = torch.randint(52, opts.char_num, (input_image.size(0), 1)).to(device)
|
178 |
trg_cls = trg_cls.view(self.opts.char_num, 1)
|
179 |
input_image = input_image.expand(self.opts.char_num, -1, -1, -1)
|
180 |
input_sequence = input_sequence.expand(self.opts.char_num, -1, -1, -1)
|
|
|
206 |
ref_pad_mask = torch.zeros(ref_seqlen_cat.size(0), self.opts.max_seq_len) # value = 1 means pos to be masked
|
207 |
for i in range(ref_seqlen_cat.size(0)):
|
208 |
ref_pad_mask[i,:ref_seqlen_cat[i]] = 1
|
209 |
+
ref_pad_mask = ref_pad_mask.to(device).float().unsqueeze(1)
|
210 |
trg_seqlen = util_funcs.select_seqlens(input_seqlen, trg_cls, self.opts)
|
211 |
trg_seqlen = trg_seqlen.squeeze()
|
212 |
|
test_few_shot.py
CHANGED
@@ -23,16 +23,18 @@ def test_main_model(opts):
|
|
23 |
dir_res = os.path.join(f"{opts.exp_path}", "experiments/", opts.name_exp, "results")
|
24 |
|
25 |
test_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, 'test')
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
device = torch.device("cpu")
|
30 |
if opts.streamlit:
|
31 |
st.write("Loading Model Weight...")
|
|
|
|
|
32 |
model_main = ModelMain(opts)
|
33 |
path_ckpt = os.path.join(f"{opts.model_path}")
|
34 |
-
model_main.load_state_dict(torch.load(path_ckpt, map_location=
|
35 |
model_main.to(device)
|
|
|
36 |
model_main.eval()
|
37 |
with torch.no_grad():
|
38 |
|
|
|
23 |
dir_res = os.path.join(f"{opts.exp_path}", "experiments/", opts.name_exp, "results")
|
24 |
|
25 |
test_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, 'test')
|
26 |
+
|
27 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
28 |
+
print("Inference With Device:", device)
|
|
|
29 |
if opts.streamlit:
|
30 |
st.write("Loading Model Weight...")
|
31 |
+
st.write("Inference With Device:", device)
|
32 |
+
|
33 |
model_main = ModelMain(opts)
|
34 |
path_ckpt = os.path.join(f"{opts.model_path}")
|
35 |
+
model_main.load_state_dict(torch.load(path_ckpt, map_location=device)['model'])
|
36 |
model_main.to(device)
|
37 |
+
|
38 |
model_main.eval()
|
39 |
with torch.no_grad():
|
40 |
|