from __future__ import annotations import configparser import pathlib import typing import torch import transformers from torch.nn.utils.rnn import pad_sequence from .config import BELLE_PARAM, LIB_SO_PATH from .model import BelleModel import os class LyraBelle: def __init__(self, model_path, model_name, dtype='fp16', int8_mode=0) -> None: self.model_path = model_path self.model_name = model_name self.dtype = dtype if dtype != 'int8': int8_mode = 0 self.int8_mode = int8_mode print(f'Loading model and tokenizer from {self.model_path}') self.model, self.tokenizer = self.load_model_and_tokenizer() print("Got model and tokenizer") def load_model_and_tokenizer(self): tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_path) checkpoint_path = pathlib.Path(self.model_path) config_path = checkpoint_path / 'config.ini' if config_path.exists(): # Read model params from config. cfg = configparser.ConfigParser() cfg.read(config_path) model_name = 'belle' inference_data_type = self.dtype if inference_data_type == None: inference_data_type = cfg.get(model_name, "weight_data_type") model_args = dict( head_num=cfg.getint(model_name, 'head_num'), size_per_head=cfg.getint(model_name, "size_per_head"), layer_num=cfg.getint(model_name, "num_layer"), tensor_para_size=cfg.getint(model_name, "tensor_para_size"), vocab_size=cfg.getint(model_name, "vocab_size"), start_id=cfg.getint(model_name, "start_id"), end_id=cfg.getint(model_name, "end_id"), weights_data_type=cfg.get(model_name, "weight_data_type"), layernorm_eps=cfg.getfloat(model_name, 'layernorm_eps'), inference_data_type=inference_data_type) else: inference_data_type = self.dtype if inference_data_type == None: inference_data_type = BELLE_PARAM.weights_data_type model_args = dict(head_num=BELLE_PARAM.num_heads, size_per_head=BELLE_PARAM.size_per_head, vocab_size=BELLE_PARAM.vocab_size, start_id=BELLE_PARAM.start_id or tokenizer.bos_token_id, end_id=BELLE_PARAM.end_id or tokenizer.eos_token_id, layer_num=BELLE_PARAM.num_layers, tensor_para_size=BELLE_PARAM.tensor_para_size, weights_data_type=BELLE_PARAM.weights_data_type, inference_data_type=inference_data_type) # update common parameters model_args.update(dict( lib_path=LIB_SO_PATH, pipeline_para_size=BELLE_PARAM.pipeline_para_size, shared_contexts_ratio=BELLE_PARAM.shared_contexts_ratio, int8_mode=self.int8_mode )) print('[FT][INFO] Load Our FT Highly Optimized BELLE model') for k, v in model_args.items(): print(f' - {k.ljust(25, ".")}: {v}') # Check sanity and consistency between the model and tokenizer. checklist = ['head_num', 'size_per_head', 'vocab_size', 'layer_num', 'tensor_para_size', 'tensor_para_size', 'weights_data_type'] if None in [model_args[k] for k in checklist]: none_params = [p for p in checklist if model_args[p] is None] print(f'[FT][WARNING] Found None parameters {none_params}. They must ' f'be provided either by config file or CLI arguments.') if model_args['start_id'] != tokenizer.bos_token_id: print('[FT][WARNING] Given start_id is not matched with the bos token ' 'id of the pretrained tokenizer.') if model_args['end_id'] not in (tokenizer.pad_token_id, tokenizer.eos_token_id): print('[FT][WARNING] Given end_id is not matched with neither pad ' 'token id nor eos token id of the pretrained tokenizer.') model = BelleModel(**model_args) if not model.load(ckpt_path=os.path.join(self.model_path, self.model_name)): print('[FT][WARNING] Skip model loading since no checkpoints are found') return model, tokenizer def generate(self, prompts: typing.List[str] | str, output_length: int = 512, beam_width: int = 1, top_k: typing.Optional[torch.IntTensor] = 1, top_p: typing.Optional[torch.FloatTensor] = 1.0, beam_search_diversity_rate: typing.Optional[torch.FloatTensor] = 0.0, temperature: typing.Optional[torch.FloatTensor] = 1.0, len_penalty: typing.Optional[torch.FloatTensor] = 0.0, repetition_penalty: typing.Optional[torch.FloatTensor] = 1.0, presence_penalty: typing.Optional[torch.FloatTensor] = None, min_length: typing.Optional[torch.IntTensor] = None, bad_words_list: typing.Optional[torch.IntTensor] = None, do_sample: bool = False, return_output_length: bool = False, return_cum_log_probs: int = 0): # if isinstance(prompts, str): prompts = [prompts, ] inputs = ['Human: ' + prompt.strip() + '\n\nAssistant:' for prompt in prompts] batch_size = len(inputs) ones_int = torch.ones(size=[batch_size], dtype=torch.int32) ones_float = torch.ones(size=[batch_size], dtype=torch.float32) # we must encode the raw prompt text one by one in order to compute the length of the original text. input_token_ids = [self.tokenizer(text, return_tensors="pt").input_ids.int().squeeze() for text in inputs] input_lengths = torch.IntTensor([len(ids) for ids in input_token_ids]) # after got the length of each input text tokens. we can batchfy the input list to a tensor. padding the right. input_token_ids = pad_sequence(input_token_ids, batch_first=True, padding_value=self.tokenizer.eos_token_id) random_seed = None if do_sample: random_seed = torch.randint(0, 262144, (batch_size,), dtype=torch.long) outputs = self.model(start_ids=input_token_ids, start_lengths=input_lengths, output_len=output_length, beam_width=beam_width, top_k=top_k*ones_int, top_p=top_p*ones_float, beam_search_diversity_rate=beam_search_diversity_rate*ones_float, temperature=temperature*ones_float, len_penalty=len_penalty*ones_float, repetition_penalty=repetition_penalty*ones_float, presence_penalty=presence_penalty, min_length=min_length, random_seed=random_seed, bad_words_list=bad_words_list, return_output_length=return_output_length, return_cum_log_probs=return_cum_log_probs) if return_cum_log_probs > 0: outputs = outputs[0] # output_token_ids. # Slice the generated token ids of the 1st beam result. # output = input tokens + generated tokens. output_token_ids = [out[0, length:].cpu() for out, length in zip(outputs, input_lengths)] output_texts = self.tokenizer.batch_decode( output_token_ids, skip_special_tokens=True) return output_texts