|
import sys |
|
sys.path.append("../") |
|
|
|
import torch |
|
import gradio as gr |
|
from omegaconf import OmegaConf |
|
from transformers import AutoTokenizer |
|
from huggingface_hub import hf_hub_download |
|
|
|
from src.utils.setup import seed_everything |
|
from src.utils.logging import print_header |
|
from src.model.pretrained import get_pretrained_loader |
|
from src.model.load_model import load_and_convert_attns, load_and_convert_finetune |
|
|
|
def load_model_from_checkpoint( |
|
attn_mlp_checkpoint_path: str = None, |
|
finetune_checkpoint_path: str = None, |
|
model_config_path: str = None, |
|
distill_config_path: str = None, |
|
finetune_config_path: str = None, |
|
config_dir: str = 'configs', |
|
print_model: bool = False, |
|
debug: bool = False, |
|
huggingface_token: str = None, |
|
use_cuda_kernels: bool = False, |
|
use_attention: bool = False |
|
): |
|
|
|
is_local = attn_mlp_checkpoint_path.endswith(".pt") |
|
|
|
model_config = OmegaConf.load(model_config_path) |
|
distill_config = OmegaConf.load(distill_config_path) |
|
finetune_config = OmegaConf.load(finetune_config_path) |
|
|
|
model_loader = get_pretrained_loader(**model_config.model, |
|
huggingface_token=huggingface_token) |
|
tokenizer = model_loader.load_tokenizer() |
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
tokenizer.padding_side = 'left' |
|
if use_attention: |
|
model = model_loader.load('softmax') |
|
return model, model_config, tokenizer |
|
|
|
model = model_loader.load(model_config['attention']['attention_type']) |
|
if use_cuda_kernels: |
|
print('*** Using TK CUDA kernels **') |
|
model_config['attention']['attention_type'] = 'lolcats_llama_window_tk_gen' |
|
|
|
if is_local: |
|
checkpoint_path = attn_mlp_checkpoint_path |
|
else: |
|
checkpoint_path = None |
|
model, distill_peft_config = load_and_convert_attns( |
|
model, model_config, |
|
attention_type=None, |
|
checkpoint_path=checkpoint_path, |
|
print_model=debug, |
|
merge_loras=False, |
|
peft_gradient_checkpointing=False, |
|
train_attention=False) |
|
|
|
if is_local: |
|
checkpoint_path = attn_mlp_checkpoint_path |
|
else: |
|
checkpoint_path = None |
|
model, ft_peft_config = load_and_convert_finetune( |
|
model, finetune_config, |
|
checkpoint_path=checkpoint_path, |
|
print_model=debug, |
|
merge_loras=False, |
|
peft_gradient_checkpointing=False) |
|
|
|
if not is_local: |
|
model = load_hf_weights( |
|
model, |
|
attn_mlp_checkpoint_path, finetune_checkpoint_path, |
|
filename="model.pt" |
|
) |
|
if use_cuda_kernels: |
|
print('*** Using TK CUDA kernels ***') |
|
|
|
if print_model: |
|
print('*** Model after checkpoint load ***') |
|
print(model) |
|
|
|
return model, model_config, tokenizer |
|
|
|
def load_hf_weights(model, distill_repo_id, ft_repo_id, filename="model.pt"): |
|
for repo_id in [distill_repo_id, ft_repo_id]: |
|
if repo_id is None: continue |
|
|
|
print(f"Loading weights from {repo_id}") |
|
|
|
local_file_path = hf_hub_download(repo_id=repo_id, filename=filename) |
|
state_dict = torch.load(local_file_path) |
|
if 'model_state_dict' in state_dict: |
|
state_dict = state_dict['model_state_dict'] |
|
else: |
|
pass |
|
_keys = model.load_state_dict(state_dict, strict=False) |
|
if len(_keys.unexpected_keys) > 0: |
|
new_state_dict = {k.replace('model.', 'model.model.'): v for k, v in state_dict.items()} |
|
_keys = model.load_state_dict(new_state_dict, strict=False) |
|
if len(_keys.unexpected_keys) > 0: |
|
new_state_dict = {k.replace('model.', 'base_model.model.model.'): v for k, v in state_dict.items()} |
|
_keys = model.load_state_dict(new_state_dict, strict=False) |
|
|
|
try: |
|
assert len(_keys.unexpected_keys) == 0 |
|
print('*** All expected keys matched successfully ***') |
|
except Exception as e: |
|
print(e) |
|
print('*** Error: unexpected keys in checkpoint - please fix ***') |
|
print('Unexpected keys:') |
|
for k in _keys.unexpected_keys: |
|
print(k) |
|
exit() |
|
|
|
return model |
|
|
|
def load_model_and_tokenizer(): |
|
CONFIG_DIR = 'configs' |
|
|
|
model_config_path = f"{CONFIG_DIR}/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml" |
|
distill_config_path = f"{CONFIG_DIR}/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml" |
|
finetune_config_path = f"{CONFIG_DIR}/experiment/finetune_lora_qkvo_alpaca_clean.yaml" |
|
attn_mlp_checkpoint_path = 'hazyresearch/lolcats-llama-3.1-8b-distill' |
|
finetune_checkpoint_path = 'hazyresearch/lolcats-llama-3.1-8b-ft-lora' |
|
|
|
model, model_config, tokenizer = load_model_from_checkpoint( |
|
attn_mlp_checkpoint_path=attn_mlp_checkpoint_path, |
|
finetune_checkpoint_path=finetune_checkpoint_path, |
|
model_config_path=model_config_path, |
|
distill_config_path=distill_config_path, |
|
finetune_config_path=finetune_config_path, |
|
config_dir=CONFIG_DIR, |
|
print_model=False, |
|
debug=False, |
|
huggingface_token=None, |
|
use_cuda_kernels=False, |
|
use_attention=False |
|
) |
|
model = model.to('cuda') |
|
model.eval() |
|
return model, tokenizer |
|
|
|
model, tokenizer = load_model_and_tokenizer() |
|
|
|
def generate_response(prompt): |
|
all_prompts = [prompt] |
|
|
|
with torch.no_grad(): |
|
model_input = tokenizer(all_prompts, return_tensors="pt").to(model.device) |
|
model_output = model.generate( |
|
**model_input, use_cache=True, |
|
max_new_tokens=50, |
|
do_sample=False, |
|
top_k=1, |
|
top_p=1.0, |
|
num_return_sequences=1, |
|
pad_token_id=tokenizer.eos_token_id) |
|
generated_tokens = model_output[0] |
|
input_len = model_input['input_ids'].shape[1] |
|
generated_tokens = generated_tokens[input_len:] |
|
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) |
|
|
|
return generated_text |
|
|
|
iface = gr.Interface(fn=generate_response, inputs="text", outputs="text", title="LOLcats Model Demo") |
|
|
|
iface.launch() |