File size: 6,154 Bytes
99193a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import EasyDel
import jax.lax
from EasyDel import JAXServer, get_mesh
from fjutils import get_float_dtype_by_name
from EasyDel.transform import llama_from_pretrained
from transformers import AutoTokenizer
import gradio as gr
from fjutils.tracker import initialise_tracking, get_mem
import argparse
from fjutils import make_shard_and_gather_fns, match_partition_rules
import threading
import typing
import IPython
import logging
import jax.numpy as jnp
import time

logging.basicConfig(
    level=logging.INFO
)


instruct = 'Context:\n{context}\nQuestion:\nYes or No question, can you answer to ' \
           '""{question}?"" only and only by using provided context?'


DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant. Always answer " \
                            "as helpfully as possible, while being safe.  Your answers should not" \
                            " include any harmful, unethical, racist, sexist, toxic, dangerous, or " \
                            "illegal content. Please ensure that your responses are socially unbiased " \
                            "and positive in nature.\nIf a question does not make any sense, or is not " \
                            "factually coherent, explain why instead of answering something not correct. If " \
                            "you don't know the answer to a question, please don't share false information."


def get_prompt_llama2_format(message: str, chat_history,
               system_prompt: str) -> str:
    texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
    do_strip = False
    for user_input, response in chat_history:
        user_input = user_input.strip() if do_strip else user_input
        do_strip = True
        texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
    message = message.strip() if do_strip else message
    texts.append(f'{message} [/INST]')
    return ''.join(texts)


class InTimeDataFinderJaxServerLlama2Type(JAXServer):
    def __init__(self, config=None):
        super().__init__(config=config)

    @classmethod
    def load_from_torch(cls, repo_id, config=None):
        with jax.default_device(jax.devices('cpu')[0]):
            param, config_model = llama_from_pretrained(
                repo_id
            )
        tokenizer = AutoTokenizer.from_pretrained(repo_id)
        model = EasyDel.FlaxLlamaForCausalLM(
            config=config_model,
            dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'),
            param_dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'),
            precision=jax.lax.Precision('fastest'),
            _do_init=False
        )
        return cls.load_from_params(
            config_model=config_model,
            model=model,
            config=config,
            params=param,
            tokenizer=tokenizer,
            add_param_field=True,
            do_memory_log=False
        )

    @classmethod
    def load_from_jax(cls, repo_id, checkpoint_path, config_repo=None, config=None):
        from huggingface_hub import hf_hub_download
        path = hf_hub_download(repo_id, checkpoint_path)
        tokenizer = AutoTokenizer.from_pretrained(repo_id)
        config_model = EasyDel.LlamaConfig.from_pretrained(config_repo or repo_id)
        model = EasyDel.FlaxLlamaForCausalLM(
            config=config_model,
            dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'),
            param_dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'),
            precision=jax.lax.Precision('fastest'),
            _do_init=False
        )
        return cls.load(
            path=path,
            config_model=config_model,
            model=model,
            config=config,
            tokenizer=tokenizer,
            add_param_field=True,
            do_memory_log=False
        )

    def process_gradio_chat(self, prompt, history, max_new_tokens, greedy, pbar=gr.Progress()):
        string = get_prompt_llama2_format(
            message=prompt,
            chat_history=history,
            system_prompt=DEFAULT_SYSTEM_PROMPT
        )
        if not self.config.stream_tokens_for_gradio:
            response, _ = self.process(
                string=string,
                greedy=greedy,
                max_new_tokens=max_new_tokens,
            )
            history.append([prompt, response])
        else:
            history.append([prompt, ''])
            for response, _ in self.process(
                    string=string,
                    greedy=greedy,
                    max_new_tokens=max_new_tokens,
                    stream=True
            ):
                history[-1][-1] = response
                yield '', history
        return '', history

    def process_gradio_instruct(self, prompt, system, max_new_tokens, greedy, pbar=gr.Progress()):
        string = get_prompt_llama2_format(system_prompt=DEFAULT_SYSTEM_PROMPT, message=prompt, chat_history=[])
        if not self.config.stream_tokens_for_gradio:
            response, _ = self.process(
                string=string,
                greedy=greedy,
                max_new_tokens=max_new_tokens,
            )
        else:
            response = ''
            for response, _ in self.process(
                    string=string,
                    greedy=greedy,
                    max_new_tokens=max_new_tokens,
                    stream=True
            ):
                yield '', response
        return '', response

if __name__ == "__main__":
    
    configs = {
        "repo_id": "meta-llama/Llama-2-7b-chat-hf",
        "max_length": 4096,
        "max_new_tokens": 4096,
        "max_stream_tokens": 64,
        "dtype": 'fp16',
        "use_prefix_tokenizer": True
    }
    for key, value in configs.items():
        print('\033[1;36m{:<30}\033[1;0m : {:>30}'.format(key.replace('_', ' '), f"{value}"))
    
    
    server = InTimeDataFinderJaxServerLlama2Type.load_from_torch(
        repo_id=configs['repo_id'],
        config=configs
    )
    server.gradio_app_chat.launch(share=False)