Spaces:
Runtime error
Runtime error
File size: 6,154 Bytes
99193a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import EasyDel
import jax.lax
from EasyDel import JAXServer, get_mesh
from fjutils import get_float_dtype_by_name
from EasyDel.transform import llama_from_pretrained
from transformers import AutoTokenizer
import gradio as gr
from fjutils.tracker import initialise_tracking, get_mem
import argparse
from fjutils import make_shard_and_gather_fns, match_partition_rules
import threading
import typing
import IPython
import logging
import jax.numpy as jnp
import time
logging.basicConfig(
level=logging.INFO
)
instruct = 'Context:\n{context}\nQuestion:\nYes or No question, can you answer to ' \
'""{question}?"" only and only by using provided context?'
DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant. Always answer " \
"as helpfully as possible, while being safe. Your answers should not" \
" include any harmful, unethical, racist, sexist, toxic, dangerous, or " \
"illegal content. Please ensure that your responses are socially unbiased " \
"and positive in nature.\nIf a question does not make any sense, or is not " \
"factually coherent, explain why instead of answering something not correct. If " \
"you don't know the answer to a question, please don't share false information."
def get_prompt_llama2_format(message: str, chat_history,
system_prompt: str) -> str:
texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
do_strip = False
for user_input, response in chat_history:
user_input = user_input.strip() if do_strip else user_input
do_strip = True
texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
message = message.strip() if do_strip else message
texts.append(f'{message} [/INST]')
return ''.join(texts)
class InTimeDataFinderJaxServerLlama2Type(JAXServer):
def __init__(self, config=None):
super().__init__(config=config)
@classmethod
def load_from_torch(cls, repo_id, config=None):
with jax.default_device(jax.devices('cpu')[0]):
param, config_model = llama_from_pretrained(
repo_id
)
tokenizer = AutoTokenizer.from_pretrained(repo_id)
model = EasyDel.FlaxLlamaForCausalLM(
config=config_model,
dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'),
param_dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'),
precision=jax.lax.Precision('fastest'),
_do_init=False
)
return cls.load_from_params(
config_model=config_model,
model=model,
config=config,
params=param,
tokenizer=tokenizer,
add_param_field=True,
do_memory_log=False
)
@classmethod
def load_from_jax(cls, repo_id, checkpoint_path, config_repo=None, config=None):
from huggingface_hub import hf_hub_download
path = hf_hub_download(repo_id, checkpoint_path)
tokenizer = AutoTokenizer.from_pretrained(repo_id)
config_model = EasyDel.LlamaConfig.from_pretrained(config_repo or repo_id)
model = EasyDel.FlaxLlamaForCausalLM(
config=config_model,
dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'),
param_dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'),
precision=jax.lax.Precision('fastest'),
_do_init=False
)
return cls.load(
path=path,
config_model=config_model,
model=model,
config=config,
tokenizer=tokenizer,
add_param_field=True,
do_memory_log=False
)
def process_gradio_chat(self, prompt, history, max_new_tokens, greedy, pbar=gr.Progress()):
string = get_prompt_llama2_format(
message=prompt,
chat_history=history,
system_prompt=DEFAULT_SYSTEM_PROMPT
)
if not self.config.stream_tokens_for_gradio:
response, _ = self.process(
string=string,
greedy=greedy,
max_new_tokens=max_new_tokens,
)
history.append([prompt, response])
else:
history.append([prompt, ''])
for response, _ in self.process(
string=string,
greedy=greedy,
max_new_tokens=max_new_tokens,
stream=True
):
history[-1][-1] = response
yield '', history
return '', history
def process_gradio_instruct(self, prompt, system, max_new_tokens, greedy, pbar=gr.Progress()):
string = get_prompt_llama2_format(system_prompt=DEFAULT_SYSTEM_PROMPT, message=prompt, chat_history=[])
if not self.config.stream_tokens_for_gradio:
response, _ = self.process(
string=string,
greedy=greedy,
max_new_tokens=max_new_tokens,
)
else:
response = ''
for response, _ in self.process(
string=string,
greedy=greedy,
max_new_tokens=max_new_tokens,
stream=True
):
yield '', response
return '', response
if __name__ == "__main__":
configs = {
"repo_id": "meta-llama/Llama-2-7b-chat-hf",
"max_length": 4096,
"max_new_tokens": 4096,
"max_stream_tokens": 64,
"dtype": 'fp16',
"use_prefix_tokenizer": True
}
for key, value in configs.items():
print('\033[1;36m{:<30}\033[1;0m : {:>30}'.format(key.replace('_', ' '), f"{value}"))
server = InTimeDataFinderJaxServerLlama2Type.load_from_torch(
repo_id=configs['repo_id'],
config=configs
)
server.gradio_app_chat.launch(share=False) |