import EasyDel import jax.lax from EasyDel import JAXServer, get_mesh from fjutils import get_float_dtype_by_name from EasyDel.transform import llama_from_pretrained from transformers import AutoTokenizer import gradio as gr from fjutils.tracker import initialise_tracking, get_mem import argparse from fjutils import make_shard_and_gather_fns, match_partition_rules import threading import typing import IPython import logging import jax.numpy as jnp import time logging.basicConfig( level=logging.INFO ) instruct = 'Context:\n{context}\nQuestion:\nYes or No question, can you answer to ' \ '""{question}?"" only and only by using provided context?' DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant. Always answer " \ "as helpfully as possible, while being safe. Your answers should not" \ " include any harmful, unethical, racist, sexist, toxic, dangerous, or " \ "illegal content. Please ensure that your responses are socially unbiased " \ "and positive in nature.\nIf a question does not make any sense, or is not " \ "factually coherent, explain why instead of answering something not correct. If " \ "you don't know the answer to a question, please don't share false information." def get_prompt_llama2_format(message: str, chat_history, system_prompt: str) -> str: texts = [f'[INST] <>\n{system_prompt}\n<>\n\n'] do_strip = False for user_input, response in chat_history: user_input = user_input.strip() if do_strip else user_input do_strip = True texts.append(f'{user_input} [/INST] {response.strip()} [INST] ') message = message.strip() if do_strip else message texts.append(f'{message} [/INST]') return ''.join(texts) class InTimeDataFinderJaxServerLlama2Type(JAXServer): def __init__(self, config=None): super().__init__(config=config) @classmethod def load_from_torch(cls, repo_id, config=None): with jax.default_device(jax.devices('cpu')[0]): param, config_model = llama_from_pretrained( repo_id ) tokenizer = AutoTokenizer.from_pretrained(repo_id) model = EasyDel.FlaxLlamaForCausalLM( config=config_model, dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'), param_dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'), precision=jax.lax.Precision('fastest'), _do_init=False ) return cls.load_from_params( config_model=config_model, model=model, config=config, params=param, tokenizer=tokenizer, add_param_field=True, do_memory_log=False ) @classmethod def load_from_jax(cls, repo_id, checkpoint_path, config_repo=None, config=None): from huggingface_hub import hf_hub_download path = hf_hub_download(repo_id, checkpoint_path) tokenizer = AutoTokenizer.from_pretrained(repo_id) config_model = EasyDel.LlamaConfig.from_pretrained(config_repo or repo_id) model = EasyDel.FlaxLlamaForCausalLM( config=config_model, dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'), param_dtype=get_float_dtype_by_name(config['dtype'] if config is not None else 'fp16'), precision=jax.lax.Precision('fastest'), _do_init=False ) return cls.load( path=path, config_model=config_model, model=model, config=config, tokenizer=tokenizer, add_param_field=True, do_memory_log=False ) def process_gradio_chat(self, prompt, history, max_new_tokens, greedy, pbar=gr.Progress()): string = get_prompt_llama2_format( message=prompt, chat_history=history, system_prompt=DEFAULT_SYSTEM_PROMPT ) if not self.config.stream_tokens_for_gradio: response, _ = self.process( string=string, greedy=greedy, max_new_tokens=max_new_tokens, ) history.append([prompt, response]) else: history.append([prompt, '']) for response, _ in self.process( string=string, greedy=greedy, max_new_tokens=max_new_tokens, stream=True ): history[-1][-1] = response yield '', history return '', history def process_gradio_instruct(self, prompt, system, max_new_tokens, greedy, pbar=gr.Progress()): string = get_prompt_llama2_format(system_prompt=DEFAULT_SYSTEM_PROMPT, message=prompt, chat_history=[]) if not self.config.stream_tokens_for_gradio: response, _ = self.process( string=string, greedy=greedy, max_new_tokens=max_new_tokens, ) else: response = '' for response, _ in self.process( string=string, greedy=greedy, max_new_tokens=max_new_tokens, stream=True ): yield '', response return '', response if __name__ == "__main__": configs = { "repo_id": "meta-llama/Llama-2-7b-chat-hf", "max_length": 4096, "max_new_tokens": 4096, "max_stream_tokens": 64, "dtype": 'fp16', "use_prefix_tokenizer": True } for key, value in configs.items(): print('\033[1;36m{:<30}\033[1;0m : {:>30}'.format(key.replace('_', ' '), f"{value}")) server = InTimeDataFinderJaxServerLlama2Type.load_from_torch( repo_id=configs['repo_id'], config=configs ) server.gradio_app_chat.launch(share=False)