iocuydi commited on
Commit
41f63b9
1 Parent(s): 418d66e

Create inference_demo.py

Browse files
Files changed (1) hide show
  1. inference_demo.py +101 -0
inference_demo.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
+
4
+ # from accelerate import init_empty_weights, load_checkpoint_and_dispatch
5
+ # Expects to be executed in folder: https://github.com/facebookresearch/llama-recipes/tree/main/src/llama_recipes/inference
6
+
7
+ import fire
8
+ import torch
9
+ import os
10
+ import sys
11
+ import time
12
+ import json
13
+ from typing import List
14
+
15
+ from transformers import LlamaTokenizer, LlamaForCausalLM
16
+ from safety_utils import get_safety_checker
17
+ from model_utils import load_model, load_peft_model
18
+
19
+ BASE_PROMPT = """Below is an interaction between a human and an AI fluent in English and Amharic, providing reliable and informative answers.
20
+ The AI is supposed to answer test questions from the human with short responses saying just the answer and nothing else.
21
+ Human: {}
22
+ Assistant [Amharic] : """
23
+
24
+ def main(
25
+ model_name: str="",
26
+ peft_model: str=None,
27
+ quantization: bool=False,
28
+ max_new_tokens =400, #The maximum numbers of tokens to generate
29
+ prompt_file: str=None,
30
+ seed: int=42, #seed value for reproducibility
31
+ do_sample: bool=True, #Whether or not to use sampling ; use greedy decoding otherwise.
32
+ min_length: int=None, #The minimum length of the sequence to be generated, input prompt + min_new_tokens
33
+ use_cache: bool=True, #[optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
34
+ top_p: float=1.0, # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
35
+ temperature: float=1.0, # [optional] The value used to modulate the next token probabilities.
36
+ top_k: int=1, # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering.
37
+ repetition_penalty: float=1.0, #The parameter for repetition penalty. 1.0 means no penalty.
38
+ length_penalty: int=1, #[optional] Exponential penalty to the length that is used with beam-based generation.
39
+ enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
40
+ enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
41
+ enable_saleforce_content_safety: bool=False, # Enable safety check woth Saleforce safety flan t5
42
+ **kwargs
43
+ ):
44
+
45
+ print("***Note: model is not set up for chat use case, history is reset after each response.")
46
+ print("***Ensure that you have replaced the default LLAMA2 tokenizer with the Amharic tokenizer")
47
+
48
+ # Set the seeds for reproducibility
49
+ torch.cuda.manual_seed(seed)
50
+ torch.manual_seed(seed)
51
+
52
+ MAIN_PATH = '/path/to/llama2'
53
+ peft_model = '/path/to/checkpoint'
54
+ model_name = MAIN_PATH
55
+
56
+ model = load_model(model_name, quantization)
57
+
58
+ tokenizer = LlamaTokenizer.from_pretrained(model_name)
59
+ embedding_size = model.get_input_embeddings().weight.shape[0]
60
+
61
+ if len(tokenizer) != embedding_size:
62
+ print("resize the embedding size by the size of the tokenizer")
63
+ model.resize_token_embeddings(len(tokenizer))
64
+
65
+ if peft_model:
66
+ model = load_peft_model(model, peft_model)
67
+
68
+ model.eval()
69
+
70
+ while True:
71
+
72
+
73
+ user_query = input('Type question in Amharic or English: ')
74
+ user_prompt = BASE_PROMPT.format(user_query)
75
+ batch = tokenizer(user_prompt, return_tensors="pt")
76
+ batch = {k: v.to("cuda") for k, v in batch.items()}
77
+ start = time.perf_counter()
78
+ with torch.no_grad():
79
+ outputs = model.generate(
80
+ **batch,
81
+ max_new_tokens=max_new_tokens,
82
+ do_sample=do_sample,
83
+ top_p=top_p,
84
+ temperature=temperature,
85
+ min_length=min_length,
86
+ use_cache=use_cache,
87
+ top_k=top_k,
88
+ repetition_penalty=repetition_penalty,
89
+ length_penalty=length_penalty,
90
+ **kwargs
91
+ )
92
+ e2e_inference_time = (time.perf_counter()-start)*1000
93
+ print(f"the inference time is {e2e_inference_time} ms")
94
+
95
+ output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
96
+
97
+ print("MODEL_OUTPUT: {}".format(output_text))
98
+ #user_prompt += output_text
99
+
100
+ if __name__ == "__main__":
101
+ fire.Fire(main)