import gradio as gr from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer import torch import torch.distributed.run as distributed_run from git import Repo from huggingface_hub import HfApi import multiprocessing as mp # Clone the medusa repo locally print("Cloning the medusa repo locally...") Repo.clone_from("https://github.com/FasterDecoding/Medusa.git", "medusa") print("Cloning the Vicuna data locally...") Repo.clone_from("https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered", "data") print("Done") def create_medusa_heads(model_id: str): parser = distributed_run.get_args_parser() args = parser.parse_args([ "--nproc_per_node", "2", "medusa/medusa/train/train.py", "--model_name_or_path", model_id, "--data_path", "data/ShareGPT_V4.3_unfiltered_cleaned_split.json", "--bf16", "True", "--output_dir", "medusa_heads", "--num_train_epochs", "1", "--per_device_train_batch_size", "4", "--per_device_eval_batch_size", "4", "--gradient_accumulation_steps", "4", "--evaluation_strategy", "no", "--save_strategy", "no", "--learning_rate", "1e-3", "--weight_decay", "0.0", "--warmup_ratio", "0.1", "--lr_scheduler_type", "cosine", "--logging_steps", "1", "--tf32", "True", "--model_max_length", "2048", "--lazy_preprocess", "True", "--medusa_num_heads", "3", "--medusa_num_layers", "1", ]) distributed_run.run(args) # Upload the medusa heads to the Hub repo_id = f"{model_id}-medusa" api = HfApi() api.create_repo( repo_id=repo_id, exist_ok=True, ) api.upload_folder( folder_path="medusa_heads", repo_id=repo_id, ) return repo_id def run(model_id: str) -> str: print(f"\n\n\nNEW RUN: {model_id}") # Input validation if model_id == "": return """ ### Invalid input 🐞 Please fill a model_id. """ print(f"Valid inputs ✅\nValidating model_id: {model_id}") # Attempt to load the base model try: config = AutoConfig.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) del config, tokenizer, model except Exception as e: return f""" ### {model_id} can't be loaded with AutoClasses 🐞 {e} """ print(f"{model_id} can be loaded ✅\nCreating medusa heads (will take a few hours)") # Run the medusa heads creation try: proc = mp.Process(target=create_medusa_heads, args=(model_id, )) proc.start() proc.join() repo_id = f"{model_id}/medusa" print("Success ✅\nMedusa heads uploaded to: ", repo_id) return f""" ### Success 🔥 Yay! Medusa heads were successfully created and uploaded to, {repo_id} """ except Exception as e: print("Error ❌\n", e) return f""" ### Error 😢😢😢 {e} """ DESCRIPTION = """ The steps to create [medusa](https://sites.google.com/view/medusa-llm) heads are the following: 1. Input a public model id from the Hub 2. Click "Submit" 3. That's it! You'll get feedback if it works or not, and if it worked, you'll get the URL of the new repo 🔥 """ title="Create LLM medusa heads in a new repo 🐍" with gr.Blocks(title=title) as demo: description = gr.Markdown(f"""# {title}""") description = gr.Markdown(DESCRIPTION) with gr.Row() as r: with gr.Column() as c: model_id = gr.Text(max_lines=1, label="model_id") with gr.Row() as c: clean = gr.ClearButton() submit = gr.Button("Submit", variant="primary") with gr.Column() as d: status_box = gr.Markdown() submit.click(run, inputs=[model_id], outputs=status_box, concurrency_limit=1) demo.queue(max_size=10).launch(show_api=True)