kimihailv commited on
Commit
8ad0892
1 Parent(s): 3d61f46

Upload convert_model.py

Browse files
Files changed (1) hide show
  1. convert_model.py +30 -10
convert_model.py CHANGED
@@ -39,9 +39,9 @@ class ImageEncoder(torch.nn.Module):
39
 
40
  def convert_model(opts):
41
  src_model = uform.get_model(opts.model_name)
42
- input_ids = torch.ones(1, 77, dtype=torch.int32)
43
- attention_mask = torch.ones(1, 77, dtype=torch.int32)
44
- image = torch.ones(1, 3, 224, 224, dtype=torch.float32)
45
 
46
  print('Tracing models…')
47
  image_encoder = ImageEncoder(src_model.image_encoder).eval()
@@ -51,13 +51,18 @@ def convert_model(opts):
51
 
52
  print('Converting models…')
53
 
 
 
 
 
 
54
  image_encoder = ct.convert(
55
  image_encoder,
56
  convert_to='mlprogram',
57
  inputs=[
58
  ct.TensorType(
59
  name='image',
60
- shape=(ct.RangeDim(lower_bound=opts.batchsize_lb, upper_bound=opts.batchsize_ub, default=1), 3, 224, 224),
61
  dtype=image.numpy().dtype
62
  )],
63
  outputs=[
@@ -71,18 +76,23 @@ def convert_model(opts):
71
  compute_precision=ct.precision.FLOAT16 if opts.use_fp16 else ct.precision.FLOAT32
72
  )
73
 
 
 
 
 
 
74
  text_encoder = ct.convert(
75
  text_encoder,
76
  convert_to='mlprogram',
77
  inputs=[
78
  ct.TensorType(
79
  name='input_ids',
80
- shape=(ct.RangeDim(lower_bound=opts.batchsize_lb, upper_bound=opts.batchsize_ub, default=1), 77),
81
  dtype=input_ids.numpy().dtype
82
  ),
83
  ct.TensorType(
84
  name='attention_mask',
85
- shape=(ct.RangeDim(lower_bound=opts.batchsize_lb, upper_bound=opts.batchsize_ub, default=1), 77),
86
  dtype=attention_mask.numpy().dtype
87
  )],
88
  outputs=[
@@ -110,15 +120,25 @@ if __name__ == '__main__':
110
  type=str,
111
  help='UForm model name')
112
 
113
- opts.add_argument('--batchsize_lb',
 
 
 
 
 
 
 
 
 
 
114
  action='store',
115
  type=int,
116
- help='lower bound of batch size')
117
 
118
- opts.add_argument('--batchsize_ub',
119
  action='store',
120
  type=int,
121
- help='upper bound of batch size')
122
 
123
  opts.add_argument('-use_fp16',
124
  action='store_true',
 
39
 
40
  def convert_model(opts):
41
  src_model = uform.get_model(opts.model_name)
42
+ input_ids = torch.ones(1, src_model.text_encoder.max_position_embeddings, dtype=torch.int32)
43
+ attention_mask = torch.ones(1, src_model.text_encoder.max_position_embeddings, dtype=torch.int32)
44
+ image = torch.ones(1, 3, src_model.image_encoder.image_size, src_model.image_encoder.image_size, dtype=torch.float32)
45
 
46
  print('Tracing models…')
47
  image_encoder = ImageEncoder(src_model.image_encoder).eval()
 
51
 
52
  print('Converting models…')
53
 
54
+ if opts.image_batchsize_lb == opts.image_batchsize_ub:
55
+ image_batch_dim_shape = opts.image_batchsize_lb
56
+ else:
57
+ image_batch_dim_shape = ct.RangeDim(lower_bound=opts.image_batchsize_lb, upper_bound=opts.image_batchsize_ub, default=1)
58
+
59
  image_encoder = ct.convert(
60
  image_encoder,
61
  convert_to='mlprogram',
62
  inputs=[
63
  ct.TensorType(
64
  name='image',
65
+ shape=(image_batch_dim_shape,) + image.shape[1:],
66
  dtype=image.numpy().dtype
67
  )],
68
  outputs=[
 
76
  compute_precision=ct.precision.FLOAT16 if opts.use_fp16 else ct.precision.FLOAT32
77
  )
78
 
79
+ if opts.text_batchsize_lb == opts.text_batchsize_ub:
80
+ text_batch_dim_shape = opts.text_batchsize_lb
81
+ else:
82
+ text_batch_dim_shape = ct.RangeDim(lower_bound=opts.text_batchsize_lb, upper_bound=opts.text_batchsize_ub, default=1)
83
+
84
  text_encoder = ct.convert(
85
  text_encoder,
86
  convert_to='mlprogram',
87
  inputs=[
88
  ct.TensorType(
89
  name='input_ids',
90
+ shape=(text_batch_dim_shape,) + input_ids.shape[1:],
91
  dtype=input_ids.numpy().dtype
92
  ),
93
  ct.TensorType(
94
  name='attention_mask',
95
+ shape=(text_batch_dim_shape,) + attention_mask.shape[1:],
96
  dtype=attention_mask.numpy().dtype
97
  )],
98
  outputs=[
 
120
  type=str,
121
  help='UForm model name')
122
 
123
+ opts.add_argument('--text_batchsize_lb',
124
+ action='store',
125
+ type=int,
126
+ help='lower bound of batch size for text encoder')
127
+
128
+ opts.add_argument('--text_batchsize_ub',
129
+ action='store',
130
+ type=int,
131
+ help='upper bound of batch size for text encoder')
132
+
133
+ opts.add_argument('--image_batchsize_lb',
134
  action='store',
135
  type=int,
136
+ help='lower bound of batch size for image encoder')
137
 
138
+ opts.add_argument('--image_batchsize_ub',
139
  action='store',
140
  type=int,
141
+ help='upper bound of batch size for image encoder')
142
 
143
  opts.add_argument('-use_fp16',
144
  action='store_true',