HMMC_t2v_search / modules /until_config.py
cheetah003's picture
first commit
29c5a57
raw
history blame
5.04 kB
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import copy
import json
import logging
import tarfile
import tempfile
import shutil
import torch
from .file_utils import cached_path
logger = logging.getLogger(__name__)
class PretrainedConfig(object):
pretrained_model_archive_map = {}
config_name = ""
weights_name = ""
@classmethod
def get_config(cls, pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=None):
archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name)
if os.path.exists(archive_file) is False:
if pretrained_model_name in cls.pretrained_model_archive_map:
archive_file = cls.pretrained_model_archive_map[pretrained_model_name]
else:
archive_file = pretrained_model_name
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except FileNotFoundError:
if task_config is None or task_config.local_rank == 0:
logger.error(
"Model name '{}' was not found in model name list. "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name,
archive_file))
return None
if resolved_archive_file == archive_file:
if task_config is None or task_config.local_rank == 0:
logger.info("loading archive file {}".format(archive_file))
else:
if task_config is None or task_config.local_rank == 0:
logger.info("loading archive file {} from cache at {}".format(
archive_file, resolved_archive_file))
tempdir = None
if os.path.isdir(resolved_archive_file):
serialization_dir = resolved_archive_file
else:
# Extract archive to temp dir
tempdir = tempfile.mkdtemp()
if task_config is None or task_config.local_rank == 0:
logger.info("extracting archive file {} to temp dir {}".format(
resolved_archive_file, tempdir))
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
archive.extractall(tempdir)
serialization_dir = tempdir
# Load config
config_file = os.path.join(serialization_dir, cls.config_name)
config = cls.from_json_file(config_file)
config.type_vocab_size = type_vocab_size
if task_config is None or task_config.local_rank == 0:
logger.info("Model config {}".format(config))
if state_dict is None:
weights_path = os.path.join(serialization_dir, cls.weights_name)
if os.path.exists(weights_path):
state_dict = torch.load(weights_path, map_location='cpu')
else:
if task_config is None or task_config.local_rank == 0:
logger.info("Weight doesn't exsits. {}".format(weights_path))
if tempdir:
# Clean up temp dir
shutil.rmtree(tempdir)
return config, state_dict
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = cls(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"