| 
							 | 
						import os | 
					
					
						
						| 
							 | 
						import yaml | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						from transformers import AlbertConfig, AlbertModel | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class CustomAlbert(AlbertModel): | 
					
					
						
						| 
							 | 
						    def forward(self, *args, **kwargs): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        outputs = super().forward(*args, **kwargs) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        return outputs.last_hidden_state | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def load_plbert(log_dir): | 
					
					
						
						| 
							 | 
						    config_path = os.path.join(log_dir, "config.yml") | 
					
					
						
						| 
							 | 
						    plbert_config = yaml.safe_load(open(config_path)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    albert_base_configuration = AlbertConfig(**plbert_config["model_params"]) | 
					
					
						
						| 
							 | 
						    bert = CustomAlbert(albert_base_configuration) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    files = os.listdir(log_dir) | 
					
					
						
						| 
							 | 
						    ckpts = [] | 
					
					
						
						| 
							 | 
						    for f in os.listdir(log_dir): | 
					
					
						
						| 
							 | 
						        if f.startswith("step_"): | 
					
					
						
						| 
							 | 
						            ckpts.append(f) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    iters = [ | 
					
					
						
						| 
							 | 
						        int(f.split("_")[-1].split(".")[0]) | 
					
					
						
						| 
							 | 
						        for f in ckpts | 
					
					
						
						| 
							 | 
						        if os.path.isfile(os.path.join(log_dir, f)) | 
					
					
						
						| 
							 | 
						    ] | 
					
					
						
						| 
							 | 
						    iters = sorted(iters)[-1] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location="cpu") | 
					
					
						
						| 
							 | 
						    state_dict = checkpoint["net"] | 
					
					
						
						| 
							 | 
						    from collections import OrderedDict | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    new_state_dict = OrderedDict() | 
					
					
						
						| 
							 | 
						    for k, v in state_dict.items(): | 
					
					
						
						| 
							 | 
						        name = k[7:]   | 
					
					
						
						| 
							 | 
						        if name.startswith("encoder."): | 
					
					
						
						| 
							 | 
						            name = name[8:]   | 
					
					
						
						| 
							 | 
						            new_state_dict[name] = v | 
					
					
						
						| 
							 | 
						    del new_state_dict["embeddings.position_ids"] | 
					
					
						
						| 
							 | 
						    bert.load_state_dict(new_state_dict, strict=False) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return bert | 
					
					
						
						| 
							 | 
						
 |