JoshuaChak's picture
Upload folder using huggingface_hub
7c071a8 verified
raw
history blame
7.65 kB
#!/usr/bin/env python3
# ==============================================================================
#
# Copyright (C) 2023 Sophgo Technologies Inc. All rights reserved.
#
# TPU-MLIR is licensed under the 2-Clause BSD License except for the
# third-party components.
#
# ==============================================================================
import os
import torch
import argparse
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
torch.set_grad_enabled(False)
torch.set_num_threads(72)
parser = argparse.ArgumentParser(description='export onnx.')
parser.add_argument('--model_path', type=str, help='path to the torch model.')
parser.add_argument('--guess_len', type=int, default=8, help='guess length')
parser.add_argument('--device', required=False, type=str, choices=["cpu", "cuda"], default="cuda")
parser.add_argument('--generation_mode', type=str, choices=["basic", "sample"], help='mode to the generate token.')
args = parser.parse_args()
model_path = args.model_path
folder = f"./tmp/onnx"
device = torch.device(args.device)
generation_mode = args.generation_mode
origin_model = AutoModelForCausalLM.from_pretrained(
model_path, trust_remote_code=True,
torch_dtype=torch.bfloat16, device_map="auto").eval()
for param in origin_model.parameters():
param.requires_grad = False
config = origin_model.config
transformer = origin_model.transformer
layers = transformer.h
SEQ_LENGTH = config.seq_length
GUESS_LEN = args.guess_len
NUM_LAYERS = config.num_hidden_layers
HIDDEN_SIZE = config.hidden_size
NUM_ATTENTION_HEADS = config.num_attention_heads
HEAD_DIM = HIDDEN_SIZE // NUM_ATTENTION_HEADS
print(f'Layers: {NUM_LAYERS}\nHidden size: {HIDDEN_SIZE}\n')
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
class Embedding(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, input_ids):
out = transformer.wte(input_ids)
return out.float()
class QwenBlock(torch.nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.layer = layers[layer_id]
self.rotary_emb = transformer.rotary_emb(SEQ_LENGTH)
self.cos_emb = self.rotary_emb[0].view(SEQ_LENGTH, HEAD_DIM)
self.sin_emb = self.rotary_emb[1].view(SEQ_LENGTH, HEAD_DIM)
def forward(self, hidden_states, position_ids, attention_mask):
cos_pos = self.cos_emb[position_ids].unsqueeze(2)
sin_pos = self.sin_emb[position_ids].unsqueeze(2)
hidden_states, past_kv = self.layer(
hidden_states,
attention_mask=attention_mask,
rotary_pos_emb_list=[[cos_pos, sin_pos]],
use_cache=True)
present_k, present_v = past_kv
return hidden_states.float(), present_k.float(), present_v.float()
class QwenBlockCache(torch.nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.layer = layers[layer_id]
self.rotary_emb = transformer.rotary_emb(SEQ_LENGTH)
self.cos_emb = self.rotary_emb[0].view(SEQ_LENGTH, HEAD_DIM)
self.sin_emb = self.rotary_emb[1].view(SEQ_LENGTH, HEAD_DIM)
def forward(self, hidden_states, position_ids, attention_mask, past_k,
past_v):
cos_pos = self.cos_emb[position_ids].unsqueeze(2)
sin_pos = self.sin_emb[position_ids].unsqueeze(2)
hidden_states, past_kv = self.layer(
hidden_states,
layer_past=(past_k, past_v),
attention_mask=attention_mask,
rotary_pos_emb_list=[[cos_pos, sin_pos]],
use_cache=True)
present_k, present_v = past_kv
return hidden_states.float(), present_k.float(), present_v.float()
class LmHead(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, hidden_states):
hidden_states = transformer.ln_f(hidden_states)
m_logits = origin_model.lm_head(hidden_states)
_, token = torch.topk(m_logits.float(), 1)
return token
# refs:https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py
class LmHeadTopk(torch.nn.Module):
def __init__(self, top_k = 50, top_p = 0.8, min_tokens_to_keep = 5):
super().__init__()
self.top_k = top_k
self.top_p = top_p
self.min_tokens_to_keep = min_tokens_to_keep
self.keep_matrix = torch.zeros((1, self.top_k), dtype=torch.bool)
self.keep_matrix[0, :self.min_tokens_to_keep] = True
def forward(self, hidden_states):
hidden_states = transformer.ln_f(hidden_states)
m_logits = origin_model.lm_head(hidden_states)
logits, token = torch.topk(m_logits.float(), self.top_k)
# top_p
cumulative_probs = logits.softmax(dim=1).cumsum(dim=1)
mask = cumulative_probs < self.top_p
mask = mask + self.keep_matrix
filtered_logits = torch.where(mask, logits, torch.FloatTensor([-1000.]))
probs = filtered_logits.softmax(dim=1)
return probs, token
def convert_block(layer_id):
model = QwenBlock(layer_id)
hidden_states = torch.randn(
(1, SEQ_LENGTH, HIDDEN_SIZE)).bfloat16().to(device)
position_ids = torch.tensor(
[range(SEQ_LENGTH)], dtype=torch.long).to(device)
attention_mask = torch.randn(
(1, 1, SEQ_LENGTH, SEQ_LENGTH)).bfloat16().to(device)
torch.onnx.export(
model, (hidden_states, position_ids, attention_mask),
f'{folder}/block_{layer_id}.onnx',
verbose=False,
input_names=['input_states', 'position_ids', 'attention_mask'],
output_names=['hidden_states', 'past_k', 'past_v'],
do_constant_folding=True,
opset_version=15)
def convert_block_cache(layer_id):
model = QwenBlockCache(layer_id)
hidden_states = torch.randn((1, GUESS_LEN, HIDDEN_SIZE)).bfloat16().to(device)
position_ids = torch.tensor([range(GUESS_LEN)], dtype=torch.long).to(device)
attention_mask = torch.ones(
(1, 1, GUESS_LEN, SEQ_LENGTH + GUESS_LEN)).bfloat16().to(device)
past_k = torch.randn((1, SEQ_LENGTH, NUM_ATTENTION_HEADS, HEAD_DIM)).bfloat16().to(device)
past_v = torch.randn((1, SEQ_LENGTH, NUM_ATTENTION_HEADS, HEAD_DIM)).bfloat16().to(device)
torch.onnx.export(
model, (hidden_states, position_ids, attention_mask, past_k, past_v),
f'{folder}/block_cache_{layer_id}.onnx',
verbose=False,
input_names=[
'input_states', 'position_ids', 'attention_mask', 'history_k',
'history_v'
],
output_names=['hidden_states', 'past_k', 'past_v'],
do_constant_folding=True,
opset_version=15)
def convert_embedding():
model = Embedding()
input_ids = torch.tensor([range(SEQ_LENGTH)]).to(device)
module = torch.jit.trace(model.forward, input_ids)
torch.jit.save(module, f'{folder}/embedding.pt')
def convert_lm_head():
if generation_mode == "basic":
model = LmHead()
elif generation_mode == "sample":
model = LmHeadTopk()
input = torch.randn(GUESS_LEN, HIDDEN_SIZE).bfloat16().to(device)
module = torch.jit.trace(model.forward, input)
torch.jit.save(module, f'{folder}/lm_head.pt')
# create folder to store onnx
if not os.path.exists(folder):
os.makedirs(folder)
# export models
print(f'Convert block & block_cache')
for i in tqdm(range(NUM_LAYERS)):
convert_block(i)
convert_block_cache(i)
print(f'Convert embedding')
convert_embedding()
print(f'Convert lm_head')
convert_lm_head()