VinayHajare commited on
Commit
990c0eb
1 Parent(s): 574a477

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +40 -0
utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def load_model_weights(model, weights, multi_gpus, train=True):
2
+ """
3
+ Load the model weights from the given checkpoint file
4
+ """
5
+ # If model was originally trained on a single GPU but needs to be loaded onto multiple ones,
6
+ # it removes the "module" prefix from the weight keys
7
+ if list(weights.keys())[0].find('module') == -1:
8
+ pretrained_with_multi_gpu = False
9
+ else:
10
+ pretrained_with_multi_gpu = True
11
+
12
+ if (multi_gpus is False) or (train is False):
13
+ if pretrained_with_multi_gpu:
14
+ state_dict = {
15
+ key[7:]: value
16
+ for key, value in weights.items()
17
+ }
18
+ else:
19
+ state_dict = weights
20
+ else:
21
+ state_dict = weights
22
+
23
+ # load the model from the state_dict
24
+ model.load_state_dict(state_dict)
25
+ return model
26
+
27
+
28
+ # Class to work with if mixed precision is failing
29
+ class dummy_context_mgr():
30
+ def __enter__(self):
31
+ return None
32
+
33
+ def __exit__(self, exc_type, exc_value, traceback):
34
+ return False
35
+
36
+
37
+ # Function to read CSS from file
38
+ def read_css_from_file(filename):
39
+ with open(filename, 'r') as file:
40
+ return file.read()