# coding=utf-8 # Copyright 2020 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from math import ceil def assert_device_map(device_map, num_blocks): blocks = list(range(0, num_blocks)) device_map_blocks = [item for sublist in list(device_map.values()) for item in sublist] # Duplicate check duplicate_blocks = [] for i in device_map_blocks: if device_map_blocks.count(i) > 1 and i not in duplicate_blocks: duplicate_blocks.append(i) # Missing blocks missing_blocks = [i for i in blocks if i not in device_map_blocks] extra_blocks = [i for i in device_map_blocks if i not in blocks] if len(duplicate_blocks) != 0: raise ValueError( "Duplicate attention blocks specified in device_map. Attention blocks must be specified to one device." " These attention blocks were specified more than once: " + str(duplicate_blocks) ) if len(missing_blocks) != 0: raise ValueError( "There are attention blocks for this model that are not specified in the device_map. Add these attention " "blocks to a device on the device_map: " + str(missing_blocks) ) if len(extra_blocks) != 0: raise ValueError( "The device_map contains more attention blocks than this model has. Remove these from the device_map:" + str(extra_blocks) ) def get_device_map(n_layers, devices): """Returns a dictionary of layers distributed evenly across all devices.""" layers = list(range(n_layers)) n_blocks = int(ceil(n_layers / len(devices))) layers_list = [layers[i : i + n_blocks] for i in range(0, n_layers, n_blocks)] return dict(zip(devices, layers_list))