Spaces:
Runtime error
Runtime error
import EasyDel | |
import jax.lax | |
from EasyDel import JAXServer, get_mesh | |
from fjutils import get_float_dtype_by_name | |
from EasyDel.transform import llama_from_pretrained | |
from transformers import AutoTokenizer | |
import gradio as gr | |
from fjutils.tracker import initialise_tracking, get_mem | |
import argparse | |
from fjutils import make_shard_and_gather_fns, match_partition_rules | |
import threading | |
import typing | |
import IPython | |
import logging | |
import jax.numpy as jnp | |
import time | |
logging.basicConfig( | |
level=logging.INFO | |
) | |
instruct = 'Context:\n{context}\nQuestion:\nYes or No question, can you answer to ' \ | |
'""{question}?"" only and only by using provided context?' | |
DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant. Always answer " \ | |
"as helpfully as possible, while being safe. Your answers should not" \ | |
" include any harmful, unethical, racist, sexist, toxic, dangerous, or " \ | |
"illegal content. Please ensure that your responses are socially unbiased " \ | |
"and positive in nature.\nIf a question does not make any sense, or is not " \ | |
"factually coherent, explain why instead of answering something not correct. If " \ | |
"you don't know the answer to a question, please don't share false information." | |
def get_prompt_llama2_format(message: str, chat_history, | |
system_prompt: str) -> str: | |
texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n'] | |
do_strip = False | |
for user_input, response in chat_history: | |
user_input = user_input.strip() if do_strip else user_input | |
do_strip = True | |
texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ') | |
message = message.strip() if do_strip else message | |
texts.append(f'{message} [/INST]') | |
return ''.join(texts) | |
class InTimeDataFinderJaxServerLlama2Type(JAXServer): | |
def __init__(self, config=None): | |
super().__init__(config=config) | |
def load_from_torch(cls, repo_id, config=None): | |
with jax.default_device(jax.devices('cpu')[0]): | |
param, config_model = llama_from_pretrained( | |
repo_id | |
) | |
tokenizer = AutoTokenizer.from_pretrained(repo_id) | |
model = EasyDel.FlaxLlamaForCausalLM( | |
config=config_model, | |
dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'), | |
param_dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'), | |
precision=jax.lax.Precision('fastest'), | |
_do_init=False | |
) | |
return cls.load_from_params( | |
config_model=config_model, | |
model=model, | |
config=config, | |
params=param, | |
tokenizer=tokenizer, | |
add_param_field=True, | |
do_memory_log=False | |
) | |
def load_from_jax(cls, repo_id, checkpoint_path, config_repo=None, config=None): | |
from huggingface_hub import hf_hub_download | |
path = hf_hub_download(repo_id, checkpoint_path) | |
tokenizer = AutoTokenizer.from_pretrained(repo_id) | |
config_model = EasyDel.LlamaConfig.from_pretrained(config_repo or repo_id) | |
model = EasyDel.FlaxLlamaForCausalLM( | |
config=config_model, | |
dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'), | |
param_dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'), | |
precision=jax.lax.Precision('fastest'), | |
_do_init=False | |
) | |
return cls.load( | |
path=path, | |
config_model=config_model, | |
model=model, | |
config=config, | |
tokenizer=tokenizer, | |
add_param_field=True, | |
do_memory_log=False | |
) | |
def process_gradio_chat(self, prompt, history, max_new_tokens, greedy, pbar=gr.Progress()): | |
string = get_prompt_llama2_format( | |
message=prompt, | |
chat_history=history, | |
system_prompt=DEFAULT_SYSTEM_PROMPT | |
) | |
if not self.config.stream_tokens_for_gradio: | |
response, _ = self.process( | |
string=string, | |
greedy=greedy, | |
max_new_tokens=max_new_tokens, | |
) | |
history.append([prompt, response]) | |
else: | |
history.append([prompt, '']) | |
for response, _ in self.process( | |
string=string, | |
greedy=greedy, | |
max_new_tokens=max_new_tokens, | |
stream=True | |
): | |
history[-1][-1] = response | |
yield '', history | |
return '', history | |
def process_gradio_instruct(self, prompt, system, max_new_tokens, greedy, pbar=gr.Progress()): | |
string = get_prompt_llama2_format(system_prompt=DEFAULT_SYSTEM_PROMPT, message=prompt, chat_history=[]) | |
if not self.config.stream_tokens_for_gradio: | |
response, _ = self.process( | |
string=string, | |
greedy=greedy, | |
max_new_tokens=max_new_tokens, | |
) | |
else: | |
response = '' | |
for response, _ in self.process( | |
string=string, | |
greedy=greedy, | |
max_new_tokens=max_new_tokens, | |
stream=True | |
): | |
yield '', response | |
return '', response | |
if __name__ == "__main__": | |
configs = { | |
"repo_id": "meta-llama/Llama-2-7b-chat-hf", | |
"max_length": 4096, | |
"max_new_tokens": 4096, | |
"max_stream_tokens": 64, | |
"dtype": 'fp16', | |
"use_prefix_tokenizer": True | |
} | |
for key, value in configs.items(): | |
print('\033[1;36m{:<30}\033[1;0m : {:>30}'.format(key.replace('_', ' '), f"{value}")) | |
server = InTimeDataFinderJaxServerLlama2Type.load_from_torch( | |
repo_id=configs['repo_id'], | |
config=configs | |
) | |
server.gradio_app_chat.launch(share=False) |