File size: 873 Bytes
f50d964
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from lyra_baichuan import lyraBaichuan7B, lyraBaichuan13B

model_path = "./models/Baichuan2-13B-lyra"
tokenizer_path = "./models/Baichuan2-13B-lyra"
inference_dtype = 'fp16'
prompt = "登鹳雀楼->王之涣\n夜雨寄北->"

memopt_mode = 1
max_output_length = 64
arch = "Ampere" # Ampere or Volta
cuda_version = 12 # cuda version, we currently support 11 and 12

model = lyraBaichuan13B(model_path, 
                        tokenizer_path = tokenizer_path, 
                        dtype = inference_dtype,
                        memopt_mode = memopt_mode,
                        arch = arch,
                        cuda_version = cuda_version)

bs = 1
prompts = [prompt, ] * bs
output_texts = model.generate(
        prompts, output_length=max_output_length,
        top_k=30, top_p=0.85, temperature=1.0, repetition_penalty=1.0, do_sample=False)

print(output_texts)