Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. | |
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""PyTorch BERT model.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import os | |
import copy | |
import json | |
import logging | |
import tarfile | |
import tempfile | |
import shutil | |
import torch | |
from .file_utils import cached_path | |
logger = logging.getLogger(__name__) | |
class PretrainedConfig(object): | |
pretrained_model_archive_map = {} | |
config_name = "" | |
weights_name = "" | |
def get_config(cls, pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=None): | |
archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name) | |
if os.path.exists(archive_file) is False: | |
if pretrained_model_name in cls.pretrained_model_archive_map: | |
archive_file = cls.pretrained_model_archive_map[pretrained_model_name] | |
else: | |
archive_file = pretrained_model_name | |
# redirect to the cache, if necessary | |
try: | |
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) | |
except FileNotFoundError: | |
if task_config is None or task_config.local_rank == 0: | |
logger.error( | |
"Model name '{}' was not found in model name list. " | |
"We assumed '{}' was a path or url but couldn't find any file " | |
"associated to this path or url.".format( | |
pretrained_model_name, | |
archive_file)) | |
return None | |
if resolved_archive_file == archive_file: | |
if task_config is None or task_config.local_rank == 0: | |
logger.info("loading archive file {}".format(archive_file)) | |
else: | |
if task_config is None or task_config.local_rank == 0: | |
logger.info("loading archive file {} from cache at {}".format( | |
archive_file, resolved_archive_file)) | |
tempdir = None | |
if os.path.isdir(resolved_archive_file): | |
serialization_dir = resolved_archive_file | |
else: | |
# Extract archive to temp dir | |
tempdir = tempfile.mkdtemp() | |
if task_config is None or task_config.local_rank == 0: | |
logger.info("extracting archive file {} to temp dir {}".format( | |
resolved_archive_file, tempdir)) | |
with tarfile.open(resolved_archive_file, 'r:gz') as archive: | |
archive.extractall(tempdir) | |
serialization_dir = tempdir | |
# Load config | |
config_file = os.path.join(serialization_dir, cls.config_name) | |
config = cls.from_json_file(config_file) | |
config.type_vocab_size = type_vocab_size | |
if task_config is None or task_config.local_rank == 0: | |
logger.info("Model config {}".format(config)) | |
if state_dict is None: | |
weights_path = os.path.join(serialization_dir, cls.weights_name) | |
if os.path.exists(weights_path): | |
state_dict = torch.load(weights_path, map_location='cpu') | |
else: | |
if task_config is None or task_config.local_rank == 0: | |
logger.info("Weight doesn't exsits. {}".format(weights_path)) | |
if tempdir: | |
# Clean up temp dir | |
shutil.rmtree(tempdir) | |
return config, state_dict | |
def from_dict(cls, json_object): | |
"""Constructs a `BertConfig` from a Python dictionary of parameters.""" | |
config = cls(vocab_size_or_config_json_file=-1) | |
for key, value in json_object.items(): | |
config.__dict__[key] = value | |
return config | |
def from_json_file(cls, json_file): | |
"""Constructs a `BertConfig` from a json file of parameters.""" | |
with open(json_file, "r", encoding='utf-8') as reader: | |
text = reader.read() | |
return cls.from_dict(json.loads(text)) | |
def __repr__(self): | |
return str(self.to_json_string()) | |
def to_dict(self): | |
"""Serializes this instance to a Python dictionary.""" | |
output = copy.deepcopy(self.__dict__) | |
return output | |
def to_json_string(self): | |
"""Serializes this instance to a JSON string.""" | |
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" |