import argparse from functools import partial import gradio as gr from transformers import AutoConfig from estimate_train_vram import vram_required from vram_helpers import ModelConfig, TrainingConfig, filter_params_for_dataclass ZERO_STAGES = [0, 1, 2, 3] BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64] OPTIMIZERS = ["adam", "adamw", "sgd"] HUGGINGFACE_URL_CONFIG = "https://huggingface.co/{}/resolve/main/config.json" def parse_args(): parser = argparse.ArgumentParser(description="Parser for VRAM estimator") parser.add_argument("--repo_id", type=str, default=None, help="HuggingFace repo id to automatically determine model settings") parser.add_argument("--model_size", type=float, default=7, help="Model size (in billion parameters)") parser.add_argument("--hidden_size", type=int, default=4096, help="Hidden size") parser.add_argument("--sequence_length", type=int, default=8192, help="Sequence length") parser.add_argument("--num_layers", type=int, default=32, help="Number of layers") parser.add_argument("--num_heads", type=int, default=32, help="Number of heads") parser.add_argument("--mixed_precision", action="store_false", help="Enable mixed precision for model training") parser.add_argument("--precision", type=str, default="bf16", help="Model precision for training") parser.add_argument("--micro_batch_size", type=int, default=4, help="Micro batch size (batch size per device/GPU)") parser.add_argument("--zero_stage", type=int, default=0, choices=ZERO_STAGES, help="ZeRO optimization stage") parser.add_argument("--gradient_checkpointing", action="store_false", help="Enable gradient checkpointing") parser.add_argument("--optimizer", type=str, default="adamw", choices=OPTIMIZERS, help="Type of optimizer") parser.add_argument("--num_gpus", type=int, default=4, help="Number of GPUs. Necessary for estimating ZeRO stages") parser.add_argument("--cache_dir", type=str, default=None, help="HuggingFace cache directory to download config from") parser.add_argument("--qlora", action="store_false", help="Enable QLoRA in case of finetuning") parser.add_argument("--no-app", action="store_true", help="Launch gradio app. Otherwise, commandline output") return parser def download_config_from_hub(repo_id: str, cache_dir: str): return AutoConfig.from_pretrained(pretrained_model_name_or_path=repo_id, cache_dir=cache_dir) def scrape_config_from_hub(repo_id): import requests url = HUGGINGFACE_URL_CONFIG.format(repo_id) try: print(f"Fetching config.json from the following URL: {url}...") response = requests.get(url) response.raise_for_status() # Raises a HTTPError if the status is 4xx, 5xx config = response.json() print(f"Fetched the config for model {repo_id} succesfully!") except requests.exceptions.HTTPError as errh: print(f"HTTP Error: {errh}") except requests.exceptions.ConnectionError as errc: print(f"Error Connecting: {errc}") except requests.exceptions.Timeout as errt: print(f"Timeout Error: {errt}") except requests.exceptions.RequestException as err: print(f"Something went wrong: {err}") except ValueError as e: print(f"Error decoding JSON: {e}") return config def build_interface(estimate_vram_fn): with gr.Blocks() as app: option = gr.Radio(["Repo ID", "Model Parameters"], label="Select Input Type") repo_id = gr.Textbox(label="Repo ID", visible=False, placeholder="mistralai/Mistral-7B-v0.1") with gr.Row(visible=False) as model_params_row: model_params = [gr.Slider(label="Model Size", minimum=0.1, maximum=400, step=0.1, value=7, info="Model size (in billion parameters)"), gr.Slider(label="Hidden size", minimum=256, maximum=8192, step=128, value=4096, info="Hidden size"), gr.Slider(label="Sequence length", minimum=256, maximum=128_000, step=256, value=8192, info="Sequence length"), gr.Slider(label="Num layers", minimum=8, maximum=64, step=1, value=32, info="Number of layers"), gr.Slider(label="Num heads", minimum=8, maximum=64, step=1, value=32, info="Number of attention heads") ] def update_visibility(selected_option): if selected_option == "Repo ID": return gr.update(visible=True), gr.update(visible=False), elif selected_option == "Model Parameters": return gr.update(visible=False), gr.update(visible=True) option.change( fn=update_visibility, inputs=[option], outputs=[repo_id, model_params_row] ) with gr.Row(equal_height=True): training_params = [gr.Dropdown(label="Micro batch size", choices=BATCH_SIZES, value=4, info="Micro batch size (batch size per device/GPU)"), gr.Dropdown(label="ZeRO stage", choices=ZERO_STAGES, value=0, info="ZeRO optimization stage"), gr.Dropdown(label="Gradient checkpointing", choices=[True, False], value=True, info="Enable gradient checkpointing"), gr.Dropdown(label="Mixed precision", choices=[False, True], value=False, info="Enable mixed precision for model training"), gr.Dropdown(label="Optimizer", choices=OPTIMIZERS, value="adamw", info="Type of optimizer"), gr.Dropdown(label="QLoRA", choices=[False, True], value=False, info="Finetune with QLoRA enabled"), gr.Slider(label="Num GPUs", minimum=1, maximum=64, step=1, value=4, info="Number of GPUs. Necessary for estimating ZeRO stages"), gr.Textbox(label="Cache dir", value=None, placeholder=".huggingface_configs", info="HuggingFace cache directory to download config from") ] submit_btn = gr.Button("Estimate!") output = gr.Textbox(label="Total estimated VRAM per device/GPU (in GB)") def create_combined_params_dict(repo_id, *values): all_params = model_params + training_params combined_dict = {param.label.lower().replace(" ", "_"): value for param, value in zip(all_params, values)} combined_dict["repo_id"] = repo_id return combined_dict submit_btn.click( fn=lambda repo_id, *values: estimate_vram_fn(create_combined_params_dict(repo_id, *values)), inputs=[repo_id] + model_params + training_params, outputs=[output] ) return app def estimate_vram(gradio_params): model_config = ModelConfig(**filter_params_for_dataclass(ModelConfig, gradio_params)) training_config = TrainingConfig(**filter_params_for_dataclass(TrainingConfig, gradio_params)) # Update model config if not gradio_params["repo_id"]: return "No model selected!" # If cache directory set, then download config if gradio_params["cache_dir"]: config = scrape_config_from_hub(gradio_params["repo_id"]) model_config.overwrite_with_hf_config(config) # By default, scrape config.json from hub else: config = download_config_from_hub(gradio_params["repo_id"], gradio_params["cache_dir"]) model_config.overwrite_with_hf_config(config.to_dict()) if gradio_params["qlora"]: model_config.precision = "int4" total_vram_dict = vram_required(model_config, training_config) output_str = f"Total {total_vram_dict['total']}GB = {total_vram_dict['model']}GB (model) + {total_vram_dict['gradients']}GB (gradients) + {total_vram_dict['optimizer']}GB (optimizer) + {total_vram_dict['activations']}GB activations" return output_str if __name__ == "__main__": parser = parse_args() args = parser.parse_args() # Launch gradio interface if not args.no_app: import gradio as gr estimate_vram_fn = partial(estimate_vram) interface = build_interface(estimate_vram_fn) interface.launch() # Command line interface else: model_config = ModelConfig(**filter_params_for_dataclass(ModelConfig, vars(args))) training_config = TrainingConfig(**filter_params_for_dataclass(TrainingConfig, vars(args))) if args.repo_id: # If cache directory set, then download config if args.cache_dir: config = download_config_from_hub(args.repo_id, args.cache_dir).to_dict() # By default, scrape config.json from hub else: config = scrape_config_from_hub(args.repo_id) model_config.overwrite_with_hf_config(config) total_vram_dict = vram_required(model_config, training_config) print(f"Total {total_vram_dict['total']}GB = {total_vram_dict['model']}GB (model) + {total_vram_dict['gradients']}GB (gradients) + {total_vram_dict['optimizer']}GB (optimizer) + {total_vram_dict['activations']}GB (activations)")