# Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import torch from transformers import ( AddedToken, AutoConfig, AutoTokenizer, ) from configuration_llava import LlavaConfig from modeling_llava import LlavaForConditionalGeneration KEYS_TO_MODIFY_MAPPING = { "transformer.vision_tower.vision_tower": "vision_model", "transformer.mm_projector": "multi_modal_projector", "transformer": "language_model.transformer", "lm_head": "language_model.lm_head", "model.model": "language_model.transformer", "multi_modal_projector.0": "multi_modal_projector.linear_1", "multi_modal_projector.2": "multi_modal_projector.linear_2", } def convert_state_dict_to_hf(state_dict): new_state_dict = {} for key, value in state_dict.items(): for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items(): if key_to_modify in key: key = key.replace(key_to_modify, new_key) new_state_dict[key] = value return new_state_dict def convert_llava_llama_to_hf(text_model_id, vision_model_id, projector_tokens_num, output_path, old_state_dict_path): torch.set_default_dtype(torch.float16) text_config = AutoConfig.from_pretrained(text_model_id, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(text_model_id) tokenizer.add_tokens(AddedToken("", special=True, normalized=False), special_tokens=True) tokenizer.add_special_tokens({"pad_token": ""}) config = LlavaConfig(text_config=text_config, vocab_size=51200, vision_tower_name=vision_model_id, projector_tokens_num=projector_tokens_num) config.text_config.vocab_size = config.vocab_size with torch.device("cuda"): model = LlavaForConditionalGeneration(config) state_dict = torch.load(old_state_dict_path, map_location="cpu") state_dict = convert_state_dict_to_hf(state_dict) model.load_state_dict(state_dict, strict=True, assign=True) model.config.vocab_size = model.config.vocab_size model.config.text_config.vocab_size = model.config.text_config.vocab_size model.save_pretrained(output_path) tokenizer.save_pretrained(output_path) def main(): parser = argparse.ArgumentParser() parser.add_argument( "--text_model_id", help="Hub location of the text model", ) parser.add_argument( "--vision_model_id", help="Hub location of the vision model", ) parser.add_argument( "--output_path", help="Location of the converted model", ) parser.add_argument( "--old_state_dict_path", help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`", ) parser.add_argument( "--tokens_num", type=int, default=1 ) args = parser.parse_args() convert_llava_llama_to_hf(args.text_model_id, args.vision_model_id, args.tokens_num, args.output_path, args.old_state_dict_path) if __name__ == "__main__": main()