|
import argparse |
|
import logging |
|
import os |
|
|
|
import torch |
|
|
|
from hf_molmo.config_molmo import MolmoConfig |
|
from hf_molmo.image_preprocessing_molmo import MolmoImageProcessor |
|
from hf_molmo.modelling_molmo import MOLMoForCausalLM |
|
from hf_molmo.preprocessing_molmo import MolmoProcessor |
|
from olmo import ModelConfig |
|
from olmo.mm_data.data_utils import build_tokenizer |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def write_config(checkpoint_dir: str, output_dir: str): |
|
|
|
|
|
logger.info(f"Loading checkpoint from {checkpoint_dir}") |
|
|
|
config_path = os.path.join(checkpoint_dir, "config.yaml") |
|
model_config = ModelConfig.load(config_path, key="model") |
|
config_kwargs = model_config.asdict() |
|
config_kwargs["use_cache"] = True |
|
config_kwargs["vit_load_path"] = None |
|
config_kwargs["llm_load_path"] = None |
|
config = MolmoConfig( |
|
vocab_size=model_config.vocab_size, |
|
embedding_size=model_config.embedding_size, |
|
hidden_size=model_config.d_model, |
|
intermediate_size=model_config.mlp_hidden_size, |
|
num_hidden_layers=model_config.n_layers, |
|
num_attention_heads=model_config.n_heads, |
|
num_key_value_heads=model_config.n_kv_heads, |
|
max_position_embeddings=model_config.max_position_embeddings or model_config.max_sequence_length, |
|
initializer_range=model_config.initializer_range, |
|
use_cache=True, |
|
layer_norm_eps=model_config.layer_norm_eps, |
|
rope_theta=model_config.rope_theta, |
|
clip_qkv=model_config.clip_qkv, |
|
qkv_bias=model_config.qkv_bias, |
|
weight_tying=model_config.weight_tying, |
|
use_position_ids=True, |
|
tie_word_embeddings=False |
|
) |
|
|
|
logger.info(f"Saving HF-compatible config to {os.path.join(checkpoint_dir, 'config.json')}") |
|
config.save_pretrained(output_dir) |
|
|
|
preprocessor = MolmoProcessor( |
|
MolmoImageProcessor( |
|
max_crops=model_config.max_crops |
|
), |
|
build_tokenizer(model_config.tokenizer.identifier.split("m:")[1]).tokenizer |
|
) |
|
preprocessor.save_pretrained(output_dir) |
|
|
|
|
|
def write_model(checkpoint_dir: str, output_dir: str, ignore_olmo_compatibility: bool = False): |
|
|
|
|
|
old_model_path = os.path.join(checkpoint_dir, "model.pt") |
|
new_model_path = os.path.join(output_dir, "pytorch_model.bin") |
|
|
|
state_dict = torch.load(old_model_path) |
|
new_state_dict = {f"{MOLMoForCausalLM.base_model_prefix}.{key}": val for key, val in state_dict.items()} |
|
torch.save(new_state_dict, new_model_path) |
|
|
|
|
|
def convert_checkpoint(checkpoint_dir: str, output_dir: str): |
|
os.makedirs(output_dir, exist_ok=True) |
|
write_config(checkpoint_dir, output_dir) |
|
write_model(checkpoint_dir, output_dir) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser( |
|
description="Adds a config.json to the checkpoint directory, and creates pytorch_model.bin, " |
|
"making it easier to load weights as HF models." |
|
) |
|
parser.add_argument("checkpoint_dir") |
|
parser.add_argument("output_dir") |
|
args = parser.parse_args() |
|
convert_checkpoint(args.checkpoint_dir, args.output_dir) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |