diff --git "a/src/gen.py" "b/src/gen.py" --- "a/src/gen.py" +++ "b/src/gen.py" @@ -44,9 +44,11 @@ import numpy as np from evaluate_params import eval_func_param_names, no_default_param_names, input_args_list from enums import DocumentSubset, LangChainMode, no_lora_str, model_token_mapping, no_model_str, \ LangChainAction, LangChainAgent, DocumentChoice, LangChainTypes, super_source_prefix, \ - super_source_postfix, t5_type, get_langchain_prompts, gr_to_lg, invalid_key_msg + super_source_postfix, t5_type, get_langchain_prompts, gr_to_lg, invalid_key_msg, docs_joiner_default, \ + docs_ordering_types_default, docs_token_handling_default from loaders import get_loaders -from utils import set_seed, clear_torch_cache, NullContext, wrapped_partial, EThread, get_githash, \ +# import utils import . +from utzils import set_seed, clear_torch_cache, NullContext, wrapped_partial, EThread, get_githash, \ import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler, get_hf_server, FakeTokenizer, \ have_langchain, set_openai, cuda_vis_check, H2O_Fire, lg_to_gr, str_to_list, str_to_dict, get_token_count @@ -75,6 +77,7 @@ def main( low_bit_mode: int = 1, load_half: bool = None, load_gptq: str = '', + load_awq: str = '', load_exllama: bool = False, use_safetensors: bool = False, revision: str = None, @@ -92,11 +95,16 @@ def main( # llama and gpt4all settings llamacpp_dict: typing.Dict = dict(n_gpu_layers=100, use_mlock=True, n_batch=1024, n_gqa=0), - model_path_llama: str = 'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q8_0.bin', - # 'llama-2-7b-chat.ggmlv3.q8_0.bin', + model_path_llama: str = 'https://huggingface.co/TheBloke/Llama-2-7b-Chat-GGUF/resolve/main/llama-2-7b-chat.Q6_K.gguf', model_name_gptj: str = 'ggml-gpt4all-j-v1.3-groovy.bin', model_name_gpt4all_llama: str = 'ggml-wizardLM-7B.q4_2.bin', model_name_exllama_if_no_config: str = 'TheBloke/Nous-Hermes-Llama2-GPTQ', + exllama_dict: typing.Dict = dict(), + gptq_dict: typing.Dict = dict(), + attention_sinks: bool = False, + sink_dict: typing.Dict = dict(), + truncation_generation: bool = False, + hf_model_dict: typing.Dict = dict(), model_lock: typing.List[typing.Dict[str, str]] = None, model_lock_columns: int = None, @@ -106,6 +114,7 @@ def main( temperature: float = None, top_p: float = None, top_k: int = None, + penalty_alpha: float = None, num_beams: int = None, repetition_penalty: float = None, num_return_sequences: int = None, @@ -118,7 +127,6 @@ def main( memory_restriction_level: int = None, debug: bool = False, save_dir: str = None, - share: bool = False, local_files_only: bool = False, resume_download: bool = True, use_auth_token: Union[str, bool] = False, @@ -136,7 +144,14 @@ def main( gradio: bool = True, gradio_offline_level: int = 0, server_name: str = "0.0.0.0", + share: bool = False, + open_browser: bool = False, root_path: str = "", + ssl_verify: bool = True, + ssl_keyfile: str | None = None, + ssl_certfile: str | None = None, + ssl_keyfile_password: str | None = None, + chat: bool = True, chat_conversation: typing.List[typing.Tuple[str, str]] = None, text_context_list: typing.List[str] = None, @@ -148,6 +163,7 @@ def main( h2ocolors: bool = True, dark: bool = False, # light tends to be best height: int = 600, + render_markdown: bool = True, show_lora: bool = True, show_llama: bool = True, show_gpt4all: bool = False, @@ -169,6 +185,7 @@ def main( auth_message: str = None, guest_name: str = "guest", enforce_h2ogpt_api_key: bool = None, + enforce_h2ogpt_ui_key: bool = None, h2ogpt_api_keys: Union[list, str] = [], h2ogpt_key: str = None, @@ -220,7 +237,8 @@ def main( langchain_agents: list = [], force_langchain_evaluate: bool = False, - visible_langchain_actions: list = [LangChainAction.QUERY.value, LangChainAction.SUMMARIZE_MAP.value], + visible_langchain_actions: list = [LangChainAction.QUERY.value, LangChainAction.SUMMARIZE_MAP.value, + LangChainAction.EXTRACT.value], visible_langchain_agents: list = langchain_agents_list.copy(), document_subset: str = DocumentSubset.Relevant.name, @@ -258,8 +276,14 @@ def main( chunk: bool = True, chunk_size: int = 512, top_k_docs: int = None, - docs_ordering_type: str = 'reverse_ucurve_sort', + docs_ordering_type: str = docs_ordering_types_default, min_max_new_tokens=256, + max_input_tokens=-1, + docs_token_handling: str = docs_token_handling_default, + docs_joiner: str = docs_joiner_default, + hyde_level: int = 0, + hyde_template: str = None, + auto_reduce_chunks: bool = True, max_chunks: int = 100, headsize: int = 50, @@ -280,14 +304,16 @@ def main( # images enable_ocr=False, - enable_doctr=False, + enable_doctr=True, enable_pix2struct=False, enable_captions=True, pre_load_caption_model: bool = False, caption_gpu: bool = True, + caption_gpu_id: Union[int, str] = 'auto', captions_model: str = "Salesforce/blip-image-captioning-base", doctr_gpu: bool = True, + doctr_gpu_id: Union[int, str] = 'auto', # json jq_schema='.[]', @@ -307,6 +333,7 @@ def main( :param load_half: load model in float16 (None means auto, which means True unless t5 based model) otherwise specify bool :param load_gptq: to load model with GPTQ, put model_basename here, e.g. gptq_model-4bit--1g + :param load_awq: load model with AWQ, often 'model' for TheBloke models :param load_exllama: whether to use exllama (only applicable to LLaMa1/2 models with 16-bit or GPTQ :param use_safetensors: to use safetensors version (assumes file/HF points to safe tensors version) :param revision: Which HF revision to use @@ -335,7 +362,13 @@ def main( Or Address can be for vLLM: Use: "vllm:IP:port" for OpenAI-compliant vLLM endpoint - Note: vllm_chat not supported by vLLM project. + Use: "vllm_chat:IP:port" for OpenAI-Chat-compliant vLLM endpoint + + Use: "vllm:http://IP:port/v1" for OpenAI-compliant vLLM endpoint + Use: "vllm_chat:http://IP:port/v1" for OpenAI-Chat-compliant vLLM endpoint + + Use: "vllm:https://IP/v1" for OpenAI-compliant vLLM endpoint + Use: "vllm_chat:https://IP/v1" for OpenAI-Chat-compliant vLLM endpoint Or Address can be replicate: Use: @@ -366,6 +399,36 @@ def main( :param model_name_gptj: model path or URL (for auto-download) :param model_name_gpt4all_llama: model path or URL (for auto-download) :param model_name_exllama_if_no_config: exllama model's full path for model, tokenizer, generator for use when no HuggingFace config + :param exllama_dict for setting various things for Exllama class + E.g. compress_pos_emb, + set_auto_map, + gpu_peer_fix, + alpha_value, + matmul_recons_thd, + fused_mlp_thd + sdp_thd + fused_attn + matmul_fused_remap + rmsnorm_no_half2 + rope_no_half2 + matmul_no_half2 + silu_no_half2 + concurrent_streams + E.g. to set memory to be split across 2 GPUs, use --exllama_dict="{'set_auto_map':20,20}" + :param gptq_dict: Choices for AutoGPTQ, e.g. one can change defaults to these non-defaults: + inject_fused_attention=False + disable_exllama=True + use_triton=True + :param attention_sinks: Whether to enable attention sinks. Requires in local repo: + git clone https://github.com/tomaarsen/attention_sinks.git + :param sink_dict: dict of options for attention sinks + :param hf_model_dict: dict of options for HF models using transformers + + :param truncation_generation: Whether (for torch) to terminate generation once reach context length of model. + For some models, perplexity becomes critically large beyond context + For other models like Mistral, one can generate beyond max_seq_len set to 4096 or 8192 without issue, since based upon 32k embeddings + codellama can also generate beyond its 16k context length + So default is off, but for simpler/older models True may be wise to avoid bad generations :param model_lock: Lock models to specific combinations, for ease of use and extending to many models Only used if gradio = True @@ -387,6 +450,7 @@ def main( :param temperature: generation temperature :param top_p: generation top_p :param top_k: generation top_k + :param penalty_alpha: penalty_alpha>0 and top_k>1 enables contrastive search (not all models support) :param num_beams: generation number of beams :param repetition_penalty: generation repetition penalty :param num_return_sequences: generation number of sequences (1 forced for chat) @@ -398,13 +462,16 @@ def main( :param memory_restriction_level: 0 = no restriction to tokens or model, 1 = some restrictions on token 2 = HF like restriction 3 = very low memory case :param debug: enable debug mode :param save_dir: directory chat data is saved to - :param share: whether to share the gradio app with sharable URL :param local_files_only: whether to only use local files instead of doing to HF for models :param resume_download: whether to resume downloads from HF for models :param use_auth_token: whether to use HF auth token (requires CLI did huggingface-cli login before) :param trust_remote_code: whether to use trust any code needed for HF model :param rope_scaling: - For HF transformers model: scaling for rope-based models, e.g. --rope_scaling="{'type':'dynamic', 'factor':4}" + For HF transformers model: scaling for rope-based models. + For long context models that have been tuned for a specific size, you have to only use that specific size by setting the `--rope_scaling` exactly correctly + e.g. --rope_scaling="{'type':'dynamic', 'factor':4}" + e.g. --rope_scaling="{'type':'linear', 'factor':4}" + e.g. python generate.py --rope_scaling="{'type':'linear','factor':4}" --base_model=lmsys/vicuna-13b-v1.5-16k --hf_embedding_model=sentence-transformers/all-MiniLM-L6-v2 --load_8bit=True --langchain_mode=UserData --user_path=user_path --prompt_type=vicuna11 --h2ocolors=False For exllama model: --rope_scaling="{'alpha_value':4}" . This automatically scales max_seq_len for exllama :param max_seq_len: Manually set maximum sequence length for the LLM :param offload_folder: path for spilling model onto disk @@ -428,10 +495,17 @@ def main( Also set --share=False to avoid sharing a gradio live link. :param server_name: IP to use. In linux 0.0.0.0 is good choice so exposed to outside host, else for only local use 127.0.0.1. For windows/MAC 0.0.0.0 or 127.0.0.1 will work, but may need to specify actual LAN IP address for other LAN clients to see. + :param share: whether to share the gradio app with sharable URL + :param open_browser: whether to automatically open browser tab with gradio UI :param root_path: The root path (or "mount point") of the application, if it's not served from the root ("/") of the domain. Often used when the application is behind a reverse proxy that forwards requests to the application. For example, if the application is served at "https://example.com/myapp", the `root_path` should be set to "/myapp". + :param ssl_verify: passed go gradio launch + :param ssl_keyfile: passed go gradio launch + :param ssl_certfile: passed go gradio launch + :param ssl_keyfile_password: passed go gradio launch + :param chat: whether to enable chat mode with chat history :param chat_conversation: list of tuples of (human, bot) conversation pre-appended to existing chat when using instruct/chat models Requires also add_chat_history_to_context = True @@ -450,6 +524,8 @@ def main( :param h2ocolors: whether to use H2O.ai theme :param dark: whether to use dark mode for UI by default (still controlled in UI) :param height: height of chat window + :param render_markdown: Whether to render markdown in chatbot UI. In some cases this distorts the rendering. + https://github.com/gradio-app/gradio/issues/4344#issuecomment-1771963021 :param show_lora: whether to show LORA options in UI (expert so can be hard to understand) :param show_llama: whether to show LLaMa.cpp/GPT4All options in UI (only likely useful if have weak GPUs) :param show_gpt4all: whether to show GPT4All models in UI (not often useful, llama.cpp models best) @@ -481,19 +557,38 @@ def main( :param guest_name: guess name if using auth and have open access. If '', then no guest allowed even if open access, then all databases for each user always persisted :param enforce_h2ogpt_api_key: Whether to enforce h2oGPT token usage for API + :param enforce_h2ogpt_ui_key: Whether to enforce h2oGPT token usage for UI (same keys as API assumed) :param h2ogpt_api_keys: list of tokens allowed for API access or file accessed on demand for json of list of keys :param h2ogpt_key: E.g. can be set when accessing gradio h2oGPT server from local gradio h2oGPT server that acts as client to that inference server :param max_max_time: Maximum max_time for gradio slider :param max_max_new_tokens: Maximum max_new_tokens for gradio slider :param min_max_new_tokens: Minimum of max_new_tokens, when auto-scaling down to handle more docs/prompt, but still let generation have some tokens - + :param max_input_tokens: Max input tokens to place into model context for each LLM call + -1 means auto, fully fill context for query, and fill by original document chunk for summarization + >=0 means use that to limit context filling to that many tokens + :param docs_token_handling: 'chunk' means fill context with top_k_docs (limited by max_input_tokens or model_max_len) chunks for query + or top_k_docs original document chunks summarization + None or 'split_or_merge' means same as 'chunk' for query, while for summarization merges documents to fill up to max_input_tokens or model_max_len tokens + + :param docs_joiner: string to join lists of text when doing split_or_merge. None means '\n\n' + :param hyde_level: HYDE level for HYDE approach (https://arxiv.org/abs/2212.10496) + 0: No HYDE + 1: Use non-document-based LLM response and original query for embedding query + 2: Use document-based LLM response and original query for embedding query + 3+: Continue iterations of embedding prior answer and getting new response + :param hyde_template: + None, 'None', 'auto' uses internal value and enable + '{query}' is minimal template one can pass :param visible_models: Which models in model_lock list to show by default Takes integers of position in model_lock (model_states) list or strings of base_model names Ignored if model_lock not used For nochat API, this is single item within a list for model by name or by index in model_lock If None, then just use first model in model_lock list If model_lock not set, use model selected by CLI --base_model etc. + Note that unlike h2ogpt_key, this visible_models only applies to this running h2oGPT server, + and the value is not used to access the inference server. + If need a visible_models for an inference server, then use --model_lock and group together. :param visible_visible_models: Whether visible models drop-down is visible in UI :param visible_submit_buttons: whether submit buttons are visible when UI first comes up @@ -507,7 +602,7 @@ def main( :param visible_models_tab: "" for models tab :param visible_system_tab: "" for system tab :param visible_tos_tab: "" for ToS tab - :param visible_login_tab: "" for Login tab + :param visible_login_tab: "" for Login tab (needed for persistence or to enter key for UI access to models and ingestion) :param visible_hosts_tab: "" for hosts tab :param chat_tables: Just show Chat as block without tab (useful if want only chat view) :param visible_h2ogpt_header: Whether github stars, URL, logo, and QR code are visible @@ -556,6 +651,7 @@ def main( Summarize or Summarize_map_reduce: Summarize document(s) via map_reduce Summarize_all: Summarize document(s) using entire document at once Summarize_refine: Summarize document(s) using entire document, and try to refine before returning summary + Extract: Extract information from document(s) via map (no reduce) :param langchain_agents: Which agents to use 'search': Use Web Search as context for LLM response, e.g. SERP if have SERPAPI_API_KEY in env :param force_langchain_evaluate: Whether to force langchain LLM use even if not doing langchain, mostly for testing. @@ -596,9 +692,9 @@ def main( :param show_link_in_sources: Whether to show URL link to source document in references :param pre_prompt_query: prompt before documents to query, if None then use internal defaults :param prompt_query: prompt after documents to query, if None then use internal defaults - :param pre_prompt_summary: prompt before documents to summarize, if None then use internal defaults - :param prompt_summary: prompt after documents to summarize, if None then use internal defaults - For summarize, normal to have empty query (nothing added in ask anything in UI or empty string in API) + :param pre_prompt_summary: prompt before documents to summarize/extract from, if None then use internal defaults + :param prompt_summary: prompt after documents to summarize/extract from, if None then use internal defaults + For summarize/extract, normal to have empty query (nothing added in ask anything in UI or empty string in API) If pass query, template is "Focusing on %s, %s" % (query, prompt_summary) If pass query and iinput, template is "Focusing on %s, %s, %s" % (query, iinput, prompt_summary) :param add_chat_history_to_context: Include chat context when performing action @@ -619,7 +715,7 @@ def main( :param chunk_size: Size of chunks, with typically top-4 passed to LLM, so needs to be in context length :param top_k_docs: For langchain_action query: number of chunks to give LLM -1 : auto-fills context up to max_seq_len - For langchain_action summarize: number of document parts, like pages for PDF. + For langchain_action summarize/extract: number of document parts, like pages for PDF. There's no such thing as chunks for summarization. -1 : auto-fills context up to max_seq_len :param docs_ordering_type: @@ -655,9 +751,10 @@ def main( :param enable_captions: Whether to support captions using BLIP for image files as documents, then preloads that model if pre_load_caption_model=True - :param pre_load_caption_model: Whether to preload caption model, or load after forking parallel doc loader + :param pre_load_caption_model: Whether to preload caption model (True), or load after forking parallel doc loader (False) parallel loading disabled if preload and have images, to prevent deadlocking on cuda context - Recommended if using larger caption model + Recommended if using larger caption model or doing production serving with many users to avoid GPU OOM if many would use model at same time + Also applies to DocTR :param captions_model: Which model to use for captions. captions_model: str = "Salesforce/blip-image-captioning-base", # continue capable captions_model: str = "Salesforce/blip2-flan-t5-xl", # question/answer capable, 16GB state @@ -665,8 +762,10 @@ def main( Note: opt-based blip2 are not permissive license due to opt and Meta license restrictions Disabled for CPU since BLIP requires CUDA :param caption_gpu: If support caption, then use GPU if exists + :param caption_gpu_id: Which GPU id to use, if 'auto' then select 0 :param doctr_gpu: If support doctr, then use GPU if exists + :param doctr_gpu_id: Which GPU id to use, if 'auto' then select 0 :param jq_schema: control json loader By default '.[]' ingests everything in brute-force way, but better to match your schema @@ -711,6 +810,11 @@ def main( if 'n_gqa' not in llamacpp_dict: llamacpp_dict['n_gqa'] = 0 + exllama_dict = str_to_dict(exllama_dict) + gptq_dict = str_to_dict(gptq_dict) + sink_dict = str_to_dict(sink_dict) + hf_model_dict = str_to_dict(hf_model_dict) + if os.environ.get('SERPAPI_API_KEY') is None and LangChainAgent.SEARCH.value in visible_langchain_agents: visible_langchain_agents.remove(LangChainAgent.SEARCH.value) @@ -729,6 +833,9 @@ def main( is_hf = bool(int(os.getenv("HUGGINGFACE_SPACES", '0'))) is_gpth2oai = bool(int(os.getenv("GPT_H2O_AI", '0'))) is_public = is_hf or is_gpth2oai # multi-user case with fixed model and disclaimer + if enforce_h2ogpt_ui_key is None: + # nominally allow UI access public or not + enforce_h2ogpt_ui_key = False if is_public: visible_tos_tab = visible_hosts_tab = True if enforce_h2ogpt_api_key is None: @@ -855,6 +962,7 @@ def main( temperature = 0.2 if temperature is None else temperature top_p = 0.85 if top_p is None else top_p top_k = 70 if top_k is None else top_k + penalty_alpha = 0.0 if penalty_alpha is None else penalty_alpha if is_hf: do_sample = True if do_sample is None else do_sample top_k_docs = 3 if top_k_docs is None else top_k_docs @@ -932,6 +1040,7 @@ def main( # wouldn't work if specified True, but respect load_half = False load_gptq = '' + load_awq = '' load_exllama = False use_gpu_id = False if get_device() == "cuda": @@ -966,7 +1075,7 @@ def main( model_lower = base_model.lower() elif model_lock: # have 0th model be thought of as normal model - assert len(model_lock) > 0 and model_lock[0]['base_model'] + assert len(model_lock) > 0 and model_lock[0]['base_model'], "model_lock: %s" % model_lock model_lower = model_lock[0]['base_model'].lower() else: model_lower = '' @@ -1003,7 +1112,7 @@ def main( placeholder_instruction, placeholder_input, \ stream_output, show_examples, \ prompt_type, prompt_dict, \ - temperature, top_p, top_k, num_beams, \ + temperature, top_p, top_k, penalty_alpha, num_beams, \ max_new_tokens, min_new_tokens, early_stopping, max_time, \ repetition_penalty, num_return_sequences, \ do_sample, \ @@ -1017,7 +1126,7 @@ def main( system_prompt, pre_prompt_query, prompt_query, pre_prompt_summary, prompt_summary, - temperature, top_p, top_k, num_beams, + temperature, top_p, top_k, penalty_alpha, num_beams, max_new_tokens, min_new_tokens, early_stopping, max_time, repetition_penalty, num_return_sequences, do_sample, @@ -1030,6 +1139,11 @@ def main( jq_schema, docs_ordering_type, min_max_new_tokens, + max_input_tokens, + docs_token_handling, + docs_joiner, + hyde_level, + hyde_template, verbose, ) @@ -1079,9 +1193,18 @@ def main( assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have" assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have" + if attention_sinks: + if use_cache is False: + raise ValueError("attention sinks requires use_cache=True") + else: + use_cache = True + # never truncate if using attention sinks + truncation_generation = truncation_generation and not attention_sinks + other_model_state_defaults = dict(load_8bit=load_8bit, load_4bit=load_4bit, low_bit_mode=low_bit_mode, load_half=load_half, - load_gptq=load_gptq, load_exllama=load_exllama, use_safetensors=use_safetensors, + load_gptq=load_gptq, load_awq=load_awq, load_exllama=load_exllama, + use_safetensors=use_safetensors, revision=revision, use_gpu_id=use_gpu_id, gpu_id=gpu_id, compile_model=compile_model, use_cache=use_cache, @@ -1089,6 +1212,14 @@ def main( model_name_gptj=model_name_gptj, model_name_gpt4all_llama=model_name_gpt4all_llama, model_name_exllama_if_no_config=model_name_exllama_if_no_config, + rope_scaling=rope_scaling, + max_seq_len=max_seq_len, + exllama_dict=exllama_dict, + gptq_dict=gptq_dict, + attention_sinks=attention_sinks, + sink_dict=sink_dict, + truncation_generation=truncation_generation, + hf_model_dict=hf_model_dict, ) model_state_none = dict(model=None, tokenizer=None, device=None, base_model=None, tokenizer_base_model=None, lora_weights=None, @@ -1193,6 +1324,9 @@ def main( model0, tokenizer0, device = get_model(reward_type=False, **get_kwargs(get_model, exclude_names=['reward_type'], **all_kwargs)) + # update model state + if hasattr(tokenizer0, 'model_max_length'): + model_dict['max_seq_len'] = tokenizer0.model_max_length else: # if empty model, then don't load anything, just get gradio up model0, tokenizer0, device = None, None, None @@ -1229,6 +1363,10 @@ def main( # This is just so UI shows reasonable correct value, not 2048 dummy value if len(model_states) >= 1: max_seq_len = model_states[0]['tokenizer'].model_max_length + elif model_state0 is not None and \ + 'tokenizer' in model_state0 and \ + hasattr(model_state0['tokenizer'], 'model_max_length'): + max_seq_len = model_state0['tokenizer'].model_max_length # get score model all_kwargs = locals().copy() @@ -1243,7 +1381,7 @@ def main( if enable_captions: if pre_load_caption_model: from image_captions import H2OImageCaptionLoader - caption_loader = H2OImageCaptionLoader(caption_gpu=caption_gpu).load_model() + caption_loader = H2OImageCaptionLoader(caption_gpu=caption_gpu, gpu_id=caption_gpu_id).load_model() else: caption_loader = 'gpu' if n_gpus > 0 and caption_gpu else 'cpu' else: @@ -1256,8 +1394,13 @@ def main( hf_embedding_model = dict(name=hf_embedding_model, model=get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model, preload=True)) + if enable_doctr or enable_pdf_ocr in [True, 'auto', 'on']: - doctr_loader = 'gpu' if n_gpus > 0 and doctr_gpu else 'cpu' + if pre_load_caption_model: + from image_doctr import H2OOCRLoader + doctr_loader = H2OOCRLoader(layout_aware=True, gpu_id=doctr_gpu_id) + else: + doctr_loader = 'gpu' if n_gpus > 0 and caption_gpu else 'cpu' else: doctr_loader = False @@ -1282,11 +1425,15 @@ def get_config(base_model, with init_empty_weights(): from transformers import AutoConfig try: - config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token, + if rope_scaling: + rope_kwargs = dict(rope_scaling=rope_scaling) + else: + rope_kwargs = {} + config = AutoConfig.from_pretrained(base_model, token=use_auth_token, trust_remote_code=trust_remote_code, offload_folder=offload_folder, revision=revision, - rope_scaling=rope_scaling if rope_scaling else None) + **rope_kwargs) except OSError as e: if raise_exception: raise @@ -1328,6 +1475,9 @@ def get_config(base_model, else: if hasattr(config, 'max_seq_len'): max_seq_len = int(config.max_seq_len) + # Note https://huggingface.co/lmsys/vicuna-13b-v1.5-16k/blob/main/config.json has below, but here just want base size before rope + # elif hasattr(config, 'max_sequence_length'): + # max_seq_len = int(config.max_sequence_length) elif hasattr(config, 'max_position_embeddings') and isinstance(config.max_position_embeddings, int): # help automatically limit inputs to generate max_seq_len = config.max_position_embeddings @@ -1345,6 +1495,10 @@ def get_config(base_model, # raise RuntimeError("Could not determine max_seq_len," # " please pass --max_seq_len and set to some value, e.g. 2048.") + # listen to model if sets this and user passed nothing + if not rope_scaling and hasattr(config, 'rope_scaling'): + rope_scaling = config.rope_scaling + if rope_scaling: if rope_scaling.get('factor'): # HF transformers @@ -1353,13 +1507,16 @@ def get_config(base_model, # exllama # Note: exllama's own tokenizer has this set correctly in loaders.py, this config will be unused max_seq_len *= rope_scaling.get('alpha_value') - print("Automatically setting max_seq_len=%d for RoPE scaling" % max_seq_len, flush=True) + max_seq_len = int(max_seq_len) + print("Automatically setting max_seq_len=%d for RoPE scaling for %s" % (max_seq_len, base_model), + flush=True) return config, model, max_seq_len def get_non_lora_model(base_model, model_loader, load_half, load_gptq, + load_awq, load_exllama, use_safetensors, revision, @@ -1418,8 +1575,6 @@ def get_non_lora_model(base_model, model_loader, load_half, if load_exllama: model = model_loader elif load_gptq: - if 'Llama-2-70B-chat-GPTQ' in base_model: - model_kwargs.update(dict(inject_fused_attention=False)) model_kwargs.pop('torch_dtype', None) model_kwargs.pop('device_map') model = model_loader( @@ -1427,6 +1582,23 @@ def get_non_lora_model(base_model, model_loader, load_half, model_basename=load_gptq, **model_kwargs, ) + elif load_awq: + allowed_dict = dict(max_new_tokens=None, + trust_remote_code=True, fuse_layers=True, + batch_size=1, safetensors=False, + max_memory=None, offload_folder=None) + for k in model_kwargs.copy(): + if k not in allowed_dict: + model_kwargs.pop(k) + if load_awq.endswith('.pt'): + args = tuple([base_model, load_awq]) + else: + args = tuple([base_model]) + model = model_loader( + *args, + safetensors=use_safetensors, + **model_kwargs, + ) elif load_in_8bit or load_in_4bit or not load_half: model = model_loader( base_model, @@ -1434,7 +1606,6 @@ def get_non_lora_model(base_model, model_loader, load_half, **model_kwargs, ) else: - model = model_loader( base_model, config=config, @@ -1456,7 +1627,7 @@ def get_client_from_inference_server(inference_server, base_model=None, raise_co print("GR Client Begin: %s %s" % (inference_server, base_model), flush=True) # first do sanity check if alive, else gradio client takes too long by default requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT', '30'))) - gr_client = GradioClient(inference_server) + gr_client = GradioClient(inference_server).setup() print("GR Client End: %s" % inference_server, flush=True) except (OSError, ValueError) as e: # Occurs when wrong endpoint and should have been HF client, so don't hard raise, just move to HF @@ -1497,6 +1668,7 @@ def get_model( low_bit_mode: int = 1, load_half: bool = True, load_gptq: str = '', + load_awq: str = '', load_exllama: bool = False, use_safetensors: bool = False, revision: str = None, @@ -1518,6 +1690,12 @@ def get_model( max_seq_len: int = None, compile_model: bool = True, llamacpp_dict=None, + exllama_dict=None, + gptq_dict=None, + attention_sinks=None, + sink_dict=None, + truncation_generation=None, + hf_model_dict={}, verbose: bool = False, ): @@ -1528,6 +1706,7 @@ def get_model( :param low_bit_mode: See gen.py :param load_half: load model in 16-bit :param load_gptq: GPTQ model_basename + :param load_awq: AWQ model_basename :param load_exllama: whether to use exllama :param use_safetensors: use safetensors file :param revision: @@ -1551,6 +1730,12 @@ def get_model( :param max_seq_len: if set, use as max_seq_len for model :param compile_model: whether to compile torch model :param llamacpp_dict: dict of llama.cpp and GPT4All model options + :param exllama_dict: dict of exllama options + :param gptq_dict: dict of AutoGPTQ options + :param attention_sinks: whether to use attention_sinks package + :param sink_dict: dict of attention sinks options + :param truncation_generation: whether to truncate generation in torch case to max_seq_len + :param hf_model_dict :param verbose: :return: """ @@ -1586,13 +1771,18 @@ def get_model( '') model_loader, tokenizer_loader, conditional_type = ( get_loaders(model_name=base_model, reward_type=reward_type, llama_type=llama_type, - load_gptq=load_gptq, load_exllama=load_exllama, config=config, + load_gptq=load_gptq, load_awq=load_awq, load_exllama=load_exllama, + config=config, rope_scaling=rope_scaling, max_seq_len=max_seq_len, - model_name_exllama_if_no_config=model_name_exllama_if_no_config)) + model_name_exllama_if_no_config=model_name_exllama_if_no_config, + exllama_dict=exllama_dict, gptq_dict=gptq_dict, + attention_sinks=attention_sinks, sink_dict=sink_dict, + truncation_generation=truncation_generation, + hf_model_dict=hf_model_dict)) tokenizer_kwargs = dict(local_files_only=local_files_only, resume_download=resume_download, - use_auth_token=use_auth_token, + token=use_auth_token, trust_remote_code=trust_remote_code, offload_folder=offload_folder, revision=revision, @@ -1615,7 +1805,7 @@ def get_model( set_model_max_len(max_seq_len, tokenizer, verbose=False) # if using fake tokenizer, not really accurate when lots of numbers, give a bit of buffer, else get: # Generation Failed: Input validation error: `inputs` must have less than 2048 tokens. Given: 2233 - tokenizer.model_max_length = tokenizer.model_max_length - 50 + tokenizer.model_max_length = int(tokenizer.model_max_length - 50) else: tokenizer = None @@ -1662,7 +1852,8 @@ def get_model( # include small token cushion if inference_server.startswith('openai') or tokenizer is None: # don't use fake (tiktoken) tokenizer for vLLM//replicate if know actual model with actual tokenizer - tokenizer = FakeTokenizer(model_max_length=max_seq_len - 50) + assert max_seq_len is not None, "Please pass --max_seq_len= for unknown or non-HF model %s" % base_model + tokenizer = FakeTokenizer(model_max_length=max_seq_len - 50, is_openai=True) return inference_server, tokenizer, inference_server assert not inference_server, "Malformed inference_server=%s" % inference_server if base_model in non_hf_types: @@ -1680,6 +1871,7 @@ def get_model( low_bit_mode=low_bit_mode, load_half=load_half, load_gptq=load_gptq, + load_awq=load_awq, use_safetensors=use_safetensors, revision=revision, use_gpu_id=use_gpu_id, @@ -1700,6 +1892,11 @@ def get_model( llama_type=llama_type, config_kwargs=config_kwargs, tokenizer_kwargs=tokenizer_kwargs, + gptq_dict=gptq_dict, + attention_sinks=attention_sinks, + sink_dict=sink_dict, + truncation_generation=truncation_generation, + hf_model_dict=hf_model_dict, verbose=verbose) @@ -1709,6 +1906,7 @@ def get_hf_model(load_8bit: bool = False, low_bit_mode: int = 1, load_half: bool = True, load_gptq: str = '', + load_awq: str = '', use_safetensors: bool = False, revision: str = None, use_gpu_id: bool = True, @@ -1729,6 +1927,11 @@ def get_hf_model(load_8bit: bool = False, llama_type: bool = False, config_kwargs=None, tokenizer_kwargs=None, + gptq_dict=None, + attention_sinks=None, + sink_dict=None, + truncation_generation=None, + hf_model_dict=None, verbose: bool = False, ): @@ -1736,6 +1939,7 @@ def get_hf_model(load_8bit: bool = False, assert tokenizer_kwargs is not None load_exllama = False # Never should be in HF code for exllama + exllama_dict = {} if lora_weights is not None and lora_weights.strip(): if verbose: @@ -1753,7 +1957,11 @@ def get_hf_model(load_8bit: bool = False, model_loader, tokenizer_loader, conditional_type = ( get_loaders(model_name=base_model, reward_type=reward_type, llama_type=llama_type, - load_gptq=load_gptq, load_exllama=load_exllama)) + load_gptq=load_gptq, load_awq=load_awq, load_exllama=load_exllama, + exllama_dict=exllama_dict, gptq_dict=gptq_dict, + attention_sinks=attention_sinks, sink_dict=sink_dict, + truncation_generation=truncation_generation, + hf_model_dict=hf_model_dict)) config, _, max_seq_len = get_config(base_model, return_model=False, raise_exception=True, **config_kwargs) @@ -1777,7 +1985,7 @@ def get_hf_model(load_8bit: bool = False, model_kwargs = dict(local_files_only=local_files_only, torch_dtype=torch.float16 if device == 'cuda' else torch.float32, resume_download=resume_download, - use_auth_token=use_auth_token, + token=use_auth_token, trust_remote_code=trust_remote_code, offload_folder=offload_folder, revision=revision, @@ -1832,13 +2040,14 @@ def get_hf_model(load_8bit: bool = False, if not lora_weights: # torch.device context uses twice memory for AutoGPTQ - context = NullContext if load_gptq else torch.device + context = NullContext if (load_gptq or load_awq) else torch.device with context(device): if use_gpu_id: config, model, max_seq_len = get_config(base_model, return_model=True, raise_exception=True, **config_kwargs) - model = get_non_lora_model(base_model, model_loader, load_half, load_gptq, + model = get_non_lora_model(base_model, model_loader, load_half, + load_gptq, load_awq, load_exllama, use_safetensors, revision, @@ -1847,8 +2056,10 @@ def get_hf_model(load_8bit: bool = False, gpu_id=gpu_id, ) else: + model_kwargs['use_safetensors'] = use_safetensors + model_kwargs['revision'] = revision config, _, max_seq_len = get_config(base_model, **config_kwargs) - if load_half and not (load_8bit or load_4bit or load_gptq): + if load_half and not (load_8bit or load_4bit or load_gptq or load_awq): model = model_loader( base_model, config=config, @@ -1856,10 +2067,36 @@ def get_hf_model(load_8bit: bool = False, if not getattr(model, "is_quantized", False): model = model.half() else: - model = model_loader( - base_model, - config=config, - **model_kwargs) + if load_gptq: + model_kwargs.pop('torch_dtype', None) + model_kwargs.pop('device_map') + model = model_loader( + model_name_or_path=base_model, + model_basename=load_gptq, + **model_kwargs, + ) + elif load_awq: + allowed_dict = dict(max_new_tokens=None, + trust_remote_code=True, fuse_layers=True, + batch_size=1, safetensors=False, + max_memory=None, offload_folder=None) + for k in model_kwargs.copy(): + if k not in allowed_dict: + model_kwargs.pop(k) + if load_awq.endswith('.pt'): + args = tuple([base_model, load_awq]) + else: + args = tuple([base_model]) + model = model_loader( + *args, + safetensors=use_safetensors, + **model_kwargs, + ) + else: + model = model_loader( + base_model, + config=config, + **model_kwargs) elif load_8bit or load_4bit: config, _, max_seq_len = get_config(base_model, **config_kwargs) model = model_loader( @@ -1874,7 +2111,7 @@ def get_hf_model(load_8bit: bool = False, torch_dtype=torch.float16 if device == 'cuda' else torch.float32, local_files_only=local_files_only, resume_download=resume_download, - use_auth_token=use_auth_token, + token=use_auth_token, trust_remote_code=trust_remote_code, offload_folder=offload_folder, rope_scaling=rope_scaling, @@ -1896,13 +2133,13 @@ def get_hf_model(load_8bit: bool = False, torch_dtype=torch.float16 if device == 'cuda' else torch.float32, local_files_only=local_files_only, resume_download=resume_download, - use_auth_token=use_auth_token, + token=use_auth_token, trust_remote_code=trust_remote_code, offload_folder=offload_folder, rope_scaling=rope_scaling, device_map="auto", ) - if load_half and not load_gptq: + if load_half and not (load_gptq or load_awq): if not getattr(model, "is_quantized", False): model = model.half() @@ -1964,6 +2201,7 @@ def get_score_model(score_model: str = None, low_bit_mode=1, load_half: bool = True, load_gptq: str = '', + load_awq: str = '', load_exllama: bool = False, use_gpu_id: bool = True, base_model: str = '', @@ -1982,6 +2220,12 @@ def get_score_model(score_model: str = None, rope_scaling: dict = None, compile_model: bool = True, llamacpp_dict: typing.Dict = None, + exllama_dict: typing.Dict = None, + gptq_dict: typing.Dict = None, + attention_sinks: bool = False, + sink_dict: typing.Dict = None, + truncation_generation: bool = False, + hf_model_dict: typing.Dict = None, verbose: bool = False, ): @@ -1991,6 +2235,7 @@ def get_score_model(score_model: str = None, low_bit_mode = 1 load_half = False load_gptq = '' + load_awq = '' load_exllama = False use_safetensors = False revision = None @@ -2000,8 +2245,15 @@ def get_score_model(score_model: str = None, inference_server = '' llama_type = False max_seq_len = None + rope_scaling = {} compile_model = False llamacpp_dict = {} + exllama_dict = {} + gptq_dict = {} + attention_sinks = False + sink_dict = {} + truncation_generation = False + hf_model_dict = {} smodel, stokenizer, sdevice = get_model(reward_type=True, **get_kwargs(get_model, exclude_names=['reward_type'], **locals())) else: @@ -2010,7 +2262,7 @@ def get_score_model(score_model: str = None, def evaluate_fake(*args, **kwargs): - yield dict(response=invalid_key_msg, sources='') + yield dict(response=invalid_key_msg, sources='', save_dict=dict(), llm_answers={}) return @@ -2029,6 +2281,7 @@ def evaluate( temperature, top_p, top_k, + penalty_alpha, num_beams, max_new_tokens, min_new_tokens, @@ -2066,6 +2319,11 @@ def evaluate( text_context_list, docs_ordering_type, min_max_new_tokens, + max_input_tokens, + docs_token_handling, + docs_joiner, + hyde_level, + hyde_template, # END NOTE: Examples must have same order of parameters captions_model=None, @@ -2105,6 +2363,7 @@ def evaluate( top_k_docs_max_show=None, show_link_in_sources=None, verbose=False, + gradio=True, cli=False, use_cache=None, auto_reduce_chunks=None, @@ -2113,6 +2372,14 @@ def evaluate( model_lock=None, force_langchain_evaluate=None, model_state_none=None, + llamacpp_dict=None, + exllama_dict=None, + gptq_dict=None, + attention_sinks=None, + sink_dict=None, + truncation_generation=None, + hf_model_dict=None, + load_exllama=None, answer_with_sources=None, append_sources_to_answer=None, @@ -2236,6 +2503,18 @@ def evaluate( instruction = instruction_nochat iinput = iinput_nochat + # avoid instruction in chat_conversation itself, since always used as additional context to prompt in what follows + if isinstance(chat_conversation, list) and \ + len(chat_conversation) > 0 and \ + len(chat_conversation[-1]) == 2 and \ + chat_conversation[-1][0] == instruction and \ + chat_conversation[-1][1] in [None, '']: + chat_conversation = chat_conversation[:-1] + if not add_chat_history_to_context: + # make it easy to ignore without needing add_chat_history_to_context + # some langchain or unit test may need to then handle more general case + chat_conversation = [] + # in some cases, like lean nochat API, don't want to force sending prompt_type, allow default choice model_lower = base_model.lower() if not prompt_type and model_lower in inv_prompt_type_to_model_lower and prompt_type != 'custom': @@ -2250,18 +2529,30 @@ def evaluate( # limits are chosen similar to gradio_runner.py sliders/numbers top_p = min(max(1e-3, top_p), 1.0 - 1e-3) top_k = min(max(1, int(top_k)), 100) + penalty_alpha = min(2.0, max(0.0, penalty_alpha)) temperature = min(max(0.01, temperature), 2.0) # FIXME: https://github.com/h2oai/h2ogpt/issues/106 num_beams = 1 if stream_output else num_beams # See max_beams in gradio_runner + if model_lower == 'distilgpt2': + # always truncate for certain models that totally fail otherwise + truncation_generation = True max_max_new_tokens = get_max_max_new_tokens(chosen_model_state, memory_restriction_level=memory_restriction_level, max_new_tokens=max_new_tokens, - max_max_new_tokens=max_max_new_tokens) + attention_sinks=attention_sinks, + max_max_new_tokens=max_max_new_tokens, + truncation_generation=truncation_generation) if min_max_new_tokens is None: # default for nochat api min_max_new_tokens = 256 + if max_input_tokens is None: + max_input_tokens = -1 if docs_ordering_type is None: - docs_ordering_type = 'reverse_ucurve_sort' + docs_ordering_type = docs_ordering_types_default + if docs_token_handling is None: + docs_token_handling = docs_token_handling_default + if docs_joiner is None: + docs_joiner = docs_joiner_default model_max_length = get_model_max_length(chosen_model_state) max_new_tokens = min(max(1, int(max_new_tokens)), max_max_new_tokens) min_new_tokens = min(max(0, int(min_new_tokens)), max_new_tokens) @@ -2279,6 +2570,12 @@ def evaluate( if not context: context = '' + # NOTE!!!!!!!!!! Choice of developer. But only possible to force stream if num_beams=1 + # stream if can, so can control task iteration and time of iteration + # not required, but helpful for max_time control etc. + stream_output0 = stream_output + stream_output = gradio and num_beams == 1 + # get prompter prompter = Prompter(prompt_type, prompt_dict, debug=debug, chat=chat, stream_output=stream_output, system_prompt=system_prompt) @@ -2335,8 +2632,9 @@ def evaluate( gen_hyper_langchain = dict(do_sample=do_sample, temperature=temperature, repetition_penalty=repetition_penalty, - top_k=top_k, top_p=top_p, + top_k=top_k, + penalty_alpha=penalty_alpha, num_beams=num_beams, min_new_tokens=min_new_tokens, max_new_tokens=max_new_tokens, @@ -2360,6 +2658,7 @@ def evaluate( prompt_basic = prompter.generate_prompt(data_point, context_from_history=False) prompt = prompt_basic num_prompt_tokens = 0 + llm_answers = {} for r in run_qa_db( inference_server=inference_server, model_name=base_model, model=model, tokenizer=tokenizer, @@ -2396,6 +2695,7 @@ def evaluate( query=instruction, iinput=iinput, context=context, + stream_output0=stream_output0, stream_output=stream_output, chunk=chunk, chunk_size=chunk_size, @@ -2420,6 +2720,11 @@ def evaluate( h2ogpt_key=h2ogpt_key, docs_ordering_type=docs_ordering_type, min_max_new_tokens=min_max_new_tokens, + max_input_tokens=max_input_tokens, + docs_token_handling=docs_token_handling, + docs_joiner=docs_joiner, + hyde_level=hyde_level, + hyde_template=hyde_template, **gen_hyper_langchain, @@ -2430,6 +2735,13 @@ def evaluate( sanitize_bot_response=sanitize_bot_response, lora_weights=lora_weights, + llamacpp_dict=llamacpp_dict, + exllama_dict=exllama_dict, + gptq_dict=gptq_dict, + attention_sinks=attention_sinks, + sink_dict=sink_dict, + truncation_generation=truncation_generation, + hf_model_dict=hf_model_dict, auto_reduce_chunks=auto_reduce_chunks, max_chunks=max_chunks, @@ -2441,7 +2753,8 @@ def evaluate( sources = r['sources'] prompt = r['prompt'] num_prompt_tokens = r['num_prompt_tokens'] - yield dict(response=response, sources=sources, save_dict=dict()) + llm_answers = r['llm_answers'] + yield dict(response=response, sources=sources, save_dict=dict(), llm_answers=llm_answers) if save_dir: # estimate using tiktoken extra_dict = gen_hyper_langchain.copy() @@ -2466,7 +2779,7 @@ def evaluate( output=response, base_model=base_model, save_dir=save_dir, where_from='run_qa_db', extra_dict=extra_dict) - yield dict(response=response, sources=sources, save_dict=save_dict) + yield dict(response=response, sources=sources, save_dict=save_dict, llm_answers=llm_answers) if verbose: print( 'Post-Generate Langchain: %s decoded_output: %s' % @@ -2481,10 +2794,14 @@ def evaluate( # NOT LANGCHAIN PATH, raw LLM # restrict instruction + , typically what has large input + from gradio_utils.grclient import GradioClient + gradio_server = inference_server.startswith('http') and isinstance(model, GradioClient) + prompt, \ instruction, iinput, context, \ num_prompt_tokens, max_new_tokens, num_prompt_tokens0, num_prompt_tokens_actual, \ - chat_index, top_k_docs_trial, one_doc_size = \ + chat_index, external_handle_chat_conversation, \ + top_k_docs_trial, one_doc_size, truncation_generation = \ get_limited_prompt(instruction, iinput, tokenizer, @@ -2503,6 +2820,9 @@ def evaluate( langchain_mode=langchain_mode, add_chat_history_to_context=add_chat_history_to_context, min_max_new_tokens=min_max_new_tokens, + max_input_tokens=max_input_tokens, + truncation_generation=truncation_generation, + gradio_server=gradio_server, ) if inference_server.startswith('vllm') or \ @@ -2511,7 +2831,7 @@ def evaluate( if inference_server.startswith('vllm') or inference_server.startswith('openai'): assert not inference_server.startswith('openai_azure_chat'), "Not fo Azure, use langchain path" assert not inference_server.startswith('openai_azure'), "Not for Azure, use langchain path" - openai, inf_type, deployment_name, base_url, api_version = set_openai(inference_server) + openai, inf_type, deployment_name, base_url, api_version, api_key = set_openai(inference_server) where_from = inf_type terminate_response = prompter.terminate_response or [] @@ -2541,19 +2861,22 @@ def evaluate( text = responses['choices'][0]['text'] response = prompter.get_response(prompt + text, prompt=prompt, sanitize_bot_response=sanitize_bot_response) - yield dict(response=response, sources=sources, save_dict=dict()) + yield dict(response=response, sources=sources, save_dict=dict(), llm_answers={}) else: collected_events = [] + tgen0 = time.time() for event in responses: collected_events.append(event) # save the event response event_text = event['choices'][0]['text'] # extract the text text += event_text # append the text response = prompter.get_response(prompt + text, prompt=prompt, sanitize_bot_response=sanitize_bot_response) - yield dict(response=response, sources=sources, save_dict=dict()) + yield dict(response=response, sources=sources, save_dict=dict(), llm_answers={}) + if time.time() - tgen0 > max_time: + if verbose: + print("Took too long for OpenAI or VLLM: %s" % (time.time() - tgen0), flush=True) + break elif inf_type == 'vllm_chat' or inference_server == 'openai_chat': - if inf_type == 'vllm_chat': - raise NotImplementedError('%s not supported by vLLM' % inf_type) if system_prompt in [None, 'None', 'auto']: openai_system_prompt = "You are a helpful assistant." else: @@ -2561,7 +2884,16 @@ def evaluate( messages0 = [] if openai_system_prompt: messages0.append({"role": "system", "content": openai_system_prompt}) - messages0.append({'role': 'user', 'content': prompt}) + if chat_conversation and add_chat_history_to_context: + assert external_handle_chat_conversation, "Should be handling only externally" + # chat_index handles token counting issues + for message1 in chat_conversation[chat_index:]: + if len(message1) == 2: + messages0.append( + {'role': 'user', 'content': message1[0] if message1[0] is not None else ''}) + messages0.append( + {'role': 'assistant', 'content': message1[1] if message1[1] is not None else ''}) + messages0.append({'role': 'user', 'content': prompt if prompt is not None else ''}) responses = openai.ChatCompletion.create( model=base_model, messages=messages0, @@ -2575,23 +2907,27 @@ def evaluate( text = responses["choices"][0]["message"]["content"] response = prompter.get_response(prompt + text, prompt=prompt, sanitize_bot_response=sanitize_bot_response) - yield dict(response=response, sources=sources, save_dict=dict()) + yield dict(response=response, sources=sources, save_dict=dict(), llm_answers={}) else: + tgen0 = time.time() for chunk in responses: delta = chunk["choices"][0]["delta"] if 'content' in delta: text += delta['content'] response = prompter.get_response(prompt + text, prompt=prompt, sanitize_bot_response=sanitize_bot_response) - yield dict(response=response, sources=sources, save_dict=dict()) + yield dict(response=response, sources=sources, save_dict=dict(), llm_answers={}) + if time.time() - tgen0 > max_time: + if verbose: + print("Took too long for OpenAI or VLLM Chat: %s" % (time.time() - tgen0), flush=True) + break else: raise RuntimeError("No such OpenAI mode: %s" % inference_server) elif inference_server.startswith('http'): inference_server, headers = get_hf_server(inference_server) - from gradio_utils.grclient import GradioClient from text_generation import Client as HFClient if isinstance(model, GradioClient): - gr_client = model + gr_client = model.clone() hf_client = None elif isinstance(model, HFClient): gr_client = None @@ -2617,6 +2953,7 @@ def evaluate( gen_server_kwargs = dict(temperature=temperature, top_p=top_p, top_k=top_k, + penalty_alpha=penalty_alpha, num_beams=num_beams, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, @@ -2685,13 +3022,19 @@ def evaluate( visible_models=visible_models, h2ogpt_key=h2ogpt_key, add_search_to_context=client_add_search_to_context, - docs_ordering_type=None, + docs_ordering_type=docs_ordering_type, min_max_new_tokens=min_max_new_tokens, + max_input_tokens=max_input_tokens, + docs_token_handling=docs_token_handling, + docs_joiner=docs_joiner, + hyde_level=hyde_level, + hyde_template=hyde_template, ) api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing response = '' text = '' sources = '' + strex = '' if not stream_output: res = gr_client.predict(str(dict(client_kwargs)), api_name=api_name) res_dict = ast.literal_eval(res) @@ -2699,15 +3042,17 @@ def evaluate( sources = res_dict['sources'] response = prompter.get_response(prompt + text, prompt=prompt, sanitize_bot_response=sanitize_bot_response) - yield dict(response=response, sources=sources, save_dict=dict()) + yield dict(response=response, sources=sources, save_dict=dict(), llm_answers={}) else: + from gradio_utils.grclient import check_job job = gr_client.submit(str(dict(client_kwargs)), api_name=api_name) - res_dict = dict(response=text, sources=sources, save_dict=dict()) + res_dict = dict(response=text, sources=sources, save_dict=dict(), llm_answers={}) text0 = '' + tgen0 = time.time() while not job.done(): if job.communicator.job.latest_status.code.name == 'FINISHED': break - e = job.future._exception + e = check_job(job, timeout=0, raise_exception=False) if e is not None: break outputs_list = job.communicator.job.outputs @@ -2725,21 +3070,34 @@ def evaluate( sanitize_bot_response=sanitize_bot_response) text_chunk = response[len(text0):] if not text_chunk: + # just need some sleep for threads to switch + time.sleep(0.001) continue # save old text0 = response - yield dict(response=response, sources=sources, save_dict=dict()) + yield dict(response=response, sources=sources, save_dict=dict(), llm_answers={}) + if time.time() - tgen0 > max_time: + if verbose: + print("Took too long for Gradio: %s" % (time.time() - tgen0), flush=True) + break time.sleep(0.01) # ensure get last output to avoid race res_all = job.outputs() if len(res_all) > 0: + # don't raise unless nochat API for now + e = check_job(job, timeout=0.02, raise_exception=not chat) + if e is not None: + strex = ''.join(traceback.format_tb(e.__traceback__)) + res = res_all[-1] res_dict = ast.literal_eval(res) text = res_dict['response'] sources = res_dict['sources'] else: + # if got no answer at all, probably something bad, always raise exception + # UI will still put exception in Chat History under chat exceptions + e = check_job(job, timeout=0.3, raise_exception=True) # go with old text if last call didn't work - e = job.future._exception if e is not None: stre = str(e) strex = ''.join(traceback.format_tb(e.__traceback__)) @@ -2757,7 +3115,7 @@ def evaluate( prompt_and_text = prompt + text response = prompter.get_response(prompt_and_text, prompt=prompt, sanitize_bot_response=sanitize_bot_response) - yield dict(response=response, sources=sources, save_dict=dict()) + yield dict(response=response, sources=sources, save_dict=dict(), error=strex, llm_answers={}) elif hf_client: # HF inference server needs control over input tokens where_from = "hf_client" @@ -2793,8 +3151,9 @@ def evaluate( text = hf_client.generate(prompt, **gen_server_kwargs).generated_text response = prompter.get_response(prompt + text, prompt=prompt, sanitize_bot_response=sanitize_bot_response) - yield dict(response=response, sources=sources, save_dict=dict()) + yield dict(response=response, sources=sources, save_dict=dict(), llm_answers={}) else: + tgen0 = time.time() text = "" for responses in hf_client.generate_stream(prompt, **gen_server_kwargs): if not responses.token.special: @@ -2804,7 +3163,11 @@ def evaluate( response = prompter.get_response(prompt + text, prompt=prompt, sanitize_bot_response=sanitize_bot_response) sources = '' - yield dict(response=response, sources=sources, save_dict=dict()) + yield dict(response=response, sources=sources, save_dict=dict(), llm_answers={}) + if time.time() - tgen0 > max_time: + if verbose: + print("Took too long for TGI: %s" % (time.time() - tgen0), flush=True) + break else: raise RuntimeError("Failed to get client: %s" % inference_server) else: @@ -2820,7 +3183,7 @@ def evaluate( )) save_dict = dict(prompt=prompt, output=text, base_model=base_model, save_dir=save_dir, where_from=where_from, extra_dict=extra_dict) - yield dict(response=response, sources=sources, save_dict=save_dict) + yield dict(response=response, sources=sources, save_dict=save_dict, llm_answers={}) return else: assert not inference_server, "inference_server=%s not supported" % inference_server @@ -2833,7 +3196,8 @@ def evaluate( raise RuntimeError("No such task type %s" % tokenizer) # NOTE: uses max_length only sources = '' - yield dict(response=model(prompt, max_length=max_new_tokens)[0][key], sources=sources, save_dict=dict()) + yield dict(response=model(prompt, max_length=max_new_tokens)[0][key], sources=sources, save_dict=dict(), + llm_answers={}) if 'mbart-' in base_model.lower(): assert src_lang is not None @@ -2841,15 +3205,20 @@ def evaluate( stopping_criteria = get_stopping(prompt_type, prompt_dict, tokenizer, device, base_model, model_max_length=model_max_length, - prompter=prompter) + prompter=prompter, + truncation_generation=truncation_generation) inputs = tokenizer(prompt, return_tensors="pt") if debug and len(inputs["input_ids"]) > 0: print('input_ids length', len(inputs["input_ids"][0]), flush=True) input_ids = inputs["input_ids"].to(device) # CRITICAL LIMIT else will fail - max_max_tokens = tokenizer.model_max_length - max_input_tokens = max(0, int(max_max_tokens - min_new_tokens)) + max_max_tokens = int(tokenizer.model_max_length) + max_input_tokens_default = max(0, int(max_max_tokens - min_new_tokens)) + if max_input_tokens >= 0: + max_input_tokens = min(max_input_tokens_default, max_input_tokens) + else: + max_input_tokens = max_input_tokens_default # NOTE: Don't limit up front due to max_new_tokens, let go up to max or reach max_max_tokens in stopping.py assert isinstance(max_input_tokens, int), "Bad type for max_input_tokens=%s %s" % ( max_input_tokens, type(max_input_tokens)) @@ -2857,6 +3226,9 @@ def evaluate( # required for falcon if multiple threads or asyncio accesses to model during generation if use_cache is None: use_cache = False if 'falcon' in base_model else True + if attention_sinks: + assert use_cache, "attention sinks requires use_cache=True" + bad_word_ids = [tokenizer.eos_token_id] gen_config_kwargs = dict(num_beams=num_beams, do_sample=do_sample, repetition_penalty=float(repetition_penalty), @@ -2864,11 +3236,14 @@ def evaluate( renormalize_logits=True, remove_invalid_values=True, use_cache=use_cache, + max_new_tokens=max_new_tokens, # unsure if required here ) if do_sample: gen_config_kwargs.update(dict(temperature=float(temperature), top_p=float(top_p), top_k=top_k)) + if penalty_alpha > 0: + gen_config_kwargs.update(dict(penalty_alpha=penalty_alpha)) if True: # unclear impact, some odd things going on inside # leads to: @@ -2944,9 +3319,10 @@ def evaluate( bucket = queue.Queue() thread = EThread(target=target, streamer=streamer, bucket=bucket) thread.start() - ret = dict(response='', sources='', save_dict=dict()) + ret = dict(response='', sources='', save_dict=dict(), llm_answers={}) outputs = "" sources = '' + tgen0 = time.time() try: for new_text in streamer: if bucket.qsize() > 0 or thread.exc: @@ -2955,11 +3331,15 @@ def evaluate( response = prompter.get_response(outputs, prompt=None, only_new_text=True, sanitize_bot_response=sanitize_bot_response) - ret = dict(response=response, sources=sources, save_dict=dict()) + ret = dict(response=response, sources=sources, save_dict=dict(), llm_answers={}) if stream_output: yield ret - if not stream_output: - yield ret + if time.time() - tgen0 > max_time: + if verbose: + print("Took too long for Torch: %s" % (time.time() - tgen0), flush=True) + break + # yield if anything left over as can happen (FIXME: Understand better) + yield ret except BaseException: # if any exception, raise that exception if was from thread, first if thread.exc: @@ -2990,7 +3370,7 @@ def evaluate( response = prompter.get_response(outputs, prompt=None, only_new_text=True, sanitize_bot_response=sanitize_bot_response) - yield dict(response=response, sources=sources, save_dict=dict()) + yield dict(response=response, sources=sources, save_dict=dict(), llm_answers={}) if outputs and len(outputs) >= 1: decoded_output = prompt + outputs[0] if save_dir and decoded_output: @@ -3003,7 +3383,7 @@ def evaluate( save_dict = dict(prompt=prompt, output=decoded_output, base_model=base_model, save_dir=save_dir, where_from="evaluate_%s" % str(stream_output), extra_dict=extra_dict) - yield dict(response=response, sources=sources, save_dict=save_dict) + yield dict(response=response, sources=sources, save_dict=save_dict, llm_answers={}) if verbose: print('Post-Generate: %s decoded_output: %s' % ( str(datetime.now()), len(decoded_output) if decoded_output else -1), flush=True) @@ -3014,7 +3394,7 @@ state_names = input_args_list.copy() # doesn't have to be the same, but state_n inputs_kwargs_list = [x for x in inputs_list_names if x not in eval_func_param_names + state_names] -def get_cutoffs(memory_restriction_level, for_context=False, model_max_length=2048): +def get_cutoffs(memory_restriction_level, for_context=False, model_max_length=2048, min_max_new_tokens=256): # help to avoid errors like: # RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3 # RuntimeError: expected scalar type Half but found Float @@ -3023,7 +3403,7 @@ def get_cutoffs(memory_restriction_level, for_context=False, model_max_length=20 max_length_tokenize = 768 - 256 if memory_restriction_level <= 2 else 512 - 256 else: # at least give room for 1 paragraph output - max_length_tokenize = model_max_length - 256 + max_length_tokenize = model_max_length - min_max_new_tokens cutoff_len = max_length_tokenize * 4 # if reaches limit, then can't generate new tokens output_smallest = 30 * 4 max_prompt_length = cutoff_len - output_smallest @@ -3165,7 +3545,7 @@ def get_generate_params(model_lower, system_prompt, pre_prompt_query, prompt_query, pre_prompt_summary, prompt_summary, - temperature, top_p, top_k, num_beams, + temperature, top_p, top_k, penalty_alpha, num_beams, max_new_tokens, min_new_tokens, early_stopping, max_time, repetition_penalty, num_return_sequences, do_sample, @@ -3176,6 +3556,11 @@ def get_generate_params(model_lower, jq_schema, docs_ordering_type, min_max_new_tokens, + max_input_tokens, + docs_token_handling, + docs_joiner, + hyde_level, + hyde_template, verbose, ): use_defaults = False @@ -3190,7 +3575,7 @@ def get_generate_params(model_lower, min_new_tokens = min_new_tokens if min_new_tokens is not None else 0 early_stopping = early_stopping if early_stopping is not None else False - max_time_defaults = 60 * 3 + max_time_defaults = 60 * 10 max_time = max_time if max_time is not None else max_time_defaults if not prompt_type and model_lower in inv_prompt_type_to_model_lower and prompt_type != 'custom': @@ -3270,6 +3655,7 @@ Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-pa temperature = 1.0 if temperature is None else temperature top_p = 1.0 if top_p is None else top_p top_k = 40 if top_k is None else top_k + penalty_alpha = 0 if penalty_alpha is None else penalty_alpha num_beams = num_beams or 1 max_new_tokens = max_new_tokens or 512 repetition_penalty = repetition_penalty or 1.07 @@ -3279,6 +3665,7 @@ Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-pa temperature = 0.1 if temperature is None else temperature top_p = 0.75 if top_p is None else top_p top_k = 40 if top_k is None else top_k + penalty_alpha = 0 if penalty_alpha is None else penalty_alpha num_beams = num_beams or 1 max_new_tokens = max_new_tokens or 1024 repetition_penalty = repetition_penalty or 1.07 @@ -3288,7 +3675,7 @@ Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-pa params_list = ["", stream_output, prompt_type, prompt_dict, - temperature, top_p, top_k, num_beams, + temperature, top_p, top_k, penalty_alpha, num_beams, max_new_tokens, min_new_tokens, early_stopping, max_time, repetition_penalty, num_return_sequences, do_sample] @@ -3357,6 +3744,11 @@ y = np.random.randint(0, 1, 100) None, docs_ordering_type, min_max_new_tokens, + max_input_tokens, + docs_token_handling, + docs_joiner, + hyde_level, + hyde_template, ] # adjust examples if non-chat mode if not chat: @@ -3382,7 +3774,7 @@ y = np.random.randint(0, 1, 100) return placeholder_instruction, placeholder_input, \ stream_output, show_examples, \ prompt_type, prompt_dict, \ - temperature, top_p, top_k, num_beams, \ + temperature, top_p, top_k, penalty_alpha, num_beams, \ max_new_tokens, min_new_tokens, early_stopping, max_time, \ repetition_penalty, num_return_sequences, \ do_sample, \ @@ -3460,9 +3852,20 @@ def get_model_max_length(model_state): return 2048 +def get_model_max_length_from_tokenizer(tokenizer): + if hasattr(tokenizer, 'model_max_length'): + return int(tokenizer.model_max_length) + else: + return 2048 + + def get_max_max_new_tokens(model_state, **kwargs): - if not isinstance(model_state['tokenizer'], (str, type(None))): - max_max_new_tokens = model_state['tokenizer'].model_max_length + if not isinstance(model_state['tokenizer'], (str, type(None))) or not kwargs.get('truncation_generation', False): + if hasattr(model_state['tokenizer'], 'model_max_length'): + max_max_new_tokens = model_state['tokenizer'].model_max_length + else: + # e.g. fast up, no model + max_max_new_tokens = None else: max_max_new_tokens = None @@ -3482,14 +3885,14 @@ def get_max_max_new_tokens(model_state, **kwargs): def get_minmax_top_k_docs(is_public): + label_top_k_docs = "Number of document chunks (query) or pages/parts (summarize)" if is_public: min_top_k_docs = 1 max_top_k_docs = 8 - label_top_k_docs = "Number of document chunks" else: min_top_k_docs = -1 max_top_k_docs = 100 - label_top_k_docs = "Number of document chunks (-1 = auto fill model context)" + label_top_k_docs = label_top_k_docs + " (-1 = auto fill model context, all pages/docs for summarize)" return min_top_k_docs, max_top_k_docs, label_top_k_docs @@ -3517,7 +3920,8 @@ def history_to_context(history, langchain_mode=None, add_chat_history_to_context=None, prompt_type=None, prompt_dict=None, chat=None, model_max_length=None, memory_restriction_level=None, keep_sources_in_context=None, - system_prompt=None, chat_conversation=None): + system_prompt=None, chat_conversation=None, + min_max_new_tokens=256): """ consumes all history up to (but not including) latest history item that is presumed to be an [instruction, None] pair :param history: @@ -3531,6 +3935,7 @@ def history_to_context(history, langchain_mode=None, :param keep_sources_in_context: :param system_prompt: :param chat_conversation: + :param min_max_new_tokens: :return: """ history = merge_chat_conversation_history(chat_conversation, history) @@ -3543,7 +3948,8 @@ def history_to_context(history, langchain_mode=None, # ensure output will be unique to models _, _, _, max_prompt_length = get_cutoffs(memory_restriction_level, - for_context=True, model_max_length=model_max_length) + for_context=True, model_max_length=model_max_length, + min_max_new_tokens=min_max_new_tokens) context1 = '' if max_prompt_length is not None and add_chat_history_to_context: context1 = '' @@ -3587,9 +3993,23 @@ def history_to_context(history, langchain_mode=None, return context1 +def get_relaxed_max_new_tokens(prompt, tokenizer=None, max_new_tokens=None, max_new_tokens0=None): + # check if can relax max_new_tokens for this specific prompt + if max_new_tokens0 is not None and \ + hasattr(tokenizer, 'model_max_len') and \ + isinstance(tokenizer.model_max_len, (float, int)): + max_new_tokens = int(tokenizer.model_max_length) - get_token_count(prompt, tokenizer) + if max_new_tokens is not None: + return min(max_new_tokens0, max_new_tokens) + else: + return max_new_tokens0 + return max_new_tokens + + def get_limited_prompt(instruction, iinput, tokenizer, + estimated_instruction=None, prompter=None, inference_server=None, prompt_type=None, prompt_dict=None, chat=False, max_new_tokens=None, @@ -3601,7 +4021,27 @@ def get_limited_prompt(instruction, verbose=False, doc_importance=0.5, min_max_new_tokens=256, + max_input_tokens=-1, + truncation_generation=False, + gradio_server=False, ): + if gradio_server or not inference_server: + # can listen to truncation_generation + pass + else: + # these don't support allowing going beyond total context + truncation_generation = True + + # for templates, use estimated for counting, but adjust instruction as output + if estimated_instruction is None: + estimated_instruction = instruction + + if max_input_tokens >= 0: + # max_input_tokens is used to runtime (via client/UI) to control actual filling of context + max_input_tokens = min(model_max_length - min_max_new_tokens, max_input_tokens) + else: + max_input_tokens = model_max_length - min_max_new_tokens + if prompter: prompt_type = prompter.prompt_type prompt_dict = prompter.prompt_dict @@ -3609,37 +4049,67 @@ def get_limited_prompt(instruction, stream_output = prompter.stream_output system_prompt = prompter.system_prompt + generate_prompt_type = prompt_type + external_handle_chat_conversation = False + if inference_server and any( + inference_server.startswith(x) for x in ['openai_chat', 'openai_azure_chat', 'vllm_chat']): + # Chat APIs do not take prompting + # Replicate does not need prompting if no chat history, but in general can take prompting + # if using prompter, prompter.system_prompt will already be filled with automatic (e.g. from llama-2), + # so if replicate final prompt with system prompt still correct because only access prompter.system_prompt that was already set + # below already true for openai, + # but not vllm by default as that can be any model and handled by FastChat API inside vLLM itself + generate_prompt_type = 'plain' + # Chat APIs don't handle chat history via single prompt, but in messages, assumed to be handled outside this function + chat_conversation = [] + external_handle_chat_conversation = True + # merge handles if chat_conversation is None history = [] history = merge_chat_conversation_history(chat_conversation, history) history_to_context_func = functools.partial(history_to_context, langchain_mode=langchain_mode, add_chat_history_to_context=add_chat_history_to_context, - prompt_type=prompt_type, + prompt_type=generate_prompt_type, prompt_dict=prompt_dict, chat=chat, - model_max_length=model_max_length, + model_max_length=max_input_tokens, memory_restriction_level=memory_restriction_level, keep_sources_in_context=keep_sources_in_context, - system_prompt=system_prompt) + system_prompt=system_prompt, + min_max_new_tokens=min_max_new_tokens) context2 = history_to_context_func(history) context1 = context if context1 is None: context1 = '' + # get how many more tokens in templated instruction, somewhat of estimate at fine level + num_instruction_tokens = get_token_count(instruction, tokenizer) + num_estimated_instruction_tokens = get_token_count(estimated_instruction, tokenizer) + delta_instruction = max(0, num_estimated_instruction_tokens - num_instruction_tokens) + + # get estimated templated instruction tokens for counting purposes from h2oai_pipeline import H2OTextGenerationPipeline - data_point_just_instruction = dict(context='', instruction=instruction, input='') - prompt_just_instruction = prompter.generate_prompt(data_point_just_instruction) - instruction, num_instruction_tokens = H2OTextGenerationPipeline.limit_prompt(instruction, tokenizer) - num_instruction_tokens_real = get_token_count(prompt_just_instruction, tokenizer) - num_instruction_tokens += (num_instruction_tokens_real - num_instruction_tokens) - - context1, num_context1_tokens = H2OTextGenerationPipeline.limit_prompt(context1, tokenizer) - context2, num_context2_tokens = H2OTextGenerationPipeline.limit_prompt(context2, tokenizer) - iinput, num_iinput_tokens = H2OTextGenerationPipeline.limit_prompt(iinput, tokenizer) + estimated_instruction, num_estimated_instruction_tokens = H2OTextGenerationPipeline.limit_prompt( + estimated_instruction, tokenizer, + max_prompt_length=max_input_tokens) + data_point_just_instruction = dict(context='', instruction=estimated_instruction, input='') + prompt_just_estimated_instruction = prompter.generate_prompt(data_point_just_instruction) + num_instruction_tokens = get_token_count(prompt_just_estimated_instruction, tokenizer) + + # get actual instruction, limited by template limitation + instruction, _ = H2OTextGenerationPipeline.limit_prompt(instruction, tokenizer, + max_prompt_length=max_input_tokens - delta_instruction) + + context1, num_context1_tokens = H2OTextGenerationPipeline.limit_prompt(context1, tokenizer, + max_prompt_length=max_input_tokens) + context2, num_context2_tokens = H2OTextGenerationPipeline.limit_prompt(context2, tokenizer, + max_prompt_length=max_input_tokens) + iinput, num_iinput_tokens = H2OTextGenerationPipeline.limit_prompt(iinput, tokenizer, + max_prompt_length=max_input_tokens) if text_context_list is None: text_context_list = [] - num_doc_tokens = sum([get_token_count(x + '\n\n', tokenizer) for x in text_context_list]) + num_doc_tokens = sum([get_token_count(x + docs_joiner_default, tokenizer) for x in text_context_list]) num_prompt_tokens0 = (num_instruction_tokens or 0) + \ (num_context1_tokens or 0) + \ @@ -3656,10 +4126,10 @@ def get_limited_prompt(instruction, # allowed residual is either half of what is allowed if doc exceeds half, or is rest of what doc didn't consume num_non_doc_tokens = num_prompt_tokens0 - num_doc_tokens # to doc first then non-doc, shouldn't matter much either way - doc_max_length = max(model_max_length - num_non_doc_tokens, doc_importance * model_max_length) + doc_max_length = max(max_input_tokens - num_non_doc_tokens, int(doc_importance * max_input_tokens)) top_k_docs, one_doc_size, num_doc_tokens = get_docs_tokens(tokenizer, text_context_list=text_context_list, max_input_tokens=doc_max_length) - non_doc_max_length = max(model_max_length - num_doc_tokens, (1.0 - doc_importance) * model_max_length) + non_doc_max_length = max(max_input_tokens - num_doc_tokens, int((1.0 - doc_importance) * max_input_tokens)) if num_non_doc_tokens > non_doc_max_length: # need to limit in some way, keep portion of history but all of context and instruction @@ -3668,10 +4138,10 @@ def get_limited_prompt(instruction, # 3) reduce context1 # 4) limit instruction so will fit diff1 = non_doc_max_length - ( - num_instruction_tokens + num_context1_tokens + num_context2_tokens + min_max_new_tokens) - diff2 = non_doc_max_length - (num_instruction_tokens + num_context1_tokens + min_max_new_tokens) - diff3 = non_doc_max_length - (num_instruction_tokens + min_max_new_tokens) - diff4 = non_doc_max_length - min_max_new_tokens + num_instruction_tokens + num_context1_tokens + num_context2_tokens) + diff2 = non_doc_max_length - (num_instruction_tokens + num_context1_tokens) + diff3 = non_doc_max_length - num_instruction_tokens + diff4 = non_doc_max_length if diff1 > 0: # then should be able to do #1 iinput = '' @@ -3687,7 +4157,7 @@ def get_limited_prompt(instruction, context2 = history_to_context_func(history[chat_index:]) num_context2_tokens = get_token_count(context2, tokenizer) diff1 = non_doc_max_length - ( - num_instruction_tokens + num_context1_tokens + num_context2_tokens + min_max_new_tokens) + num_instruction_tokens + num_context1_tokens + num_context2_tokens) if diff1 > 0: chat_index_final = chat_index if verbose: @@ -3716,13 +4186,14 @@ def get_limited_prompt(instruction, num_context1_tokens = 0 # diff4 accounts for real prompting for instruction # FIXME: history_to_context could include instruction, in case system prompt long, we overcount and could have more free tokens - instruction, num_instruction_tokens = H2OTextGenerationPipeline.limit_prompt(instruction, tokenizer, - max_prompt_length=diff4) - # get actual tokens + + max_prompt_length = max(0, diff4 - delta_instruction) + instruction, _ = H2OTextGenerationPipeline.limit_prompt(instruction, tokenizer, + max_prompt_length=max_prompt_length) + # get actual instruction tokens data_point_just_instruction = dict(context='', instruction=instruction, input='') prompt_just_instruction = prompter.generate_prompt(data_point_just_instruction) - num_instruction_tokens_real = get_token_count(prompt_just_instruction, tokenizer) - num_instruction_tokens += (num_instruction_tokens_real - num_instruction_tokens) + num_instruction_tokens = get_token_count(prompt_just_instruction, tokenizer) + delta_instruction # update full context context = context1 + context2 @@ -3734,20 +4205,24 @@ def get_limited_prompt(instruction, (num_doc_tokens or 0) # update max_new_tokens - if inference_server and inference_server.startswith('http'): - # assume TGI/Gradio setup to consume tokens and have long output too, even if exceeds model capacity. - pass - else: - # limit so max_new_tokens = prompt + new < max - # otherwise model can fail etc. e.g. for distilgpt2 asking for 1024 tokens is enough to fail if prompt=1 token + # limit so max_new_tokens = prompt + new < max + # otherwise model can fail etc. e.g. for distilgpt2 asking for 1024 tokens is enough to fail if prompt=1 token + if truncation_generation: max_new_tokens = min(max_new_tokens, model_max_length - num_prompt_tokens) + if os.getenv('HARD_ASSERTS'): + if max_new_tokens < min_max_new_tokens: + raise ValueError("Invalid max_new_tokens=%s" % max_new_tokens) + if prompter is None: # get prompter debug = False stream_output = False # doesn't matter prompter = Prompter(prompt_type, prompt_dict, debug=debug, chat=chat, stream_output=stream_output, system_prompt=system_prompt) + if prompt_type != generate_prompt_type: + # override just this attribute, keep system_prompt etc. from original prompt_type + prompter.prompt_type = generate_prompt_type data_point = dict(context=context, instruction=instruction, input=iinput) # handle promptA/promptB addition if really from history. @@ -3760,7 +4235,8 @@ def get_limited_prompt(instruction, return prompt, \ instruction, iinput, context, \ num_prompt_tokens, max_new_tokens, num_prompt_tokens0, num_prompt_tokens_actual, \ - chat_index, top_k_docs, one_doc_size + chat_index, external_handle_chat_conversation, \ + top_k_docs, one_doc_size, truncation_generation def get_docs_tokens(tokenizer, text_context_list=[], max_input_tokens=None): @@ -3768,7 +4244,7 @@ def get_docs_tokens(tokenizer, text_context_list=[], max_input_tokens=None): return 0, None, 0 if max_input_tokens is None: max_input_tokens = tokenizer.model_max_length - tokens = [get_token_count(x + '\n\n', tokenizer) for x in text_context_list] + tokens = [get_token_count(x + docs_joiner_default, tokenizer) for x in text_context_list] tokens_cumsum = np.cumsum(tokens) where_res = np.where(tokens_cumsum < max_input_tokens)[0] # if below condition fails, then keep top_k_docs=-1 and trigger special handling next @@ -3788,7 +4264,7 @@ def get_docs_tokens(tokenizer, text_context_list=[], max_input_tokens=None): max_prompt_length=max_input_tokens) text_context_list[0] = doc_content one_doc_size = len(doc_content) - num_doc_tokens = get_token_count(doc_content + '\n\n', tokenizer) + num_doc_tokens = get_token_count(doc_content + docs_joiner_default, tokenizer) print("Unexpected large chunks and can't add to context, will add 1 anyways. Tokens %s -> %s" % ( tokens[0], new_tokens0), flush=True) return top_k_docs, one_doc_size, num_doc_tokens