fffiloni commited on
Commit
f6b8ea3
1 Parent(s): b0f5b8e

Create BaseModel.py

Browse files
Files changed (1) hide show
  1. xdecoder/BaseModel.py +37 -0
xdecoder/BaseModel.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Xueyan Zou (xueyan@cs.wisc.edu)
6
+ # --------------------------------------------------------
7
+
8
+ import os
9
+ import logging
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from utils.model_loading import align_and_update_state_dicts
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class BaseModel(nn.Module):
20
+ def __init__(self, opt, module: nn.Module):
21
+ super(BaseModel, self).__init__()
22
+ self.opt = opt
23
+ self.model = module
24
+
25
+ def forward(self, *inputs, **kwargs):
26
+ outputs = self.model(*inputs, **kwargs)
27
+ return outputs
28
+
29
+ def save_pretrained(self, save_dir):
30
+ save_path = os.path.join(save_dir, 'model_state_dict.pt')
31
+ torch.save(self.model.state_dict(), save_path)
32
+
33
+ def from_pretrained(self, load_path):
34
+ state_dict = torch.load(load_path, map_location=self.opt['device'])
35
+ state_dict = align_and_update_state_dicts(self.model.state_dict(), state_dict)
36
+ self.model.load_state_dict(state_dict, strict=False)
37
+ return self