lolcats / app.py
ariG23498's picture
ariG23498 HF staff
chore: adding the app
48b0a57
raw
history blame
6.14 kB
import sys
sys.path.append("../")
import torch
import gradio as gr
from omegaconf import OmegaConf
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
from src.utils.setup import seed_everything
from src.utils.logging import print_header
from src.model.pretrained import get_pretrained_loader
from src.model.load_model import load_and_convert_attns, load_and_convert_finetune
def load_model_from_checkpoint(
attn_mlp_checkpoint_path: str = None,
finetune_checkpoint_path: str = None,
model_config_path: str = None,
distill_config_path: str = None,
finetune_config_path: str = None,
config_dir: str = 'configs',
print_model: bool = False,
debug: bool = False,
huggingface_token: str = None,
use_cuda_kernels: bool = False,
use_attention: bool = False
):
is_local = attn_mlp_checkpoint_path.endswith(".pt")
model_config = OmegaConf.load(model_config_path)
distill_config = OmegaConf.load(distill_config_path)
finetune_config = OmegaConf.load(finetune_config_path)
model_loader = get_pretrained_loader(**model_config.model,
huggingface_token=huggingface_token)
tokenizer = model_loader.load_tokenizer()
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = 'left'
if use_attention:
model = model_loader.load('softmax')
return model, model_config, tokenizer
model = model_loader.load(model_config['attention']['attention_type'])
if use_cuda_kernels:
print('*** Using TK CUDA kernels **')
model_config['attention']['attention_type'] = 'lolcats_llama_window_tk_gen'
if is_local:
checkpoint_path = attn_mlp_checkpoint_path
else:
checkpoint_path = None
model, distill_peft_config = load_and_convert_attns(
model, model_config,
attention_type=None,
checkpoint_path=checkpoint_path,
print_model=debug,
merge_loras=False,
peft_gradient_checkpointing=False,
train_attention=False)
if is_local:
checkpoint_path = attn_mlp_checkpoint_path
else:
checkpoint_path = None
model, ft_peft_config = load_and_convert_finetune(
model, finetune_config,
checkpoint_path=checkpoint_path,
print_model=debug,
merge_loras=False,
peft_gradient_checkpointing=False)
if not is_local:
model = load_hf_weights(
model,
attn_mlp_checkpoint_path, finetune_checkpoint_path,
filename="model.pt"
)
if use_cuda_kernels:
print('*** Using TK CUDA kernels ***')
if print_model:
print('*** Model after checkpoint load ***')
print(model)
return model, model_config, tokenizer
def load_hf_weights(model, distill_repo_id, ft_repo_id, filename="model.pt"):
for repo_id in [distill_repo_id, ft_repo_id]:
if repo_id is None: continue
print(f"Loading weights from {repo_id}")
local_file_path = hf_hub_download(repo_id=repo_id, filename=filename)
state_dict = torch.load(local_file_path)
if 'model_state_dict' in state_dict:
state_dict = state_dict['model_state_dict']
else:
pass
_keys = model.load_state_dict(state_dict, strict=False)
if len(_keys.unexpected_keys) > 0:
new_state_dict = {k.replace('model.', 'model.model.'): v for k, v in state_dict.items()}
_keys = model.load_state_dict(new_state_dict, strict=False)
if len(_keys.unexpected_keys) > 0:
new_state_dict = {k.replace('model.', 'base_model.model.model.'): v for k, v in state_dict.items()}
_keys = model.load_state_dict(new_state_dict, strict=False)
try:
assert len(_keys.unexpected_keys) == 0
print('*** All expected keys matched successfully ***')
except Exception as e:
print(e)
print('*** Error: unexpected keys in checkpoint - please fix ***')
print('Unexpected keys:')
for k in _keys.unexpected_keys:
print(k)
exit()
return model
def load_model_and_tokenizer():
CONFIG_DIR = 'configs' # Update to your path
model_config_path = f"{CONFIG_DIR}/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml"
distill_config_path = f"{CONFIG_DIR}/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml"
finetune_config_path = f"{CONFIG_DIR}/experiment/finetune_lora_qkvo_alpaca_clean.yaml"
attn_mlp_checkpoint_path = 'hazyresearch/lolcats-llama-3.1-8b-distill'
finetune_checkpoint_path = 'hazyresearch/lolcats-llama-3.1-8b-ft-lora'
model, model_config, tokenizer = load_model_from_checkpoint(
attn_mlp_checkpoint_path=attn_mlp_checkpoint_path,
finetune_checkpoint_path=finetune_checkpoint_path,
model_config_path=model_config_path,
distill_config_path=distill_config_path,
finetune_config_path=finetune_config_path,
config_dir=CONFIG_DIR,
print_model=False,
debug=False,
huggingface_token=None,
use_cuda_kernels=False,
use_attention=False
)
model = model.to('cuda')
model.eval()
return model, tokenizer
model, tokenizer = load_model_and_tokenizer()
def generate_response(prompt):
all_prompts = [prompt]
with torch.no_grad():
model_input = tokenizer(all_prompts, return_tensors="pt").to(model.device)
model_output = model.generate(
**model_input, use_cache=True,
max_new_tokens=50,
do_sample=False,
top_k=1,
top_p=1.0,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id)
generated_tokens = model_output[0]
input_len = model_input['input_ids'].shape[1]
generated_tokens = generated_tokens[input_len:]
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
return generated_text
iface = gr.Interface(fn=generate_response, inputs="text", outputs="text", title="LOLcats Model Demo")
iface.launch()