lyraBaichuan / demo.py
carsonhxsu
fix model path
b318031
raw
history blame
883 Bytes
from lyra_baichuan import lyraBaichuan7B, lyraBaichuan13B
model_path = "./models/Baichuan2-13B-Base-lyra"
tokenizer_path = "./models/Baichuan2-13B-Base-lyra"
inference_dtype = 'fp16'
prompt = "登鹳雀楼->王之涣\n夜雨寄北->"
memopt_mode = 1
max_output_length = 64
arch = "Ampere" # Ampere or Volta
cuda_version = 12 # cuda version, we currently support 11 and 12
model = lyraBaichuan13B(model_path,
tokenizer_path = tokenizer_path,
dtype = inference_dtype,
memopt_mode = memopt_mode,
arch = arch,
cuda_version = cuda_version)
bs = 1
prompts = [prompt, ] * bs
output_texts = model.generate(
prompts, output_length=max_output_length,
top_k=30, top_p=0.85, temperature=1.0, repetition_penalty=1.0, do_sample=False)
print(output_texts)