File size: 759 Bytes
2ce2d73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from lyra_xverse import lyraXVERSE

model_path = "./models/"
tokenizer_path = "./models/"
inference_dtype = 'fp16'
prompt = "讲个故事:"

memopt_mode = 1
max_output_length = 512
arch = "Ampere" # Ampere or Volta
cuda_version = 12 # cuda version, we currently support 11 and 12

model = lyraXVERSE(model_path, 
                   tokenizer_path = tokenizer_path, 
                   dtype = inference_dtype,
                   memopt_mode = memopt_mode,
                   arch = arch,
                   cuda_version = cuda_version)

bs = 1
prompts = [prompt, ] * bs
output_texts = model.generate(
        prompts, output_length=max_output_length,
        top_k=30, top_p=0.85, temperature=1.0, repetition_penalty=1.0, do_sample=False)

print(output_texts)