erfanzar commited on
Commit
99193a0
1 Parent(s): 373144e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -0
app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import EasyDel
2
+ import jax.lax
3
+ from EasyDel import JAXServer, get_mesh
4
+ from fjutils import get_float_dtype_by_name
5
+ from EasyDel.transform import llama_from_pretrained
6
+ from transformers import AutoTokenizer
7
+ import gradio as gr
8
+ from fjutils.tracker import initialise_tracking, get_mem
9
+ import argparse
10
+ from fjutils import make_shard_and_gather_fns, match_partition_rules
11
+ import threading
12
+ import typing
13
+ import IPython
14
+ import logging
15
+ import jax.numpy as jnp
16
+ import time
17
+
18
+ logging.basicConfig(
19
+ level=logging.INFO
20
+ )
21
+
22
+
23
+ instruct = 'Context:\n{context}\nQuestion:\nYes or No question, can you answer to ' \
24
+ '""{question}?"" only and only by using provided context?'
25
+
26
+
27
+ DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant. Always answer " \
28
+ "as helpfully as possible, while being safe. Your answers should not" \
29
+ " include any harmful, unethical, racist, sexist, toxic, dangerous, or " \
30
+ "illegal content. Please ensure that your responses are socially unbiased " \
31
+ "and positive in nature.\nIf a question does not make any sense, or is not " \
32
+ "factually coherent, explain why instead of answering something not correct. If " \
33
+ "you don't know the answer to a question, please don't share false information."
34
+
35
+
36
+ def get_prompt_llama2_format(message: str, chat_history,
37
+ system_prompt: str) -> str:
38
+ texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
39
+ do_strip = False
40
+ for user_input, response in chat_history:
41
+ user_input = user_input.strip() if do_strip else user_input
42
+ do_strip = True
43
+ texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
44
+ message = message.strip() if do_strip else message
45
+ texts.append(f'{message} [/INST]')
46
+ return ''.join(texts)
47
+
48
+
49
+ class InTimeDataFinderJaxServerLlama2Type(JAXServer):
50
+ def __init__(self, config=None):
51
+ super().__init__(config=config)
52
+
53
+ @classmethod
54
+ def load_from_torch(cls, repo_id, config=None):
55
+ with jax.default_device(jax.devices('cpu')[0]):
56
+ param, config_model = llama_from_pretrained(
57
+ repo_id
58
+ )
59
+ tokenizer = AutoTokenizer.from_pretrained(repo_id)
60
+ model = EasyDel.FlaxLlamaForCausalLM(
61
+ config=config_model,
62
+ dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'),
63
+ param_dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'),
64
+ precision=jax.lax.Precision('fastest'),
65
+ _do_init=False
66
+ )
67
+ return cls.load_from_params(
68
+ config_model=config_model,
69
+ model=model,
70
+ config=config,
71
+ params=param,
72
+ tokenizer=tokenizer,
73
+ add_param_field=True,
74
+ do_memory_log=False
75
+ )
76
+
77
+ @classmethod
78
+ def load_from_jax(cls, repo_id, checkpoint_path, config_repo=None, config=None):
79
+ from huggingface_hub import hf_hub_download
80
+ path = hf_hub_download(repo_id, checkpoint_path)
81
+ tokenizer = AutoTokenizer.from_pretrained(repo_id)
82
+ config_model = EasyDel.LlamaConfig.from_pretrained(config_repo or repo_id)
83
+ model = EasyDel.FlaxLlamaForCausalLM(
84
+ config=config_model,
85
+ dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'),
86
+ param_dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'),
87
+ precision=jax.lax.Precision('fastest'),
88
+ _do_init=False
89
+ )
90
+ return cls.load(
91
+ path=path,
92
+ config_model=config_model,
93
+ model=model,
94
+ config=config,
95
+ tokenizer=tokenizer,
96
+ add_param_field=True,
97
+ do_memory_log=False
98
+ )
99
+
100
+ def process_gradio_chat(self, prompt, history, max_new_tokens, greedy, pbar=gr.Progress()):
101
+ string = get_prompt_llama2_format(
102
+ message=prompt,
103
+ chat_history=history,
104
+ system_prompt=DEFAULT_SYSTEM_PROMPT
105
+ )
106
+ if not self.config.stream_tokens_for_gradio:
107
+ response, _ = self.process(
108
+ string=string,
109
+ greedy=greedy,
110
+ max_new_tokens=max_new_tokens,
111
+ )
112
+ history.append([prompt, response])
113
+ else:
114
+ history.append([prompt, ''])
115
+ for response, _ in self.process(
116
+ string=string,
117
+ greedy=greedy,
118
+ max_new_tokens=max_new_tokens,
119
+ stream=True
120
+ ):
121
+ history[-1][-1] = response
122
+ yield '', history
123
+ return '', history
124
+
125
+ def process_gradio_instruct(self, prompt, system, max_new_tokens, greedy, pbar=gr.Progress()):
126
+ string = get_prompt_llama2_format(system_prompt=DEFAULT_SYSTEM_PROMPT, message=prompt, chat_history=[])
127
+ if not self.config.stream_tokens_for_gradio:
128
+ response, _ = self.process(
129
+ string=string,
130
+ greedy=greedy,
131
+ max_new_tokens=max_new_tokens,
132
+ )
133
+ else:
134
+ response = ''
135
+ for response, _ in self.process(
136
+ string=string,
137
+ greedy=greedy,
138
+ max_new_tokens=max_new_tokens,
139
+ stream=True
140
+ ):
141
+ yield '', response
142
+ return '', response
143
+
144
+ if __name__ == "__main__":
145
+
146
+ configs = {
147
+ "repo_id": "meta-llama/Llama-2-7b-chat-hf",
148
+ "max_length": 4096,
149
+ "max_new_tokens": 4096,
150
+ "max_stream_tokens": 64,
151
+ "dtype": 'fp16',
152
+ "use_prefix_tokenizer": True
153
+ }
154
+ for key, value in configs.items():
155
+ print('\033[1;36m{:<30}\033[1;0m : {:>30}'.format(key.replace('_', ' '), f"{value}"))
156
+
157
+
158
+ server = InTimeDataFinderJaxServerLlama2Type.load_from_torch(
159
+ repo_id=configs['repo_id'],
160
+ config=configs
161
+ )
162
+ server.gradio_app_chat.launch(share=False)