diff --git "a/gradio_runner.py" "b/gradio_runner.py" --- "a/gradio_runner.py" +++ "b/gradio_runner.py" @@ -1,20 +1,25 @@ import copy import functools import inspect +import itertools import json import os import pprint import random +import shutil import sys +import time import traceback +import typing import uuid import filelock import pandas as pd import requests import tabulate +from iterators import TimeoutIterator -from gradio_ui.css import get_css -from gradio_ui.prompt_form import make_prompt_form +from gradio_utils.css import get_css +from gradio_utils.prompt_form import make_prompt_form, make_chatbots # This is a hack to prevent Gradio from phoning home when it gets imported os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' @@ -32,29 +37,53 @@ import gradio as gr requests.get = original_get -from enums import DocumentChoices -from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js -from prompter import Prompter, \ - prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, generate_prompt, non_hf_types, \ + +def fix_pydantic_duplicate_validators_error(): + try: + from pydantic import class_validators + + class_validators.in_ipython = lambda: True # type: ignore[attr-defined] + except ImportError: + pass + + +fix_pydantic_duplicate_validators_error() + +from enums import DocumentChoices, no_model_str, no_lora_str, no_server_str, LangChainMode +from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js, spacing_xsm, radius_xsm, \ + text_xsm +from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, non_hf_types, \ get_prompt from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \ - ping, get_short_name, get_url, makedirs, get_kwargs + ping, get_short_name, get_url, makedirs, get_kwargs, remove, system_info, ping_gpu from generate import get_model, languages_covered, evaluate, eval_func_param_names, score_qa, langchain_modes, \ - inputs_kwargs_list, get_cutoffs, scratch_base_dir, evaluate_from_str, no_default_param_names, \ - eval_func_param_names_defaults, get_max_max_new_tokens + inputs_kwargs_list, scratch_base_dir, evaluate_from_str, no_default_param_names, \ + eval_func_param_names_defaults, get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context from apscheduler.schedulers.background import BackgroundScheduler -def fix_newlines(text): - # ensure good visually, else markdown ignores multiple \n - # handle code blocks - ts = text.split('```') - for parti, part in enumerate(ts): - inside = parti % 2 == 1 - if not inside: - ts[parti] = ts[parti].replace('\n', '
') - return '```'.join(ts) +def fix_text_for_gradio(text, fix_new_lines=False, fix_latex_dollars=True): + if fix_latex_dollars: + ts = text.split('```') + for parti, part in enumerate(ts): + inside = parti % 2 == 1 + if not inside: + ts[parti] = ts[parti].replace('$', '﹩') + text = '```'.join(ts) + + if fix_new_lines: + # let Gradio handle code, since got improved recently + ## FIXME: below conflicts with Gradio, but need to see if can handle multiple \n\n\n etc. properly as is. + # ensure good visually, else markdown ignores multiple \n + # handle code blocks + ts = text.split('```') + for parti, part in enumerate(ts): + inside = parti % 2 == 1 + if not inside: + ts[parti] = ts[parti].replace('\n', '
') + text = '```'.join(ts) + return text def go_gradio(**kwargs): @@ -65,6 +94,7 @@ def go_gradio(**kwargs): n_gpus = kwargs['n_gpus'] admin_pass = kwargs['admin_pass'] model_state0 = kwargs['model_state0'] + model_states = kwargs['model_states'] score_model_state0 = kwargs['score_model_state0'] dbs = kwargs['dbs'] db_type = kwargs['db_type'] @@ -93,11 +123,7 @@ def go_gradio(**kwargs): " use Enter for multiple input lines)" title = 'h2oGPT' - if 'h2ogpt-research' in kwargs['base_model']: - title += " [Research demonstration]" - more_info = """For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O-LLMStudio](https://github.com/h2oai/h2o-llmstudio)
""" - if is_public: - more_info += """""" + more_info = """h2oGPT H2O LLM Studio
🤗 Models""" if kwargs['verbose']: description = f"""Model {kwargs['base_model']} Instruct dataset. For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio). @@ -106,10 +132,10 @@ def go_gradio(**kwargs): """ else: description = more_info - description += "If this host is busy, try [12B](https://gpt.h2o.ai), [Falcon 40B](http://falcon.h2o.ai), [HF Spaces1 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) or [HF Spaces2 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)
" - description += """

By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/docs/tos.md)

""" + description_bottom = "If this host is busy, try [LLaMa 65B](https://llama.h2o.ai), [Falcon 40B](https://gpt.h2o.ai), [Falcon 40B](http://falcon.h2o.ai), [HF Spaces1 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) or [HF Spaces2 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)
" + description_bottom += """

By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/docs/tos.md)

""" if is_hf: - description += '''Duplicate Space''' + description_bottom += '''Duplicate Space''' if kwargs['verbose']: task_info_md = f""" @@ -119,33 +145,6 @@ def go_gradio(**kwargs): css_code = get_css(kwargs) - if kwargs['gradio_avoid_processing_markdown']: - from gradio_client import utils as client_utils - from gradio.components import Chatbot - - # gradio has issue with taking too long to process input/output for markdown etc. - # Avoid for now, allow raw html to render, good enough for chatbot. - def _postprocess_chat_messages(self, chat_message: str): - if chat_message is None: - return None - elif isinstance(chat_message, (tuple, list)): - filepath = chat_message[0] - mime_type = client_utils.get_mimetype(filepath) - filepath = self.make_temp_copy_if_needed(filepath) - return { - "name": filepath, - "mime_type": mime_type, - "alt_text": chat_message[1] if len(chat_message) > 1 else None, - "data": None, # These last two fields are filled in by the frontend - "is_file": True, - } - elif isinstance(chat_message, str): - return chat_message - else: - raise ValueError(f"Invalid message for Chatbot component: {chat_message}") - - Chatbot._postprocess_chat_messages = _postprocess_chat_messages - if kwargs['gradio_offline_level'] >= 0: # avoid GoogleFont that pulls from internet if kwargs['gradio_offline_level'] == 1: @@ -157,6 +156,17 @@ def go_gradio(**kwargs): font_mono=('IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace')) else: theme_kwargs = dict() + if kwargs['gradio_size'] == 'xsmall': + theme_kwargs.update(dict(spacing_size=spacing_xsm, text_size=text_xsm, radius_size=radius_xsm)) + elif kwargs['gradio_size'] == 'small': + theme_kwargs.update(dict(spacing_size=gr.themes.sizes.spacing_sm, text_size=gr.themes.sizes.text_sm, + radius_size=gr.themes.sizes.spacing_sm)) + elif kwargs['gradio_size'] == 'large': + theme_kwargs.update(dict(spacing_size=gr.themes.sizes.spacing_lg, text_size=gr.themes.sizes.text_lg), + radius_size=gr.themes.sizes.spacing_lg) + elif kwargs['gradio_size'] == 'medium': + theme_kwargs.update(dict(spacing_size=gr.themes.sizes.spacing_md, text_size=gr.themes.sizes.text_md, + radius_size=gr.themes.sizes.spacing_md)) theme = H2oTheme(**theme_kwargs) if kwargs['h2ocolors'] else SoftTheme(**theme_kwargs) demo = gr.Blocks(theme=theme, css=css_code, title="h2oGPT", analytics_enabled=False) @@ -164,24 +174,36 @@ def go_gradio(**kwargs): model_options = flatten_list(list(prompt_type_to_model_name.values())) + kwargs['extra_model_options'] if kwargs['base_model'].strip() not in model_options: - lora_options = [kwargs['base_model'].strip()] + model_options + model_options = [kwargs['base_model'].strip()] + model_options lora_options = kwargs['extra_lora_options'] if kwargs['lora_weights'].strip() not in lora_options: lora_options = [kwargs['lora_weights'].strip()] + lora_options + server_options = kwargs['extra_server_options'] + if kwargs['inference_server'].strip() not in server_options: + server_options = [kwargs['inference_server'].strip()] + server_options + if os.getenv('OPENAI_API_KEY'): + if 'openai_chat' not in server_options: + server_options += ['openai_chat'] + if 'openai' not in server_options: + server_options += ['openai'] + # always add in no lora case # add fake space so doesn't go away in gradio dropdown - no_lora_str = no_model_str = '[None/Remove]' - lora_options = [no_lora_str] + kwargs['extra_lora_options'] # FIXME: why double? + model_options = [no_model_str] + model_options + lora_options = [no_lora_str] + lora_options + server_options = [no_server_str] + server_options # always add in no model case so can free memory # add fake space so doesn't go away in gradio dropdown - model_options = [no_model_str] + model_options # transcribe, will be detranscribed before use by evaluate() + if not kwargs['base_model'].strip(): + kwargs['base_model'] = no_model_str + if not kwargs['lora_weights'].strip(): kwargs['lora_weights'] = no_lora_str - if not kwargs['base_model'].strip(): - kwargs['base_model'] = no_model_str + if not kwargs['inference_server'].strip(): + kwargs['inference_server'] = no_server_str # transcribe for gradio kwargs['gpu_id'] = str(kwargs['gpu_id']) @@ -195,13 +217,28 @@ def go_gradio(**kwargs): for k in no_default_param_names: default_kwargs[k] = '' + def dummy_fun(x): + # need dummy function to block new input from being sent until output is done, + # else gets input_list at time of submit that is old, and shows up as truncated in chatbot + return x + with demo: # avoid actual model/tokenizer here or anything that would be bad to deepcopy # https://github.com/gradio-app/gradio/issues/3558 - model_state = gr.State(['model', 'tokenizer', kwargs['device'], kwargs['base_model']]) - model_state2 = gr.State([None, None, None, None]) + model_state = gr.State( + dict(model='model', tokenizer='tokenizer', device=kwargs['device'], + base_model=kwargs['base_model'], + tokenizer_base_model=kwargs['tokenizer_base_model'], + lora_weights=kwargs['lora_weights'], + inference_server=kwargs['inference_server'], + prompt_type=kwargs['prompt_type'], + prompt_dict=kwargs['prompt_dict'], + ) + ) + model_state2 = gr.State(kwargs['model_state_none'].copy()) model_options_state = gr.State([model_options]) lora_options_state = gr.State([lora_options]) + server_options_state = gr.State([server_options]) my_db_state = gr.State([None, None]) chat_state = gr.State({}) # make user default first and default choice, dedup @@ -210,19 +247,17 @@ def go_gradio(**kwargs): [docs_state0.append(x) for x in docs_state00 if x not in docs_state0] docs_state = gr.State(docs_state0) # first is chosen as default gr.Markdown(f""" - {get_h2o_title(title) if kwargs['h2ocolors'] else get_simple_title(title)} - - {description} - {task_info_md} + {get_h2o_title(title, description) if kwargs['h2ocolors'] else get_simple_title(title, description)} """) - if is_hf: - gr.HTML( - ) # go button visible if base_wanted = kwargs['base_model'] != no_model_str and kwargs['login_mode_if_model0'] go_btn = gr.Button(value="ENTER", visible=base_wanted, variant="primary") + nas = ' '.join(['NA'] * len(kwargs['model_states'])) + res_value = "Response Score: NA" if not kwargs[ + 'model_lock'] else "Response Scores: %s" % nas + normal_block = gr.Row(visible=not base_wanted) with normal_block: with gr.Tabs(): @@ -230,7 +265,7 @@ def go_gradio(**kwargs): col_nochat = gr.Column(visible=not kwargs['chat']) with col_nochat: # FIXME: for model comparison, and check rest if kwargs['langchain_mode'] == 'Disabled': - text_output_nochat = gr.Textbox(lines=5, label=output_label0).style(show_copy_button=True) + text_output_nochat = gr.Textbox(lines=5, label=output_label0, show_copy_button=True) else: # text looks a bit worse, but HTML links work text_output_nochat = gr.HTML(label=output_label0) @@ -243,70 +278,31 @@ def go_gradio(**kwargs): placeholder=kwargs['placeholder_input']) submit_nochat = gr.Button("Submit") flag_btn_nochat = gr.Button("Flag") - if not kwargs['auto_score']: - with gr.Column(visible=kwargs['score_model']): - score_btn_nochat = gr.Button("Score last prompt & response") - score_text_nochat = gr.Textbox("Response Score: NA", show_label=False) - else: - with gr.Column(visible=kwargs['score_model']): - score_text_nochat = gr.Textbox("Response Score: NA", show_label=False) + with gr.Column(visible=kwargs['score_model']): + score_text_nochat = gr.Textbox("Response Score: NA", show_label=False) col_chat = gr.Column(visible=kwargs['chat']) with col_chat: - with gr.Row(): - text_output = gr.Chatbot(label=output_label0).style(height=kwargs['height'] or 400) - text_output2 = gr.Chatbot(label=output_label0_model2, visible=False).style( - height=kwargs['height'] or 400) - instruction, submit, stop_btn = make_prompt_form(kwargs) + text_output, text_output2, text_outputs = make_chatbots(output_label0, output_label0_model2, + **kwargs) with gr.Row(): clear = gr.Button("Save Chat / New Chat") flag_btn = gr.Button("Flag") - if not kwargs['auto_score']: # FIXME: For checkbox model2 - with gr.Column(visible=kwargs['score_model']): - with gr.Row(): - score_btn = gr.Button("Score last prompt & response").style( - full_width=False, size='sm') - score_text = gr.Textbox("Response Score: NA", show_label=False) - score_res2 = gr.Row(visible=False) - with score_res2: - score_btn2 = gr.Button("Score last prompt & response 2").style( - full_width=False, size='sm') - score_text2 = gr.Textbox("Response Score2: NA", show_label=False) - else: - with gr.Column(visible=kwargs['score_model']): - score_text = gr.Textbox("Response Score: NA", show_label=False) - score_text2 = gr.Textbox("Response Score2: NA", show_label=False, visible=False) - retry = gr.Button("Regenerate") + with gr.Column(visible=kwargs['score_model']): + score_text = gr.Textbox(res_value, + show_label=False, + visible=True) + score_text2 = gr.Textbox("Response Score2: NA", show_label=False, + visible=False and not kwargs['model_lock']) + retry_btn = gr.Button("Regenerate") undo = gr.Button("Undo") submit_nochat_api = gr.Button("Submit nochat API", visible=False) inputs_dict_str = gr.Textbox(label='API input for nochat', show_label=False, visible=False) - text_output_nochat_api = gr.Textbox(lines=5, label='API nochat output', visible=False).style( - show_copy_button=True) - with gr.TabItem("Chat"): - with gr.Row(): - if 'mbart-' in kwargs['model_lower']: - src_lang = gr.Dropdown(list(languages_covered().keys()), - value=kwargs['src_lang'], - label="Input Language") - tgt_lang = gr.Dropdown(list(languages_covered().keys()), - value=kwargs['tgt_lang'], - label="Output Language") - radio_chats = gr.Radio(value=None, label="Saved Chats", visible=True, interactive=True, - type='value') - with gr.Row(): - clear_chat_btn = gr.Button(value="Clear Chat", visible=True).style(size='sm') - export_chats_btn = gr.Button(value="Export Chats to Download").style(size='sm') - remove_chat_btn = gr.Button(value="Remove Selected Chat", visible=True).style(size='sm') - add_to_chats_btn = gr.Button("Import Chats from Upload").style(size='sm') - with gr.Row(): - chats_file = gr.File(interactive=False, label="Download Exported Chats") - chatsup_output = gr.File(label="Upload Chat File(s)", - file_types=['.json'], - file_count='multiple', - elem_id="warning", elem_classes="feedback") - with gr.TabItem("Data Source"): + text_output_nochat_api = gr.Textbox(lines=5, label='API nochat output', visible=False, + show_copy_button=True) + with gr.TabItem("Documents"): langchain_readme = get_url('https://github.com/h2oai/h2ogpt/blob/main/docs/README_LangChain.md', from_str=True) gr.HTML(value=f"""LangChain Support Disabled

@@ -346,12 +342,9 @@ def go_gradio(**kwargs): multiselect=True, ) with gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list): - get_sources_btn = gr.Button(value="Get Sources", - ).style(full_width=False, size='sm') - show_sources_btn = gr.Button(value="Show Sources", - ).style(full_width=False, size='sm') - refresh_sources_btn = gr.Button(value="Refresh Sources", - ).style(full_width=False, size='sm') + get_sources_btn = gr.Button(value="Get Sources", scale=0, size='sm') + show_sources_btn = gr.Button(value="Show Sources", scale=0, size='sm') + refresh_sources_btn = gr.Button(value="Refresh Sources", scale=0, size='sm') # import control if kwargs['langchain_mode'] != 'Disabled': @@ -360,8 +353,8 @@ def go_gradio(**kwargs): have_arxiv = False file_types = [] - upload_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload).style( - equal_height=False) + upload_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload, + equal_height=False) with upload_row: with gr.Column(): file_types_str = '[' + ' '.join(file_types) + ']' @@ -371,38 +364,50 @@ def go_gradio(**kwargs): elem_id="warning", elem_classes="feedback") with gr.Row(): add_to_shared_db_btn = gr.Button("Add File(s) to UserData", - visible=allow_upload_to_user_data, elem_id='small_btn') + visible=allow_upload_to_user_data, + elem_id='small_btn') add_to_my_db_btn = gr.Button("Add File(s) to Scratch MyData", - visible=allow_upload_to_my_data, + visible=allow_upload_to_my_data and + allow_upload_to_user_data, elem_id='small_btn' if allow_upload_to_user_data else None, - ).style( - size='sm' if not allow_upload_to_user_data else None) + size='sm' if not allow_upload_to_user_data else None) with gr.Column( visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload): url_label = 'URL (http/https) or ArXiv:' if have_arxiv else 'URL (http/https)' - url_text = gr.Textbox(label=url_label, interactive=True) + url_text = gr.Textbox(label=url_label, + placeholder="Click Add to Submit" if + allow_upload_to_my_data and + allow_upload_to_user_data else + "Enter to Submit", + max_lines=1, + interactive=True) with gr.Row(): url_user_btn = gr.Button(value='Add URL content to Shared UserData', - visible=allow_upload_to_user_data, elem_id='small_btn') + visible=allow_upload_to_user_data and allow_upload_to_my_data, + elem_id='small_btn') url_my_btn = gr.Button(value='Add URL content to Scratch MyData', - visible=allow_upload_to_my_data, + visible=allow_upload_to_my_data and allow_upload_to_user_data, elem_id='small_btn' if allow_upload_to_user_data else None, - ).style(size='sm' if not allow_upload_to_user_data else None) + size='sm' if not allow_upload_to_user_data else None) with gr.Column( visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_text_upload): - user_text_text = gr.Textbox(label='Paste Text [Shift-Enter more lines]', interactive=True) + user_text_text = gr.Textbox(label='Paste Text [Shift-Enter more lines]', + placeholder="Click Add to Submit" if + allow_upload_to_my_data and + allow_upload_to_user_data else + "Enter to Submit, Shift-Enter for more lines", + interactive=True) with gr.Row(): user_text_user_btn = gr.Button(value='Add Text to Shared UserData', - visible=allow_upload_to_user_data, + visible=allow_upload_to_user_data and allow_upload_to_my_data, elem_id='small_btn') user_text_my_btn = gr.Button(value='Add Text to Scratch MyData', - visible=allow_upload_to_my_data, + visible=allow_upload_to_my_data and allow_upload_to_user_data, elem_id='small_btn' if allow_upload_to_user_data else None, - ).style( - size='sm' if not allow_upload_to_user_data else None) + size='sm' if not allow_upload_to_user_data else None) with gr.Column(visible=False): # WIP: - with gr.Row(visible=False).style(equal_height=False): + with gr.Row(visible=False, equal_height=False): github_textbox = gr.Textbox(label="Github URL") with gr.Row(visible=True): github_shared_btn = gr.Button(value="Add Github to Shared UserData", @@ -410,18 +415,37 @@ def go_gradio(**kwargs): elem_id='small_btn') github_my_btn = gr.Button(value="Add Github to Scratch MyData", visible=allow_upload_to_my_data, elem_id='small_btn') - sources_row3 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list).style( - equal_height=False) - with sources_row3: + sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list, + equal_height=False) + with sources_row: with gr.Column(scale=1): file_source = gr.File(interactive=False, label="Download File w/Sources [click get sources to make file]") with gr.Column(scale=2): - pass - sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list).style( - equal_height=False) - with sources_row: - sources_text = gr.HTML(label='Sources Added', interactive=False) + sources_text = gr.HTML(label='Sources Added', interactive=False) + + with gr.TabItem("Chat History"): + with gr.Row(): + if 'mbart-' in kwargs['model_lower']: + src_lang = gr.Dropdown(list(languages_covered().keys()), + value=kwargs['src_lang'], + label="Input Language") + tgt_lang = gr.Dropdown(list(languages_covered().keys()), + value=kwargs['tgt_lang'], + label="Output Language") + radio_chats = gr.Radio(value=None, label="Saved Chats", visible=True, interactive=True, + type='value') + with gr.Row(): + clear_chat_btn = gr.Button(value="Clear Chat", visible=True, size='sm') + export_chats_btn = gr.Button(value="Export Chats to Download", size='sm') + remove_chat_btn = gr.Button(value="Remove Selected Chat", visible=True, size='sm') + add_to_chats_btn = gr.Button("Import Chats from Upload", size='sm') + with gr.Row(): + chats_file = gr.File(interactive=False, label="Download Exported Chats") + chatsup_output = gr.File(label="Upload Chat File(s)", + file_types=['.json'], + file_count='multiple', + elem_id="warning", elem_classes="feedback") with gr.TabItem("Expert"): with gr.Row(): @@ -430,22 +454,25 @@ def go_gradio(**kwargs): value=kwargs['stream_output']) prompt_type = gr.Dropdown(prompt_types_strings, value=kwargs['prompt_type'], label="Prompt Type", - visible=not is_public) + visible=not kwargs['model_lock'], + interactive=not is_public, + ) prompt_type2 = gr.Dropdown(prompt_types_strings, value=kwargs['prompt_type'], label="Prompt Type Model 2", - visible=not is_public and False) + visible=False and not kwargs['model_lock'], + interactive=not is_public) do_sample = gr.Checkbox(label="Sample", info="Enable sampler, required for use of temperature, top_p, top_k", value=kwargs['do_sample']) - temperature = gr.Slider(minimum=0.01, maximum=3, + temperature = gr.Slider(minimum=0.01, maximum=2, value=kwargs['temperature'], label="Temperature", info="Lower is deterministic (but may lead to repeats), Higher more creative (but may lead to hallucinations)") - top_p = gr.Slider(minimum=0, maximum=1, + top_p = gr.Slider(minimum=1e-3, maximum=1.0 - 1e-3, value=kwargs['top_p'], label="Top p", info="Cumulative probability of tokens to sample from") top_k = gr.Slider( - minimum=0, maximum=100, step=1, + minimum=1, maximum=100, step=1, value=kwargs['top_k'], label="Top k", info='Num. tokens to sample from' ) @@ -457,7 +484,8 @@ def go_gradio(**kwargs): num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1, value=min(max_beams, kwargs['num_beams']), label="Beams", info="Number of searches for optimal overall probability. " - "Uses more GPU memory/compute") + "Uses more GPU memory/compute", + interactive=False) max_max_new_tokens = get_max_max_new_tokens(model_state0, **kwargs) max_new_tokens = gr.Slider( minimum=1, maximum=max_max_new_tokens, step=1, @@ -470,20 +498,18 @@ def go_gradio(**kwargs): max_new_tokens2 = gr.Slider( minimum=1, maximum=max_max_new_tokens, step=1, value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length 2", - visible=False, + visible=False and not kwargs['model_lock'], ) min_new_tokens2 = gr.Slider( minimum=0, maximum=max_max_new_tokens, step=1, value=min(max_max_new_tokens, kwargs['min_new_tokens']), label="Min output length 2", - visible=False, + visible=False and not kwargs['model_lock'], ) early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search", value=kwargs['early_stopping']) - max_max_time = 60 * 20 if not is_public else 60 * 2 - if is_hf: - max_max_time = min(max_max_time, 60 * 1) - max_time = gr.Slider(minimum=0, maximum=max_max_time, step=1, - value=min(max_max_time, kwargs['max_time']), label="Max. time", + max_time = gr.Slider(minimum=0, maximum=kwargs['max_max_time'], step=1, + value=min(kwargs['max_max_time'], + kwargs['max_time']), label="Max. time", info="Max. time to search optimal output.") repetition_penalty = gr.Slider(minimum=0.01, maximum=3.0, value=kwargs['repetition_penalty'], @@ -491,105 +517,137 @@ def go_gradio(**kwargs): num_return_sequences = gr.Slider(minimum=1, maximum=10, step=1, value=kwargs['num_return_sequences'], label="Number Returns", info="Must be <= num_beams", - visible=not is_public) + interactive=not is_public) iinput = gr.Textbox(lines=4, label="Input", placeholder=kwargs['placeholder_input'], - visible=not is_public) + interactive=not is_public) context = gr.Textbox(lines=3, label="System Pre-Context", info="Directly pre-appended without prompt processing", - visible=not is_public) + interactive=not is_public) chat = gr.components.Checkbox(label="Chat mode", value=kwargs['chat'], - visible=not is_public) - count_chat_tokens_btn = gr.Button(value="Count Chat Tokens", visible=not is_public) + visible=not kwargs['model_lock'], + interactive=not is_public, + ) + count_chat_tokens_btn = gr.Button(value="Count Chat Tokens", + visible=not is_public and not kwargs['model_lock'], + interactive=not is_public) chat_token_count = gr.Textbox(label="Chat Token Count", value=None, - visible=not is_public, interactive=False) + visible=not is_public and not kwargs['model_lock'], + interactive=False) chunk = gr.components.Checkbox(value=kwargs['chunk'], label="Whether to chunk documents", info="For LangChain", - visible=not is_public) - top_k_docs = gr.Slider(minimum=0, maximum=100, step=1, + visible=kwargs['langchain_mode'] != 'Disabled', + interactive=not is_public) + min_top_k_docs, max_top_k_docs, label_top_k_docs = get_minmax_top_k_docs(is_public) + top_k_docs = gr.Slider(minimum=min_top_k_docs, maximum=max_top_k_docs, step=1, value=kwargs['top_k_docs'], - label="Number of document chunks", + label=label_top_k_docs, info="For LangChain", - visible=not is_public) + visible=kwargs['langchain_mode'] != 'Disabled', + interactive=not is_public) chunk_size = gr.Number(value=kwargs['chunk_size'], label="Chunk size for document chunking", info="For LangChain (ignored if chunk=False)", - visible=not is_public, + minimum=128, + maximum=2048, + visible=kwargs['langchain_mode'] != 'Disabled', + interactive=not is_public, precision=0) with gr.TabItem("Models"): + model_lock_msg = gr.Textbox(lines=1, label="Model Lock Notice", + placeholder="Started in model_lock mode, no model changes allowed.", + visible=bool(kwargs['model_lock']), interactive=False) load_msg = "Load-Unload Model/LORA [unload works if did not use --base_model]" if not is_public \ else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO" load_msg2 = "Load-Unload Model/LORA 2 [unload works if did not use --base_model]" if not is_public \ else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO 2" + variant_load_msg = 'primary' if not is_public else 'secondary' compare_checkbox = gr.components.Checkbox(label="Compare Mode", - value=False, visible=not is_public) + value=kwargs['model_lock'], + visible=not is_public and not kwargs['model_lock']) with gr.Row(): n_gpus_list = [str(x) for x in list(range(-1, n_gpus))] with gr.Column(): with gr.Row(): - with gr.Column(scale=20): + with gr.Column(scale=20, visible=not kwargs['model_lock']): model_choice = gr.Dropdown(model_options_state.value[0], label="Choose Model", value=kwargs['base_model']) lora_choice = gr.Dropdown(lora_options_state.value[0], label="Choose LORA", value=kwargs['lora_weights'], visible=kwargs['show_lora']) - with gr.Column(scale=1): - load_model_button = gr.Button(load_msg).style(full_width=False, size='sm') + server_choice = gr.Dropdown(server_options_state.value[0], label="Choose Server", + value=kwargs['inference_server'], visible=not is_public) + with gr.Column(scale=1, visible=not kwargs['model_lock']): + load_model_button = gr.Button(load_msg, variant=variant_load_msg, scale=0, + size='sm', interactive=not is_public) model_load8bit_checkbox = gr.components.Checkbox( label="Load 8-bit [requires support]", - value=kwargs['load_8bit']) + value=kwargs['load_8bit'], interactive=not is_public) model_infer_devices_checkbox = gr.components.Checkbox( label="Choose Devices [If not Checked, use all GPUs]", - value=kwargs['infer_devices']) + value=kwargs['infer_devices'], interactive=not is_public) model_gpu = gr.Dropdown(n_gpus_list, label="GPU ID [-1 = all GPUs, if Choose is enabled]", - value=kwargs['gpu_id']) + value=kwargs['gpu_id'], interactive=not is_public) model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'], interactive=False) lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'], visible=kwargs['show_lora'], interactive=False) + server_used = gr.Textbox(label="Current Server", + value=kwargs['inference_server'], + visible=bool(kwargs['inference_server']) and not is_public, + interactive=False) prompt_dict = gr.Textbox(label="Prompt (or Custom)", value=pprint.pformat(kwargs['prompt_dict'], indent=4), - interactive=True, lines=4) + interactive=not is_public, lines=4) col_model2 = gr.Column(visible=False) with col_model2: with gr.Row(): - with gr.Column(scale=20): + with gr.Column(scale=20, visible=not kwargs['model_lock']): model_choice2 = gr.Dropdown(model_options_state.value[0], label="Choose Model 2", value=no_model_str) lora_choice2 = gr.Dropdown(lora_options_state.value[0], label="Choose LORA 2", value=no_lora_str, visible=kwargs['show_lora']) - with gr.Column(scale=1): - load_model_button2 = gr.Button(load_msg2).style(full_width=False, size='sm') + server_choice2 = gr.Dropdown(server_options_state.value[0], label="Choose Server 2", + value=no_server_str, + visible=not is_public) + with gr.Column(scale=1, visible=not kwargs['model_lock']): + load_model_button2 = gr.Button(load_msg2, variant=variant_load_msg, scale=0, + size='sm', interactive=not is_public) model_load8bit_checkbox2 = gr.components.Checkbox( label="Load 8-bit 2 [requires support]", - value=kwargs['load_8bit']) + value=kwargs['load_8bit'], interactive=not is_public) model_infer_devices_checkbox2 = gr.components.Checkbox( label="Choose Devices 2 [If not Checked, use all GPUs]", value=kwargs[ - 'infer_devices']) + 'infer_devices'], interactive=not is_public) model_gpu2 = gr.Dropdown(n_gpus_list, label="GPU ID 2 [-1 = all GPUs, if choose is enabled]", - value=kwargs['gpu_id']) + value=kwargs['gpu_id'], interactive=not is_public) # no model/lora loaded ever in model2 by default - model_used2 = gr.Textbox(label="Current Model 2", value=no_model_str) + model_used2 = gr.Textbox(label="Current Model 2", value=no_model_str, + interactive=False) lora_used2 = gr.Textbox(label="Current LORA 2", value=no_lora_str, - visible=kwargs['show_lora']) + visible=kwargs['show_lora'], interactive=False) + server_used2 = gr.Textbox(label="Current Server 2", value=no_server_str, + interactive=False, + visible=not is_public) prompt_dict2 = gr.Textbox(label="Prompt (or Custom) 2", value=pprint.pformat(kwargs['prompt_dict'], indent=4), - interactive=True, lines=4) - with gr.Row(): + interactive=not is_public, lines=4) + with gr.Row(visible=not kwargs['model_lock']): with gr.Column(scale=50): - new_model = gr.Textbox(label="New Model HF name/path") - with gr.Row(): - add_model_button = gr.Button("Add new model name").style(full_width=False, size='sm') + new_model = gr.Textbox(label="New Model name/path", interactive=not is_public) + with gr.Column(scale=50): + new_lora = gr.Textbox(label="New LORA name/path", visible=kwargs['show_lora'], + interactive=not is_public) with gr.Column(scale=50): - new_lora = gr.Textbox(label="New LORA HF name/path", visible=kwargs['show_lora']) + new_server = gr.Textbox(label="New Server url:port", interactive=not is_public) with gr.Row(): - add_lora_button = gr.Button("Add new LORA name", visible=kwargs['show_lora']).style( - full_width=False, size='sm') + add_model_lora_server_button = gr.Button("Add new Model, Lora, Server url:port", scale=0, + size='sm', interactive=not is_public) with gr.TabItem("System"): admin_row = gr.Row() with admin_row: @@ -600,8 +658,17 @@ def go_gradio(**kwargs): with gr.Column(): with gr.Row(): system_btn = gr.Button(value='Get System Info') - system_text = gr.Textbox(label='System Info', interactive=False).style( - show_copy_button=True) + system_text = gr.Textbox(label='System Info', interactive=False, show_copy_button=True) + with gr.Row(): + system_input = gr.Textbox(label='System Info Dict Password', interactive=True, + visible=not is_public) + system_btn2 = gr.Button(value='Get System Info Dict', visible=not is_public) + system_text2 = gr.Textbox(label='System Info Dict', interactive=False, + visible=not is_public, show_copy_button=True) + with gr.Row(): + system_btn3 = gr.Button(value='Get Hash', visible=not is_public) + system_text3 = gr.Textbox(label='Hash', interactive=False, + visible=not is_public, show_copy_button=True) with gr.Row(): zip_btn = gr.Button("Zip") @@ -621,6 +688,11 @@ def go_gradio(**kwargs): description += """

  • By using h2oGPT, you accept our Terms of Service
  • """ gr.Markdown(value=description, show_label=False, interactive=False) + gr.Markdown(f""" + {description_bottom} + {task_info_md} + """) + # Get flagged data zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_dir']]) zip_btn.click(zip_data1, inputs=None, outputs=[file_output, zip_text], queue=False, @@ -628,20 +700,20 @@ def go_gradio(**kwargs): s3up_btn.click(s3up, inputs=zip_text, outputs=s3up_text, queue=False, api_name='s3up_data' if allow_api else None) - def make_add_visible(x): - return gr.update(visible=x is not None) - def clear_file_list(): return None - def make_invisible(): - return gr.update(visible=False) - - def make_visible(): - return gr.update(visible=True) + def make_non_interactive(*args): + if len(args) == 1: + return gr.update(interactive=False) + else: + return tuple([gr.update(interactive=False)] * len(args)) - def update_radio_to_user(): - return gr.update(value='UserData') + def make_interactive(*args): + if len(args) == 1: + return gr.update(interactive=True) + else: + return tuple([gr.update(interactive=True)] * len(args)) # Add to UserData update_user_db_func = functools.partial(update_user_db, @@ -653,44 +725,75 @@ def go_gradio(**kwargs): enable_ocr=enable_ocr, caption_loader=caption_loader, verbose=kwargs['verbose'], + user_path=kwargs['user_path'], ) + add_file_outputs = [fileup_output, langchain_mode, add_to_shared_db_btn, add_to_my_db_btn] + add_file_kwargs = dict(fn=update_user_db_func, + inputs=[fileup_output, my_db_state, add_to_shared_db_btn, + add_to_my_db_btn, + chunk, chunk_size], + outputs=add_file_outputs + [sources_text], + queue=queue, + api_name='add_to_shared' if allow_api and allow_upload_to_user_data else None) - # note for update_user_db_func output is ignored for db - add_to_shared_db_btn.click(update_user_db_func, - inputs=[fileup_output, my_db_state, add_to_shared_db_btn, add_to_my_db_btn, - chunk, chunk_size], - outputs=[add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue, - api_name='add_to_shared' if allow_api and allow_upload_to_user_data else None) \ - .then(clear_file_list, outputs=fileup_output, queue=queue) \ - .then(update_radio_to_user, inputs=None, outputs=langchain_mode, queue=False) + if allow_upload_to_user_data and not allow_upload_to_my_data: + func1 = fileup_output.change + else: + func1 = add_to_shared_db_btn.click + # then no need for add buttons, only single changeable db + eventdb1a = func1(make_non_interactive, inputs=add_file_outputs, outputs=add_file_outputs, + show_progress='minimal') + eventdb1 = eventdb1a.then(**add_file_kwargs, show_progress='minimal') + eventdb1.then(make_interactive, inputs=add_file_outputs, outputs=add_file_outputs, show_progress='minimal') - # .then(make_invisible, outputs=add_to_shared_db_btn, queue=queue) - # .then(make_visible, outputs=upload_button, queue=queue) + # note for update_user_db_func output is ignored for db def clear_textbox(): return gr.Textbox.update(value='') update_user_db_url_func = functools.partial(update_user_db_func, is_url=True) - url_user_btn.click(update_user_db_url_func, - inputs=[url_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn, - chunk, chunk_size], - outputs=[add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue, - api_name='add_url_to_shared' if allow_api and allow_upload_to_user_data else None) \ - .then(clear_textbox, outputs=url_text, queue=queue) \ - .then(update_radio_to_user, inputs=None, outputs=langchain_mode, queue=False) + + add_url_outputs = [url_text, langchain_mode, url_user_btn, url_my_btn] + add_url_kwargs = dict(fn=update_user_db_url_func, + inputs=[url_text, my_db_state, url_user_btn, url_my_btn, + chunk, chunk_size], + outputs=add_url_outputs + [sources_text], + queue=queue, + api_name='add_url_to_shared' if allow_api and allow_upload_to_user_data else None) + + if allow_upload_to_user_data and not allow_upload_to_my_data: + func2 = url_text.submit + else: + func2 = url_user_btn.click + eventdb2a = func2(fn=dummy_fun, inputs=url_text, outputs=url_text, queue=queue, + show_progress='minimal') + # work around https://github.com/gradio-app/gradio/issues/4733 + eventdb2b = eventdb2a.then(make_non_interactive, inputs=add_url_outputs, outputs=add_url_outputs, + show_progress='minimal') + eventdb2 = eventdb2b.then(**add_url_kwargs, show_progress='minimal') + eventdb2.then(make_interactive, inputs=add_url_outputs, outputs=add_url_outputs, show_progress='minimal') update_user_db_txt_func = functools.partial(update_user_db_func, is_txt=True) - user_text_user_btn.click(update_user_db_txt_func, - inputs=[user_text_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn, - chunk, chunk_size], - outputs=[add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue, - api_name='add_text_to_shared' if allow_api and allow_upload_to_user_data else None) \ - .then(clear_textbox, outputs=user_text_text, queue=queue) \ - .then(update_radio_to_user, inputs=None, outputs=langchain_mode, queue=False) + add_text_outputs = [user_text_text, langchain_mode, user_text_user_btn, user_text_my_btn] + add_text_kwargs = dict(fn=update_user_db_txt_func, + inputs=[user_text_text, my_db_state, user_text_user_btn, user_text_my_btn, + chunk, chunk_size], + outputs=add_text_outputs + [sources_text], + queue=queue, + api_name='add_text_to_shared' if allow_api and allow_upload_to_user_data else None + ) + if allow_upload_to_user_data and not allow_upload_to_my_data: + func3 = user_text_text.submit + else: + func3 = user_text_user_btn.click - # Add to MyData - def update_radio_to_my(): - return gr.update(value='MyData') + eventdb3a = func3(fn=dummy_fun, inputs=user_text_text, outputs=user_text_text, queue=queue, + show_progress='minimal') + eventdb3b = eventdb3a.then(make_non_interactive, inputs=add_text_outputs, outputs=add_text_outputs, + show_progress='minimal') + eventdb3 = eventdb3b.then(**add_text_kwargs, show_progress='minimal') + eventdb3.then(make_interactive, inputs=add_text_outputs, outputs=add_text_outputs, + show_progress='minimal') update_my_db_func = functools.partial(update_user_db, dbs=dbs, db_type=db_type, langchain_mode='MyData', use_openai_embedding=use_openai_embedding, @@ -700,35 +803,71 @@ def go_gradio(**kwargs): enable_ocr=enable_ocr, caption_loader=caption_loader, verbose=kwargs['verbose'], + user_path=kwargs['user_path'], ) - add_to_my_db_btn.click(update_my_db_func, - inputs=[fileup_output, my_db_state, add_to_shared_db_btn, add_to_my_db_btn, - chunk, chunk_size], - outputs=[my_db_state, add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue, - api_name='add_to_my' if allow_api and allow_upload_to_my_data else None) \ - .then(clear_file_list, outputs=fileup_output, queue=queue) \ - .then(update_radio_to_my, inputs=None, outputs=langchain_mode, queue=False) - # .then(make_invisible, outputs=add_to_shared_db_btn, queue=queue) - # .then(make_visible, outputs=upload_button, queue=queue) + add_my_file_outputs = [fileup_output, langchain_mode, my_db_state, add_to_shared_db_btn, add_to_my_db_btn] + add_my_file_kwargs = dict(fn=update_my_db_func, + inputs=[fileup_output, my_db_state, add_to_shared_db_btn, add_to_my_db_btn, + chunk, chunk_size], + outputs=add_my_file_outputs + [sources_text], + queue=queue, + api_name='add_to_my' if allow_api and allow_upload_to_my_data else None) + + if not allow_upload_to_user_data and allow_upload_to_my_data: + func4 = fileup_output.change + else: + func4 = add_to_my_db_btn.click + + eventdb4a = func4(make_non_interactive, inputs=add_my_file_outputs, + outputs=add_my_file_outputs, + show_progress='minimal') + eventdb4 = eventdb4a.then(**add_my_file_kwargs, show_progress='minimal') + eventdb4.then(make_interactive, inputs=add_my_file_outputs, outputs=add_my_file_outputs, + show_progress='minimal') update_my_db_url_func = functools.partial(update_my_db_func, is_url=True) - url_my_btn.click(update_my_db_url_func, - inputs=[url_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn, - chunk, chunk_size], - outputs=[my_db_state, add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue, - api_name='add_url_to_my' if allow_api and allow_upload_to_my_data else None) \ - .then(clear_textbox, outputs=url_text, queue=queue) \ - .then(update_radio_to_my, inputs=None, outputs=langchain_mode, queue=False) + add_my_url_outputs = [url_text, langchain_mode, my_db_state, url_user_btn, url_my_btn] + add_my_url_kwargs = dict(fn=update_my_db_url_func, + inputs=[url_text, my_db_state, url_user_btn, url_my_btn, + chunk, chunk_size], + outputs=add_my_url_outputs + [sources_text], + queue=queue, + api_name='add_url_to_my' if allow_api and allow_upload_to_my_data else None) + if not allow_upload_to_user_data and allow_upload_to_my_data: + func5 = url_text.submit + else: + func5 = url_my_btn.click + eventdb5a = func5(fn=dummy_fun, inputs=url_text, outputs=url_text, queue=queue, + show_progress='minimal') + eventdb5b = eventdb5a.then(make_non_interactive, inputs=add_my_url_outputs, outputs=add_my_url_outputs, + show_progress='minimal') + eventdb5 = eventdb5b.then(**add_my_url_kwargs, show_progress='minimal') + eventdb5.then(make_interactive, inputs=add_my_url_outputs, outputs=add_my_url_outputs, + show_progress='minimal') update_my_db_txt_func = functools.partial(update_my_db_func, is_txt=True) - user_text_my_btn.click(update_my_db_txt_func, - inputs=[user_text_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn, - chunk, chunk_size], - outputs=[my_db_state, add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue, - api_name='add_txt_to_my' if allow_api and allow_upload_to_my_data else None) \ - .then(clear_textbox, outputs=user_text_text, queue=queue) \ - .then(update_radio_to_my, inputs=None, outputs=langchain_mode, queue=False) + + add_my_text_outputs = [user_text_text, langchain_mode, my_db_state, user_text_user_btn, + user_text_my_btn] + add_my_text_kwargs = dict(fn=update_my_db_txt_func, + inputs=[user_text_text, my_db_state, user_text_user_btn, user_text_my_btn, + chunk, chunk_size], + outputs=add_my_text_outputs + [sources_text], + queue=queue, + api_name='add_txt_to_my' if allow_api and allow_upload_to_my_data else None) + if not allow_upload_to_user_data and allow_upload_to_my_data: + func6 = user_text_text.submit + else: + func6 = user_text_my_btn.click + + eventdb6a = func6(fn=dummy_fun, inputs=user_text_text, outputs=user_text_text, queue=queue, + show_progress='minimal') + eventdb6b = eventdb6a.then(make_non_interactive, inputs=add_my_text_outputs, outputs=add_my_text_outputs, + show_progress='minimal') + eventdb6 = eventdb6b.then(**add_my_text_kwargs, show_progress='minimal') + eventdb6.then(make_interactive, inputs=add_my_text_outputs, outputs=add_my_text_outputs, + show_progress='minimal') get_sources1 = functools.partial(get_sources, dbs=dbs, docs_state0=docs_state0) @@ -741,14 +880,15 @@ def go_gradio(**kwargs): def update_dropdown(x): return gr.Dropdown.update(choices=x, value=[docs_state0[0]]) - get_sources_btn.click(get_sources1, inputs=[my_db_state, langchain_mode], outputs=[file_source, docs_state], - queue=queue, - api_name='get_sources' if allow_api else None) \ + eventdb7 = get_sources_btn.click(get_sources1, inputs=[my_db_state, langchain_mode], + outputs=[file_source, docs_state], + queue=queue, + api_name='get_sources' if allow_api else None) \ .then(fn=update_dropdown, inputs=docs_state, outputs=document_choice) # show button, else only show when add. Could add to above get_sources for download/dropdown, but bit much maybe show_sources1 = functools.partial(get_source_files_given_langchain_mode, dbs=dbs) - show_sources_btn.click(fn=show_sources1, inputs=[my_db_state, langchain_mode], outputs=sources_text, - api_name='show_sources' if allow_api else None) + eventdb8 = show_sources_btn.click(fn=show_sources1, inputs=[my_db_state, langchain_mode], outputs=sources_text, + api_name='show_sources' if allow_api else None) # Get inputs to evaluate() and make_db() # don't deepcopy, can contain model itself @@ -759,8 +899,9 @@ def go_gradio(**kwargs): **get_kwargs(update_and_get_source_files_given_langchain_mode, exclude_names=['db1', 'langchain_mode'], **all_kwargs)) - refresh_sources_btn.click(fn=refresh_sources1, inputs=[my_db_state, langchain_mode], outputs=sources_text, - api_name='refresh_sources' if allow_api else None) + eventdb9 = refresh_sources_btn.click(fn=refresh_sources1, inputs=[my_db_state, langchain_mode], + outputs=sources_text, + api_name='refresh_sources' if allow_api else None) def check_admin_pass(x): return gr.update(visible=x == admin_pass) @@ -781,7 +922,10 @@ def go_gradio(**kwargs): def evaluate_gradio(*args1, **kwargs1): for res_dict in evaluate(*args1, **kwargs1): - yield '
    ' + fix_newlines(res_dict['response']) + if kwargs['langchain_mode'] == 'Disabled': + yield fix_text_for_gradio(res_dict['response']) + else: + yield '
    ' + fix_text_for_gradio(res_dict['response']) fun = partial(evaluate_gradio, **kwargs_evaluate) @@ -792,9 +936,7 @@ def go_gradio(**kwargs): **kwargs_evaluate ) - dark_mode_btn = gr.Button("Dark Mode", variant="primary").style( - size="sm", - ) + dark_mode_btn = gr.Button("Dark Mode", variant="primary", size="sm") # FIXME: Could add exceptions for non-chat but still streaming exception_text = gr.Textbox(value="", visible=kwargs['chat'], label='Chat Exceptions', interactive=False) dark_mode_btn.click( @@ -811,7 +953,7 @@ def go_gradio(**kwargs): return gr.Column.update(visible=not x) def col_chat_fun(x): - return gr.Column.update(visible=x) + return gr.Column.update(visible=bool(x)) def context_fun(x): return gr.Textbox.update(visible=not x) @@ -826,24 +968,47 @@ def go_gradio(**kwargs): gr.Examples(examples=kwargs['examples'], inputs=inputs_list) # Score - def score_last_response(*args, nochat=False, model2=False): + def score_last_response(*args, nochat=False, num_model_lock=0): + try: + if num_model_lock > 0: + # then lock way + args_list = list(args).copy() + outputs = args_list[-num_model_lock:] + score_texts1 = [] + for output in outputs: + # same input, put into form good for _score_last_response() + args_list[-1] = output + score_texts1.append( + _score_last_response(*tuple(args_list), nochat=nochat, + num_model_lock=num_model_lock, prefix='')) + if len(score_texts1) > 1: + return "Response Scores: %s" % ' '.join(score_texts1) + else: + return "Response Scores: %s" % score_texts1[0] + else: + return _score_last_response(*args, nochat=nochat, num_model_lock=num_model_lock) + finally: + clear_torch_cache() + + def _score_last_response(*args, nochat=False, num_model_lock=0, prefix='Response Score: '): """ Similar to user() """ args_list = list(args) + smodel = score_model_state0['model'] + stokenizer = score_model_state0['tokenizer'] + sdevice = score_model_state0['device'] if memory_restriction_level > 0: max_length_tokenize = 768 - 256 if memory_restriction_level <= 2 else 512 - 256 + elif hasattr(stokenizer, 'model_max_length'): + max_length_tokenize = stokenizer.model_max_length else: - max_length_tokenize = 2048 - 256 + # limit to 1024, not worth OOMing on reward score + max_length_tokenize = 2048 - 1024 cutoff_len = max_length_tokenize * 4 # restrict deberta related to max for LLM - smodel = score_model_state0[0] - stokenizer = score_model_state0[1] - sdevice = score_model_state0[2] + if not nochat: history = args_list[-1] if history is None: - if not model2: - # maybe only doing first model, no need to complain - print("Bad history in scoring last response, fix for now", flush=True) history = [] if smodel is not None and \ stokenizer is not None and \ @@ -857,23 +1022,23 @@ def go_gradio(**kwargs): answer = history[-1][1] else: - return 'Response Score: NA' + return '%sNA' % prefix else: answer = args_list[-1] instruction_nochat_arg_id = eval_func_param_names.index('instruction_nochat') question = args_list[instruction_nochat_arg_id] if question is None: - return 'Response Score: Bad Question' + return '%sBad Question' % prefix if answer is None: - return 'Response Score: Bad Answer' + return '%sBad Answer' % prefix try: score = score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_len) finally: clear_torch_cache() if isinstance(score, str): - return 'Response Score: NA' - return 'Response Score: {:.1%}'.format(score) + return '%sNA' % prefix + return '{}{:.1%}'.format(prefix, score) def noop_score_last_response(*args, **kwargs): return "Response Score: Disabled" @@ -887,24 +1052,22 @@ def go_gradio(**kwargs): inputs=inputs_list + [text_output], outputs=[score_text], ) - score_args2 = dict(fn=partial(score_fun, model2=True), + score_args2 = dict(fn=partial(score_fun), inputs=inputs_list2 + [text_output2], outputs=[score_text2], ) + score_fun_func = functools.partial(score_fun, num_model_lock=len(text_outputs)) + all_score_args = dict(fn=score_fun_func, + inputs=inputs_list + text_outputs, + outputs=score_text, + ) score_args_nochat = dict(fn=partial(score_fun, nochat=True), inputs=inputs_list + [text_output_nochat], outputs=[score_text_nochat], ) - if not kwargs['auto_score']: - score_event = score_btn.click(**score_args, queue=queue, api_name='score' if allow_api else None) \ - .then(**score_args2, queue=queue, api_name='score2' if allow_api else None) \ - .then(clear_torch_cache) - score_event_nochat = score_btn_nochat.click(**score_args_nochat, queue=queue, - api_name='score_nochat' if allow_api else None) \ - .then(clear_torch_cache) - def user(*args, undo=False, sanitize_user_prompt=True, model2=False): + def update_history(*args, undo=False, retry=False, sanitize_user_prompt=False): """ User that fills history for bot :param args: @@ -916,11 +1079,12 @@ def go_gradio(**kwargs): args_list = list(args) user_message = args_list[eval_func_param_names.index('instruction')] # chat only input1 = args_list[eval_func_param_names.index('iinput')] # chat only - context1 = args_list[eval_func_param_names.index('context')] prompt_type1 = args_list[eval_func_param_names.index('prompt_type')] - prompt_dict1 = args_list[eval_func_param_names.index('prompt_dict')] - chat1 = args_list[eval_func_param_names.index('chat')] - stream_output1 = args_list[eval_func_param_names.index('stream_output')] + if not prompt_type1: + # shouldn't have to specify if CLI launched model + prompt_type1 = kwargs['prompt_type'] + # apply back + args_list[eval_func_param_names.index('prompt_type')] = prompt_type1 if input1 and not user_message.endswith(':'): user_message1 = user_message + ":" + input1 elif input1: @@ -930,77 +1094,52 @@ def go_gradio(**kwargs): if sanitize_user_prompt: from better_profanity import profanity user_message1 = profanity.censor(user_message1) - # FIXME: WIP to use desired seperator when user enters nothing - prompter = Prompter(prompt_type1, prompt_dict1, debug=kwargs['debug'], chat=chat1, - stream_output=stream_output1) - if user_message1 in ['']: - # e.g. when user just hits enter in textbox, - # else will have : : on single line, which seems to be "ok" for LLM but not usual - user_message1 = '\n' - user_message1 = fix_newlines(user_message1) history = args_list[-1] - if undo and history: - history.pop() - args_list = args_list[:-1] # FYI, even if unused currently if history is None: - if not model2: - # no need to complain so often unless model1 - print("Bad history, fix for now", flush=True) + # bad history history = [] - # ensure elements not mixed across models as output, - # even if input is currently same source history = history.copy() + if undo: + if len(history) > 0: + history.pop() return history - else: - # FIXME: compare, same history for now - return history + [[user_message1, None]] - - def history_to_context(history, langchain_mode1, prompt_type1, prompt_dict1, chat1, model_max_length1): - # ensure output will be unique to models - _, _, _, max_prompt_length = get_cutoffs(memory_restriction_level, - for_context=True, model_max_length=model_max_length1) - history = copy.deepcopy(history) - - context1 = '' - if max_prompt_length is not None and langchain_mode1 not in ['LLM']: - context1 = '' - # - 1 below because current instruction already in history from user() - for histi in range(0, len(history) - 1): - data_point = dict(instruction=history[histi][0], input='', output=history[histi][1]) - prompt, pre_response, terminate_response, chat_sep = generate_prompt(data_point, - prompt_type1, - prompt_dict1, - chat1, reduced=True) - # md -> back to text, maybe not super important if model trained enough - if not kwargs['keep_sources_in_context']: - from gpt_langchain import source_prefix, source_postfix - import re - prompt = re.sub(f'{re.escape(source_prefix)}.*?{re.escape(source_postfix)}', '', prompt, - flags=re.DOTALL) - if prompt.endswith('\n

    '): - prompt = prompt[:-4] - prompt = prompt.replace('
    ', chat_sep) - if not prompt.endswith(chat_sep): - prompt += chat_sep - # most recent first, add older if can - # only include desired chat history - if len(prompt + context1) > max_prompt_length: - break - context1 = prompt + context1 + if retry: + if history: + history[-1][1] = None + return history + if user_message1 in ['', None, '\n']: + # reject non-retry submit/enter + return history + user_message1 = fix_text_for_gradio(user_message1) + return history + [[user_message1, None]] - _, pre_response, terminate_response, chat_sep = generate_prompt({}, prompt_type1, prompt_dict1, - chat1, reduced=True) - if context1 and not context1.endswith(chat_sep): - context1 += chat_sep # ensure if terminates abruptly, then human continues on next line - return context1 + def user(*args, undo=False, retry=False, sanitize_user_prompt=False): + return update_history(*args, undo=undo, retry=retry, sanitize_user_prompt=sanitize_user_prompt) + + def all_user(*args, undo=False, retry=False, sanitize_user_prompt=False, num_model_lock=0): + args_list = list(args) + history_list = args_list[-num_model_lock:] + assert len(history_list) > 0, "Bad history list: %s" % history_list + for hi, history in enumerate(history_list): + if num_model_lock > 0: + hargs = args_list[:-num_model_lock].copy() + else: + hargs = args_list.copy() + hargs += [history] + history_list[hi] = update_history(*hargs, undo=undo, retry=retry, + sanitize_user_prompt=sanitize_user_prompt) + if len(history_list) > 1: + return tuple(history_list) + else: + return history_list[0] def get_model_max_length(model_state1): - if model_state1 and not isinstance(model_state1[1], str): - tokenizer = model_state1[1] - elif model_state0 and not isinstance(model_state0[1], str): - tokenizer = model_state0[1] + if model_state1 and not isinstance(model_state1["tokenizer"], str): + tokenizer = model_state1["tokenizer"] + elif model_state0 and not isinstance(model_state0["tokenizer"], str): + tokenizer = model_state0["tokenizer"] else: tokenizer = None if tokenizer is not None: @@ -1008,61 +1147,82 @@ def go_gradio(**kwargs): else: return 2000 - def bot(*args, retry=False): + def prep_bot(*args, retry=False): """ - bot that consumes history for user input - instruction (from input_list) itself is not consumed by bot + :param args: :param retry: - :return: + :return: last element is True if should run bot, False if should just yield history """ # don't deepcopy, can contain model itself args_list = list(args).copy() model_state1 = args_list[-3] my_db_state1 = args_list[-2] history = args_list[-1] + langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')] - if model_state1[0] is None or model_state1[0] == no_model_str: - history = [] - yield history, '' - return + if model_state1['model'] is None or model_state1['model'] == no_model_str: + return history, None, None, None args_list = args_list[:-3] # only keep rest needed for evaluate() - langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')] - if retry and history: - history.pop() - if not args_list[eval_func_param_names.index('do_sample')]: - # if was not sampling, no point in retry unless change to sample - args_list[eval_func_param_names.index('do_sample')] = True if not history: print("No history", flush=True) history = [] - yield history, '' - return + return history, None, None, None instruction1 = history[-1][0] - if not instruction1: - # reject empty query, can sometimes go nuts - history = [] - yield history, '' - return - prompt_type1 = args_list[eval_func_param_names.index('prompt_type')] - prompt_dict1 = args_list[eval_func_param_names.index('prompt_dict')] + if retry and history: + # if retry, pop history and move onto bot stuff + instruction1 = history[-1][0] + history[-1][1] = None + elif not instruction1: + # if not retrying, then reject empty query + return history, None, None, None + elif len(history) > 0 and history[-1][1] not in [None, '']: + # reject submit button if already filled and not retrying + # None when not filling with '' to keep client happy + return history, None, None, None + + # shouldn't have to specify in API prompt_type if CLI launched model, so prefer global CLI one if have it + prompt_type1 = kwargs.get('prompt_type', args_list[eval_func_param_names.index('prompt_type')]) + # prefer model specific prompt type instead of global one, and apply back to args_list for evaluate() + args_list[eval_func_param_names.index('prompt_type')] = prompt_type1 = \ + model_state1.get('prompt_type', prompt_type1) + + prompt_dict1 = kwargs.get('prompt_dict', args_list[eval_func_param_names.index('prompt_dict')]) + args_list[eval_func_param_names.index('prompt_dict')] = prompt_dict1 = \ + model_state1.get('prompt_dict', prompt_dict1) + chat1 = args_list[eval_func_param_names.index('chat')] model_max_length1 = get_model_max_length(model_state1) context1 = history_to_context(history, langchain_mode1, prompt_type1, prompt_dict1, chat1, - model_max_length1) + model_max_length1, memory_restriction_level, + kwargs['keep_sources_in_context']) args_list[0] = instruction1 # override original instruction with history from user args_list[2] = context1 + fun1 = partial(evaluate, model_state1, my_db_state1, + *tuple(args_list), **kwargs_evaluate) + + return history, fun1, langchain_mode1, my_db_state1 + + def get_response(fun1, history): + """ + bot that consumes history for user input + instruction (from input_list) itself is not consumed by bot + :return: + """ + if not fun1: + yield history, '' + return try: - for output_fun in fun1(*tuple(args_list)): + for output_fun in fun1(): output = output_fun['response'] extra = output_fun['sources'] # FIXME: can show sources in separate text box etc. # ensure good visually, else markdown ignores multiple \n - bot_message = fix_newlines(output) + bot_message = fix_text_for_gradio(output) history[-1][1] = bot_message yield history, '' except StopIteration: @@ -1088,6 +1248,94 @@ def go_gradio(**kwargs): clear_torch_cache() return + def clear_embeddings(langchain_mode1, my_db): + # clear any use of embedding that sits on GPU, else keeps accumulating GPU usage even if clear torch cache + if db_type == 'chroma' and langchain_mode1 not in ['ChatLLM', 'LLM', 'Disabled', None, '']: + from gpt_langchain import clear_embedding + db = dbs.get('langchain_mode1') + if db is not None and not isinstance(db, str): + clear_embedding(db) + if langchain_mode1 == LangChainMode.MY_DATA.value and my_db is not None: + clear_embedding(my_db[0]) + + def bot(*args, retry=False): + history, fun1, langchain_mode1, my_db_state1 = prep_bot(*args, retry=retry) + try: + for res in get_response(fun1, history): + yield res + finally: + clear_embeddings(langchain_mode1, my_db_state1) + + def all_bot(*args, retry=False, model_states1=None): + args_list = list(args).copy() + chatbots = args_list[-len(model_states1):] + args_list0 = args_list[:-len(model_states1)] # same for all models + exceptions = [] + stream_output1 = args_list[eval_func_param_names.index('stream_output')] + max_time1 = args_list[eval_func_param_names.index('max_time')] + langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')] + my_db_state1 = None # will be filled below by some bot + try: + gen_list = [] + for chatbot1, model_state1 in zip(chatbots, model_states1): + args_list1 = args_list0.copy() + args_list1.insert(-1, model_state1) # insert at -1 so is at -2 + # if at start, have None in response still, replace with '' so client etc. acts like normal + # assumes other parts of code treat '' and None as if no response yet from bot + # can't do this later in bot code as racy with threaded generators + if len(chatbot1) > 0 and len(chatbot1[-1]) == 2 and chatbot1[-1][1] is None: + chatbot1[-1][1] = '' + args_list1.append(chatbot1) + # so consistent with prep_bot() + # with model_state1 at -3, my_db_state1 at -2, and history(chatbot) at -1 + # langchain_mode1 and my_db_state1 should be same for every bot + history, fun1, langchain_mode1, my_db_state1 = prep_bot(*tuple(args_list1), retry=retry) + gen1 = get_response(fun1, history) + if stream_output1: + gen1 = TimeoutIterator(gen1, timeout=0.01, sentinel=None, raise_on_exception=False) + # else timeout will truncate output for non-streaming case + gen_list.append(gen1) + + bots_old = chatbots.copy() + exceptions_old = [''] * len(bots_old) + tgen0 = time.time() + for res1 in itertools.zip_longest(*gen_list): + if time.time() - tgen0 > max_time1: + break + + bots = [x[0] if x is not None and not isinstance(x, BaseException) else y for x, y in + zip(res1, bots_old)] + bots_old = bots.copy() + + def larger_str(x, y): + return x if len(x) > len(y) else y + + exceptions = [x[1] if x is not None and not isinstance(x, BaseException) else larger_str(str(x), y) + for x, y in zip(res1, exceptions_old)] + exceptions_old = exceptions.copy() + + def choose_exc(x): + # don't expose ports etc. to exceptions window + if is_public: + return "Endpoint unavailable or failed" + else: + return x + + exceptions_str = '\n'.join( + ['Model %s: %s' % (iix, choose_exc(x)) for iix, x in enumerate(exceptions) if + x not in [None, '', 'None']]) + if len(bots) > 1: + yield tuple(bots + [exceptions_str]) + else: + yield bots[0], exceptions_str + if exceptions: + exceptions = [x for x in exceptions if x not in ['', None, 'None']] + if exceptions: + print("Generate exceptions: %s" % exceptions, flush=True) + finally: + clear_torch_cache() + clear_embeddings(langchain_mode1, my_db_state1) + # NORMAL MODEL user_args = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']), inputs=inputs_list + [text_output], @@ -1101,13 +1349,17 @@ def go_gradio(**kwargs): inputs=inputs_list + [model_state, my_db_state] + [text_output], outputs=[text_output, exception_text], ) + retry_user_args = dict(fn=functools.partial(user, retry=True), + inputs=inputs_list + [text_output], + outputs=text_output, + ) undo_user_args = dict(fn=functools.partial(user, undo=True), inputs=inputs_list + [text_output], outputs=text_output, ) # MODEL2 - user_args2 = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt'], model2=True), + user_args2 = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']), inputs=inputs_list2 + [text_output2], outputs=text_output2, ) @@ -1119,89 +1371,174 @@ def go_gradio(**kwargs): inputs=inputs_list2 + [model_state2, my_db_state] + [text_output2], outputs=[text_output2, exception_text], ) + retry_user_args2 = dict(fn=functools.partial(user, retry=True), + inputs=inputs_list2 + [text_output2], + outputs=text_output2, + ) undo_user_args2 = dict(fn=functools.partial(user, undo=True), inputs=inputs_list2 + [text_output2], outputs=text_output2, ) + # MODEL N + all_user_args = dict(fn=functools.partial(all_user, + sanitize_user_prompt=kwargs['sanitize_user_prompt'], + num_model_lock=len(text_outputs), + ), + inputs=inputs_list + text_outputs, + outputs=text_outputs, + ) + all_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states), + inputs=inputs_list + [my_db_state] + text_outputs, + outputs=text_outputs + [exception_text], + ) + all_retry_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states, retry=True), + inputs=inputs_list + [my_db_state] + text_outputs, + outputs=text_outputs + [exception_text], + ) + all_retry_user_args = dict(fn=functools.partial(all_user, retry=True, + sanitize_user_prompt=kwargs['sanitize_user_prompt'], + num_model_lock=len(text_outputs), + ), + inputs=inputs_list + text_outputs, + outputs=text_outputs, + ) + all_undo_user_args = dict(fn=functools.partial(all_user, undo=True, + sanitize_user_prompt=kwargs['sanitize_user_prompt'], + num_model_lock=len(text_outputs), + ), + inputs=inputs_list + text_outputs, + outputs=text_outputs, + ) + def clear_instruct(): return gr.Textbox.update(value='') - if kwargs['auto_score']: - score_args_submit = score_args - score_args2_submit = score_args2 - else: - score_args_submit = dict(fn=lambda: None, inputs=None, outputs=None) - score_args2_submit = dict(fn=lambda: None, inputs=None, outputs=None) - def deselect_radio_chats(): return gr.update(value=None) - # in case 2nd model, consume instruction first, so can clear quickly - # bot doesn't consume instruction itself, just history from user, so why works - submit_event1a = instruction.submit(**user_args, queue=queue, - api_name='instruction' if allow_api else None) - submit_event1b = submit_event1a.then(**user_args2, api_name='instruction2' if allow_api else None) - submit_event1c = submit_event1b.then(clear_instruct, None, instruction) \ - .then(clear_instruct, None, iinput) - submit_event1d = submit_event1c.then(**bot_args, api_name='instruction_bot' if allow_api else None, - queue=queue) - submit_event1d2 = submit_event1d.then(clear_torch_cache) - submit_event1e = submit_event1d2.then(**score_args_submit, - api_name='instruction_bot_score' if allow_api else None, - queue=queue) - submit_event1f = submit_event1e.then(**bot_args2, api_name='instruction_bot2' if allow_api else None, - queue=queue) - submit_event1f2 = submit_event1f.then(clear_torch_cache) - submit_event1g = submit_event1f2.then(**score_args2_submit, - api_name='instruction_bot_score2' if allow_api else None, queue=queue) - submit_event1h = submit_event1g.then(clear_torch_cache) - # if hit enter on new instruction for submitting new query, no longer the saved chat - submit_event1i = submit_event1h.then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False) - - submit_event2a = submit.click(**user_args, api_name='submit' if allow_api else None) - submit_event2b = submit_event2a.then(**user_args2, api_name='submit2' if allow_api else None) - submit_event2c = submit_event2b.then(clear_instruct, None, instruction) \ - .then(clear_instruct, None, iinput) - submit_event2d = submit_event2c.then(**bot_args, api_name='submit_bot' if allow_api else None, queue=queue) - submit_event2d2 = submit_event2d.then(clear_torch_cache) - submit_event2e = submit_event2d2.then(**score_args_submit, api_name='submit_bot_score' if allow_api else None, - queue=queue) - submit_event2f = submit_event2e.then(**bot_args2, api_name='submit_bot2' if allow_api else None, queue=queue) - submit_event2f2 = submit_event2f.then(clear_torch_cache) - submit_event2g = submit_event2f2.then(**score_args2_submit, api_name='submit_bot_score2' if allow_api else None, - queue=queue) - submit_event2h = submit_event2g.then(clear_torch_cache) - # if submit new query, no longer the saved chat - submit_event2i = submit_event2h.then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False) - - submit_event3a = retry.click(**user_args, api_name='retry' if allow_api else None) - submit_event3b = submit_event3a.then(**user_args2, api_name='retry2' if allow_api else None) - submit_event3c = submit_event3b.then(clear_instruct, None, instruction) \ - .then(clear_instruct, None, iinput) - submit_event3d = submit_event3c.then(**retry_bot_args, api_name='retry_bot' if allow_api else None, - queue=queue) - submit_event3d2 = submit_event3d.then(clear_torch_cache) - submit_event3e = submit_event3d2.then(**score_args_submit, api_name='retry_bot_score' if allow_api else None, - queue=queue) - submit_event3f = submit_event3e.then(**retry_bot_args2, api_name='retry_bot2' if allow_api else None, - queue=queue) - submit_event3f2 = submit_event3f.then(clear_torch_cache) - submit_event3g = submit_event3f2.then(**score_args2_submit, api_name='retry_bot_score2' if allow_api else None, - queue=queue) - submit_event3h = submit_event3g.then(clear_torch_cache) - # if retry, no longer the saved chat - submit_event3i = submit_event3h.then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False) - - # if undo, no longer the saved chat - submit_event4 = undo.click(**undo_user_args, api_name='undo' if allow_api else None) \ - .then(**undo_user_args2, api_name='undo2' if allow_api else None) \ - .then(clear_instruct, None, instruction) \ - .then(clear_instruct, None, iinput) \ - .then(**score_args_submit, api_name='undo_score' if allow_api else None) \ - .then(**score_args2_submit, api_name='undo_score2' if allow_api else None) \ - .then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False) \ - .then(clear_torch_cache) + def clear_all(): + return gr.Textbox.update(value=''), gr.Textbox.update(value=''), gr.update(value=None), \ + gr.Textbox.update(value=''), gr.Textbox.update(value='') + + if kwargs['model_states']: + submits1 = submits2 = submits3 = [] + submits4 = [] + + fun_source = [instruction.submit, submit.click, retry_btn.click] + fun_name = ['instruction', 'submit', 'retry'] + user_args = [all_user_args, all_user_args, all_retry_user_args] + bot_args = [all_bot_args, all_bot_args, all_retry_bot_args] + for userargs1, botarg1, funn1, funs1 in zip(user_args, bot_args, fun_name, fun_source): + submit_event11 = funs1(fn=dummy_fun, + inputs=instruction, outputs=instruction, queue=queue) + submit_event1a = submit_event11.then(**userargs1, queue=queue, + api_name='%s' % funn1 if allow_api else None) + # if hit enter on new instruction for submitting new query, no longer the saved chat + submit_event1b = submit_event1a.then(clear_all, inputs=None, + outputs=[instruction, iinput, radio_chats, score_text, + score_text2], + queue=queue) + submit_event1c = submit_event1b.then(**botarg1, + api_name='%s_bot' % funn1 if allow_api else None, + queue=queue) + submit_event1d = submit_event1c.then(**all_score_args, + api_name='%s_bot_score' % funn1 if allow_api else None, + queue=queue) + + submits1.extend([submit_event1a, submit_event1b, submit_event1c, submit_event1d]) + + # if undo, no longer the saved chat + submit_event4 = undo.click(fn=dummy_fun, + inputs=instruction, outputs=instruction, queue=queue) \ + .then(**all_undo_user_args, api_name='undo' if allow_api else None) \ + .then(clear_all, inputs=None, outputs=[instruction, iinput, radio_chats, score_text, + score_text2], queue=queue) \ + .then(**all_score_args, api_name='undo_score' if allow_api else None) + submits4 = [submit_event4] + + else: + # in case 2nd model, consume instruction first, so can clear quickly + # bot doesn't consume instruction itself, just history from user, so why works + submit_event11 = instruction.submit(fn=dummy_fun, + inputs=instruction, outputs=instruction, queue=queue) + submit_event1a = submit_event11.then(**user_args, queue=queue, + api_name='instruction' if allow_api else None) + # if hit enter on new instruction for submitting new query, no longer the saved chat + submit_event1a2 = submit_event1a.then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=queue) + submit_event1b = submit_event1a2.then(**user_args2, api_name='instruction2' if allow_api else None) + submit_event1c = submit_event1b.then(clear_instruct, None, instruction) \ + .then(clear_instruct, None, iinput) + submit_event1d = submit_event1c.then(**bot_args, api_name='instruction_bot' if allow_api else None, + queue=queue) + submit_event1e = submit_event1d.then(**score_args, + api_name='instruction_bot_score' if allow_api else None, + queue=queue) + submit_event1f = submit_event1e.then(**bot_args2, api_name='instruction_bot2' if allow_api else None, + queue=queue) + submit_event1g = submit_event1f.then(**score_args2, + api_name='instruction_bot_score2' if allow_api else None, queue=queue) + + submits1 = [submit_event1a, submit_event1a2, submit_event1b, submit_event1c, submit_event1d, + submit_event1e, + submit_event1f, submit_event1g] + + submit_event21 = submit.click(fn=dummy_fun, + inputs=instruction, outputs=instruction, queue=queue) + submit_event2a = submit_event21.then(**user_args, api_name='submit' if allow_api else None) + # if submit new query, no longer the saved chat + submit_event2a2 = submit_event2a.then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=queue) + submit_event2b = submit_event2a2.then(**user_args2, api_name='submit2' if allow_api else None) + submit_event2c = submit_event2b.then(clear_all, inputs=None, + outputs=[instruction, iinput, radio_chats, score_text, score_text2], + queue=queue) + submit_event2d = submit_event2c.then(**bot_args, api_name='submit_bot' if allow_api else None, queue=queue) + submit_event2e = submit_event2d.then(**score_args, + api_name='submit_bot_score' if allow_api else None, + queue=queue) + submit_event2f = submit_event2e.then(**bot_args2, api_name='submit_bot2' if allow_api else None, + queue=queue) + submit_event2g = submit_event2f.then(**score_args2, + api_name='submit_bot_score2' if allow_api else None, + queue=queue) + + submits2 = [submit_event2a, submit_event2a2, submit_event2b, submit_event2c, submit_event2d, + submit_event2e, + submit_event2f, submit_event2g] + + submit_event31 = retry_btn.click(fn=dummy_fun, + inputs=instruction, outputs=instruction, queue=queue) + submit_event3a = submit_event31.then(**user_args, api_name='retry' if allow_api else None) + # if retry, no longer the saved chat + submit_event3a2 = submit_event3a.then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=queue) + submit_event3b = submit_event3a2.then(**user_args2, api_name='retry2' if allow_api else None) + submit_event3c = submit_event3b.then(clear_instruct, None, instruction) \ + .then(clear_instruct, None, iinput) + submit_event3d = submit_event3c.then(**retry_bot_args, api_name='retry_bot' if allow_api else None, + queue=queue) + submit_event3e = submit_event3d.then(**score_args, + api_name='retry_bot_score' if allow_api else None, + queue=queue) + submit_event3f = submit_event3e.then(**retry_bot_args2, api_name='retry_bot2' if allow_api else None, + queue=queue) + submit_event3g = submit_event3f.then(**score_args2, + api_name='retry_bot_score2' if allow_api else None, + queue=queue) + + submits3 = [submit_event3a, submit_event3a2, submit_event3b, submit_event3c, submit_event3d, + submit_event3e, + submit_event3f, submit_event3g] + + # if undo, no longer the saved chat + submit_event4 = undo.click(fn=dummy_fun, + inputs=instruction, outputs=instruction, queue=queue) \ + .then(**undo_user_args, api_name='undo' if allow_api else None) \ + .then(**undo_user_args2, api_name='undo2' if allow_api else None) \ + .then(clear_all, inputs=None, outputs=[instruction, iinput, radio_chats, score_text, + score_text2], queue=queue) \ + .then(**score_args, api_name='undo_score' if allow_api else None) \ + .then(**score_args2, api_name='undo_score2' if allow_api else None) + submits4 = [submit_event4] # MANAGE CHATS def dedup(short_chat, short_chats): @@ -1227,47 +1564,80 @@ def go_gradio(**kwargs): #

    etc. added in chat, try to remove some of that to help avoid dup entries when hit new conversation is_same = True # length of conversation has to be same + if len(x) != len(y): + return False if len(x) != len(y): return False for stepx, stepy in zip(x, y): if len(stepx) != len(stepy): # something off with a conversation return False - if len(stepx) != 2: - # something off - return False - if len(stepy) != 2: - # something off - return False - questionx = stepx[0].replace('

    ', '').replace('

    ', '') if stepx[0] is not None else None - answerx = stepx[1].replace('

    ', '').replace('

    ', '') if stepx[1] is not None else None - - questiony = stepy[0].replace('

    ', '').replace('

    ', '') if stepy[0] is not None else None - answery = stepy[1].replace('

    ', '').replace('

    ', '') if stepy[1] is not None else None - - if questionx != questiony or answerx != answery: - return False + for stepxx, stepyy in zip(stepx, stepy): + if len(stepxx) != len(stepyy): + # something off with a conversation + return False + if len(stepxx) != 2: + # something off + return False + if len(stepyy) != 2: + # something off + return False + questionx = stepxx[0].replace('

    ', '').replace('

    ', '') if stepxx[0] is not None else None + answerx = stepxx[1].replace('

    ', '').replace('

    ', '') if stepxx[1] is not None else None + + questiony = stepyy[0].replace('

    ', '').replace('

    ', '') if stepyy[0] is not None else None + answery = stepyy[1].replace('

    ', '').replace('

    ', '') if stepyy[1] is not None else None + + if questionx != questiony or answerx != answery: + return False return is_same - def save_chat(chat1, chat2, chat_state1): + def save_chat(*args): + args_list = list(args) + chat_list = args_list[:-1] # list of chatbot histories + # remove None histories + chat_list_not_none = [x for x in chat_list if x and len(x) > 0 and len(x[0]) == 2 and x[0][1] is not None] + chat_state1 = args_list[ + -1] # dict with keys of short chat names, values of list of list of chatbot histories short_chats = list(chat_state1.keys()) - for chati in [chat1, chat2]: - if chati and len(chati) > 0 and len(chati[0]) == 2 and chati[0][1] is not None: - short_chat = get_short_chat(chati, short_chats) - if short_chat: - already_exists = any([is_chat_same(chati, x) for x in chat_state1.values()]) - if not already_exists: - chat_state1[short_chat] = chati - return chat_state1 + if len(chat_list_not_none) > 0: + # make short_chat key from only first history, based upon question that is same anyways + chat_first = chat_list_not_none[0] + short_chat = get_short_chat(chat_first, short_chats) + if short_chat: + old_chat_lists = list(chat_state1.values()) + already_exists = any([is_chat_same(chat_list, x) for x in old_chat_lists]) + if not already_exists: + chat_state1[short_chat] = chat_list.copy() + # clear chat_list so saved and then new conversation starts + chat_list = [[]] * len(chat_list) + ret_list = chat_list + [chat_state1] + return tuple(ret_list) def update_radio_chats(chat_state1): return gr.update(choices=list(chat_state1.keys()), value=None) - def switch_chat(chat_key, chat_state1): + def switch_chat(chat_key, chat_state1, num_model_lock=0): chosen_chat = chat_state1[chat_key] - return chosen_chat, chosen_chat - - radio_chats.input(switch_chat, inputs=[radio_chats, chat_state], outputs=[text_output, text_output2]) + # deal with possible different size of chat list vs. current list + ret_chat = [None] * (2 + num_model_lock) + for chati in range(0, 2 + num_model_lock): + ret_chat[chati % len(ret_chat)] = chosen_chat[chati % len(chosen_chat)] + return tuple(ret_chat) + + def clear_texts(*args): + return tuple([gr.Textbox.update(value='')] * len(args)) + + def clear_scores(): + return gr.Textbox.update(value=res_value), \ + gr.Textbox.update(value='Response Score: NA'), \ + gr.Textbox.update(value='Response Score: NA') + + switch_chat_fun = functools.partial(switch_chat, num_model_lock=len(text_outputs)) + radio_chats.input(switch_chat_fun, + inputs=[radio_chats, chat_state], + outputs=[text_output, text_output2] + text_outputs) \ + .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat]) def remove_chat(chat_key, chat_state1): chat_state1.pop(chat_key, None) @@ -1304,9 +1674,11 @@ def go_gradio(**kwargs): new_chats = json.loads(f.read()) for chat1_k, chat1_v in new_chats.items(): # ignore chat1_k, regenerate and de-dup to avoid loss - chat_state1 = save_chat(chat1_v, None, chat_state1) + _, chat_state1 = save_chat(chat1_v, chat_state1) except BaseException as e: - print("Add chats exception: %s" % str(e), flush=True) + t, v, tb = sys.exc_info() + ex = ''.join(traceback.format_exception(t, v, tb)) + print("Add chats exception: %s" % str(ex), flush=True) return chat_state1, add_btn # note for update_user_db_func output is ignored for db @@ -1317,17 +1689,22 @@ def go_gradio(**kwargs): .then(clear_file_list, outputs=chatsup_output, queue=False) \ .then(update_radio_chats, inputs=chat_state, outputs=radio_chats, queue=False) - clear_chat_btn.click(lambda: None, None, text_output, queue=False, api_name='clear' if allow_api else None) \ - .then(lambda: None, None, text_output2, queue=False, api_name='clear2' if allow_api else None) \ - .then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False) + clear_chat_btn.click(fn=clear_texts, + inputs=[text_output, text_output2] + text_outputs, + outputs=[text_output, text_output2] + text_outputs, + queue=False, api_name='clear' if allow_api else None) \ + .then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False) \ + .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat]) # does both models - clear.click(save_chat, inputs=[text_output, text_output2, chat_state], outputs=chat_state, + clear.click(save_chat, + inputs=[text_output, text_output2] + text_outputs + [chat_state], + outputs=[text_output, text_output2] + text_outputs + [chat_state], api_name='save_chat' if allow_api else None) \ .then(update_radio_chats, inputs=chat_state, outputs=radio_chats, api_name='update_chats' if allow_api else None) \ - .then(lambda: None, None, text_output, queue=False, api_name='clearB' if allow_api else None) \ - .then(lambda: None, None, text_output2, queue=False, api_name='clearB2' if allow_api else None) + .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat]) + # NOTE: clear of instruction/iinput for nochat has to come after score, # because score for nochat consumes actual textbox, while chat consumes chat history filled by user() no_chat_args = dict(fn=fun, @@ -1356,28 +1733,29 @@ def go_gradio(**kwargs): api_name='submit_nochat_api' if allow_api else None) \ .then(clear_torch_cache) - def load_model(model_name, lora_weights, model_state_old, prompt_type_old, load_8bit, infer_devices, gpu_id): + def load_model(model_name, lora_weights, server_name, model_state_old, prompt_type_old, load_8bit, + infer_devices, gpu_id): # ensure old model removed from GPU memory if kwargs['debug']: print("Pre-switch pre-del GPU memory: %s" % get_torch_allocated(), flush=True) - model0 = model_state0[0] - if isinstance(model_state_old[0], str) and model0 is not None: + model0 = model_state0['model'] + if isinstance(model_state_old['model'], str) and model0 is not None: # best can do, move model loaded at first to CPU model0.cpu() - if model_state_old[0] is not None and not isinstance(model_state_old[0], str): + if model_state_old['model'] is not None and not isinstance(model_state_old['model'], str): try: - model_state_old[0].cpu() + model_state_old['model'].cpu() except Exception as e: # sometimes hit NotImplementedError: Cannot copy out of meta tensor; no data! print("Unable to put model on CPU: %s" % str(e), flush=True) - del model_state_old[0] - model_state_old[0] = None + del model_state_old['model'] + model_state_old['model'] = None - if model_state_old[1] is not None and not isinstance(model_state_old[1], str): - del model_state_old[1] - model_state_old[1] = None + if model_state_old['tokenizer'] is not None and not isinstance(model_state_old['tokenizer'], str): + del model_state_old['tokenizer'] + model_state_old['tokenizer'] = None clear_torch_cache() if kwargs['debug']: @@ -1387,7 +1765,11 @@ def go_gradio(**kwargs): # no-op if no model, just free memory # no detranscribe needed for model, never go into evaluate lora_weights = no_lora_str - return [None, None, None, model_name], model_name, lora_weights, prompt_type_old + server_name = no_server_str + return [None, None, None, model_name, server_name], \ + model_name, lora_weights, server_name, prompt_type_old, \ + gr.Slider.update(maximum=256), \ + gr.Slider.update(maximum=256) # don't deepcopy, can contain model itself all_kwargs1 = all_kwargs.copy() @@ -1404,34 +1786,50 @@ def go_gradio(**kwargs): # detranscribe if lora_weights == no_lora_str: lora_weights = '' - all_kwargs1['lora_weights'] = lora_weights.strip() + if server_name == no_server_str: + server_name = '' + all_kwargs1['inference_server'] = server_name.strip() + model1, tokenizer1, device1 = get_model(reward_type=False, **get_kwargs(get_model, exclude_names=['reward_type'], **all_kwargs1)) clear_torch_cache() - model_state_new = [model1, tokenizer1, device1, model_name] + tokenizer_base_model = model_name + prompt_dict1, error0 = get_prompt(prompt_type1, '', + chat=False, context='', reduced=False, making_context=False, + return_dict=True) + model_state_new = dict(model=model1, tokenizer=tokenizer1, device=device1, + base_model=model_name, tokenizer_base_model=tokenizer_base_model, + lora_weights=lora_weights, inference_server=server_name, + prompt_type=prompt_type1, prompt_dict=prompt_dict1, + ) max_max_new_tokens1 = get_max_max_new_tokens(model_state_new, **kwargs) if kwargs['debug']: print("Post-switch GPU memory: %s" % get_torch_allocated(), flush=True) - return model_state_new, model_name, lora_weights, prompt_type1, \ + return model_state_new, model_name, lora_weights, server_name, prompt_type1, \ gr.Slider.update(maximum=max_max_new_tokens1), \ gr.Slider.update(maximum=max_max_new_tokens1) - def get_prompt_str(prompt_type1, prompt_dict1): + def get_prompt_str(prompt_type1, prompt_dict1, which=0): + if prompt_type1 in ['', None]: + print("Got prompt_type %s: %s" % (which, prompt_type1), flush=True) + return str({}) prompt_dict1, prompt_dict_error = get_prompt(prompt_type1, prompt_dict1, chat=False, context='', - reduced=False, return_dict=True) + reduced=False, making_context=False, return_dict=True) if prompt_dict_error: return str(prompt_dict_error) else: # return so user can manipulate if want and use as custom return str(prompt_dict1) - prompt_type.change(fn=get_prompt_str, inputs=[prompt_type, prompt_dict], outputs=prompt_dict) - prompt_type2.change(fn=get_prompt_str, inputs=[prompt_type2, prompt_dict2], outputs=prompt_dict2) + get_prompt_str_func1 = functools.partial(get_prompt_str, which=1) + get_prompt_str_func2 = functools.partial(get_prompt_str, which=2) + prompt_type.change(fn=get_prompt_str_func1, inputs=[prompt_type, prompt_dict], outputs=prompt_dict) + prompt_type2.change(fn=get_prompt_str_func2, inputs=[prompt_type2, prompt_dict2], outputs=prompt_dict2) def dropdown_prompt_type_list(x): return gr.Dropdown.update(value=x) @@ -1440,9 +1838,9 @@ def go_gradio(**kwargs): return gr.Textbox.update(label=f'h2oGPT [Model: {model_used_in}]') load_model_args = dict(fn=load_model, - inputs=[model_choice, lora_choice, model_state, prompt_type, + inputs=[model_choice, lora_choice, server_choice, model_state, prompt_type, model_load8bit_checkbox, model_infer_devices_checkbox, model_gpu], - outputs=[model_state, model_used, lora_used, + outputs=[model_state, model_used, lora_used, server_used, # if prompt_type changes, prompt_dict will change via change rule prompt_type, max_new_tokens, min_new_tokens, ]) @@ -1457,9 +1855,9 @@ def go_gradio(**kwargs): .then(clear_torch_cache) load_model_args2 = dict(fn=load_model, - inputs=[model_choice2, lora_choice2, model_state2, prompt_type2, + inputs=[model_choice2, lora_choice2, server_choice2, model_state2, prompt_type2, model_load8bit_checkbox2, model_infer_devices_checkbox2, model_gpu2], - outputs=[model_state2, model_used2, lora_used2, + outputs=[model_state2, model_used2, lora_used2, server_used2, # if prompt_type2 changes, prompt_dict2 will change via change rule prompt_type2, max_new_tokens2, min_new_tokens2 ]) @@ -1472,32 +1870,51 @@ def go_gradio(**kwargs): .then(**chatbot_update_args2) \ .then(clear_torch_cache) - def dropdown_model_list(list0, x): - new_state = [list0[0] + [x]] - new_options = [*new_state[0]] - return gr.Dropdown.update(value=x, choices=new_options), \ - gr.Dropdown.update(value=x, choices=new_options), \ - '', new_state - - add_model_event = add_model_button.click(fn=dropdown_model_list, - inputs=[model_options_state, new_model], - outputs=[model_choice, model_choice2, new_model, model_options_state], - queue=False) - - def dropdown_lora_list(list0, x, model_used1, lora_used1, model_used2, lora_used2): - new_state = [list0[0] + [x]] - new_options = [*new_state[0]] + def dropdown_model_lora_server_list(model_list0, model_x, + lora_list0, lora_x, + server_list0, server_x, + model_used1, lora_used1, server_used1, + model_used2, lora_used2, server_used2, + ): + model_new_state = [model_list0[0] + [model_x]] + model_new_options = [*model_new_state[0]] + x1 = model_x if model_used1 == no_model_str else model_used1 + x2 = model_x if model_used2 == no_model_str else model_used2 + ret1 = [gr.Dropdown.update(value=x1, choices=model_new_options), + gr.Dropdown.update(value=x2, choices=model_new_options), + '', model_new_state] + + lora_new_state = [lora_list0[0] + [lora_x]] + lora_new_options = [*lora_new_state[0]] # don't switch drop-down to added lora if already have model loaded - x1 = x if model_used1 == no_model_str else lora_used1 - x2 = x if model_used2 == no_model_str else lora_used2 - return gr.Dropdown.update(value=x1, choices=new_options), \ - gr.Dropdown.update(value=x2, choices=new_options), \ - '', new_state - - add_lora_event = add_lora_button.click(fn=dropdown_lora_list, - inputs=[lora_options_state, new_lora, model_used, lora_used, model_used2, - lora_used2], - outputs=[lora_choice, lora_choice2, new_lora, lora_options_state], + x1 = lora_x if model_used1 == no_model_str else lora_used1 + x2 = lora_x if model_used2 == no_model_str else lora_used2 + ret2 = [gr.Dropdown.update(value=x1, choices=lora_new_options), + gr.Dropdown.update(value=x2, choices=lora_new_options), + '', lora_new_state] + + server_new_state = [server_list0[0] + [server_x]] + server_new_options = [*server_new_state[0]] + # don't switch drop-down to added server if already have model loaded + x1 = server_x if model_used1 == no_model_str else server_used1 + x2 = server_x if model_used2 == no_model_str else server_used2 + ret3 = [gr.Dropdown.update(value=x1, choices=server_new_options), + gr.Dropdown.update(value=x2, choices=server_new_options), + '', server_new_state] + + return tuple(ret1 + ret2 + ret3) + + add_model_lora_server_event = \ + add_model_lora_server_button.click(fn=dropdown_model_lora_server_list, + inputs=[model_options_state, new_model] + + [lora_options_state, new_lora] + + [server_options_state, new_server] + + [model_used, lora_used, server_used] + + [model_used2, lora_used2, server_used2], + outputs=[model_choice, model_choice2, new_model, model_options_state] + + [lora_choice, lora_choice2, new_lora, lora_options_state] + + [server_choice, server_choice2, new_server, + server_options_state], queue=False) go_btn.click(lambda: gr.update(visible=False), None, go_btn, api_name="go" if allow_api else None, queue=False) \ @@ -1526,8 +1943,9 @@ def go_gradio(**kwargs): # FIXME: add score_res2 in condition, but do better # callback for logging flagged input/output - callback.setup(inputs_list + [text_output, text_output2], "flagged_data_points") - flag_btn.click(lambda *args: callback.flag(args), inputs_list + [text_output, text_output2], None, + callback.setup(inputs_list + [text_output, text_output2] + text_outputs, "flagged_data_points") + flag_btn.click(lambda *args: callback.flag(args), inputs_list + [text_output, text_output2] + text_outputs, + None, preprocess=False, api_name='flag' if allow_api else None, queue=False) flag_btn_nochat.click(lambda *args: callback.flag(args), inputs_list + [text_output_nochat], None, @@ -1535,27 +1953,64 @@ def go_gradio(**kwargs): api_name='flag_nochat' if allow_api else None, queue=False) def get_system_info(): + if is_public: + time.sleep(10) # delay to avoid spam since queue=False return gr.Textbox.update(value=system_info_print()) system_event = system_btn.click(get_system_info, outputs=system_text, api_name='system_info' if allow_api else None, queue=False) + def get_system_info_dict(system_input1, **kwargs1): + if system_input1 != os.getenv("ADMIN_PASS", ""): + return json.dumps({}) + exclude_list = ['admin_pass', 'examples'] + sys_dict = {k: v for k, v in kwargs1.items() if + isinstance(v, (str, int, bool, float)) and k not in exclude_list} + try: + sys_dict.update(system_info()) + except Exception as e: + # protection + print("Exception: %s" % str(e), flush=True) + return json.dumps(sys_dict) + + get_system_info_dict_func = functools.partial(get_system_info_dict, **all_kwargs) + + system_dict_event = system_btn2.click(get_system_info_dict_func, + inputs=system_input, + outputs=system_text2, + api_name='system_info_dict' if allow_api else None, + queue=False, # queue to avoid spam + ) + + def get_hash(): + return kwargs['git_hash'] + + system_btn3.click(get_hash, + outputs=system_text3, + api_name='system_hash' if allow_api else None, + queue=False, + ) + # don't pass text_output, don't want to clear output, just stop it # cancel only stops outer generation, not inner generation or non-generation stop_btn.click(lambda: None, None, None, - cancels=[submit_event1d, submit_event1f, - submit_event2d, submit_event2f, - submit_event3d, submit_event3f, - submit_event_nochat, - submit_event_nochat2, - ], + cancels=submits1 + submits2 + submits3 + + submits4 + + [submit_event_nochat, submit_event_nochat2] + + [eventdb1, eventdb2, eventdb3, + eventdb4, eventdb5, eventdb6] + + [eventdb7, eventdb8, eventdb9] + , queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False) - def count_chat_tokens(model_state1, chat1, prompt_type1, prompt_dict1): - if model_state1 and not isinstance(model_state1[1], str): - tokenizer = model_state1[1] - elif model_state0 and not isinstance(model_state0[1], str): - tokenizer = model_state0[1] + def count_chat_tokens(model_state1, chat1, prompt_type1, prompt_dict1, + memory_restriction_level1=0, + keep_sources_in_context1=False, + ): + if model_state1 and not isinstance(model_state1['tokenizer'], str): + tokenizer = model_state1['tokenizer'] + elif model_state0 and not isinstance(model_state0['tokenizer'], str): + tokenizer = model_state0['tokenizer'] else: tokenizer = None if tokenizer is not None: @@ -1565,18 +2020,26 @@ def go_gradio(**kwargs): chat1 = chat1 + [['user_message1', None]] model_max_length1 = tokenizer.model_max_length context1 = history_to_context(chat1, langchain_mode1, prompt_type1, prompt_dict1, chat1, - model_max_length1) + model_max_length1, + memory_restriction_level1, keep_sources_in_context1) return str(tokenizer(context1, return_tensors="pt")['input_ids'].shape[1]) else: return "N/A" - count_chat_tokens_btn.click(fn=count_chat_tokens, inputs=[model_state, text_output, prompt_type, prompt_dict], + count_chat_tokens_func = functools.partial(count_chat_tokens, + memory_restriction_level1=memory_restriction_level, + keep_sources_in_context1=kwargs['keep_sources_in_context']) + count_chat_tokens_btn.click(fn=count_chat_tokens, + inputs=[model_state, text_output, prompt_type, prompt_dict], outputs=chat_token_count, api_name='count_tokens' if allow_api else None) - demo.load(None, None, None, _js=get_dark_js() if kwargs['h2ocolors'] else None) + demo.load(None, None, None, _js=get_dark_js() if kwargs['h2ocolors'] and False else None) # light best demo.queue(concurrency_count=kwargs['concurrency_count'], api_open=kwargs['api_open']) favicon_path = "h2o-logo.svg" + if not os.path.isfile(favicon_path): + print("favicon_path=%s not found" % favicon_path, flush=True) + favicon_path = None scheduler = BackgroundScheduler() scheduler.add_job(func=clear_torch_cache, trigger="interval", seconds=20) @@ -1585,6 +2048,7 @@ def go_gradio(**kwargs): # FIXME: disable for gptj, langchain or gpt4all modify print itself # FIXME: and any multi-threaded/async print will enter model output! scheduler.add_job(func=ping, trigger="interval", seconds=60) + scheduler.add_job(func=ping_gpu, trigger="interval", seconds=60 * 10) scheduler.start() # import control @@ -1648,18 +2112,22 @@ def get_sources(db1, langchain_mode, dbs=None, docs_state0=None): " Ask jon.mckinney@h2o.ai for file if required." source_list = [] elif langchain_mode == 'MyData' and len(db1) > 0 and db1[0] is not None: - db_get = db1[0].get() - source_list = sorted(set([x['source'] for x in db_get['metadatas']])) + from gpt_langchain import get_metadatas + metadatas = get_metadatas(db1[0]) + source_list = sorted(set([x['source'] for x in metadatas])) source_files_added = '\n'.join(source_list) elif langchain_mode in dbs and dbs[langchain_mode] is not None: + from gpt_langchain import get_metadatas db1 = dbs[langchain_mode] - db_get = db1.get() - source_list = sorted(set([x['source'] for x in db_get['metadatas']])) + metadatas = get_metadatas(db1) + source_list = sorted(set([x['source'] for x in metadatas])) source_files_added = '\n'.join(source_list) else: source_list = [] source_files_added = "None" - sources_file = 'sources_%s_%s' % (langchain_mode, str(uuid.uuid4())) + sources_dir = "sources_dir" + makedirs(sources_dir) + sources_file = os.path.join(sources_dir, 'sources_%s_%s' % (langchain_mode, str(uuid.uuid4()))) with open(sources_file, "wt") as f: f.write(source_files_added) source_list = docs_state0 + source_list @@ -1686,14 +2154,15 @@ def update_user_db(file, db1, x, y, *args, dbs=None, langchain_mode='UserData', """.format(ex_str) if langchain_mode == 'MyData': - return db1, x, y, source_files_added + return None, langchain_mode, db1, x, y, source_files_added else: - return x, y, source_files_added + return None, langchain_mode, x, y, source_files_added finally: clear_torch_cache() def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None, langchain_mode='UserData', + user_path=None, use_openai_embedding=None, hf_embedding_model=None, caption_loader=None, @@ -1710,8 +2179,10 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None, assert enable_ocr is not None assert verbose is not None + if dbs is None: + dbs = {} assert isinstance(dbs, dict), "Wrong type for dbs: %s" % str(type(dbs)) - assert db_type in ['faiss', 'chroma'], "db_type %s not supported" % db_type + # assert db_type in ['faiss', 'chroma'], "db_type %s not supported" % db_type from gpt_langchain import add_to_db, get_db, path_to_docs # handle case of list of temp buffer if isinstance(file, list) and len(file) > 0 and hasattr(file[0], 'name'): @@ -1719,6 +2190,23 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None, # handle single file of temp buffer if hasattr(file, 'name'): file = file.name + if not isinstance(file, (list, tuple, typing.Generator)) and isinstance(file, str): + file = [file] + + if langchain_mode == 'UserData' and user_path is not None: + # move temp files from gradio upload to stable location + for fili, fil in enumerate(file): + if isinstance(fil, str): + if fil.startswith('/tmp/gradio/'): + new_fil = os.path.join(user_path, os.path.basename(fil)) + if os.path.isfile(new_fil): + remove(new_fil) + try: + shutil.move(fil, new_fil) + except FileExistsError: + pass + file[fili] = new_fil + if verbose: print("Adding %s" % file, flush=True) sources = path_to_docs(file if not is_url and not is_txt else None, @@ -1742,7 +2230,10 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None, use_openai_embedding=use_openai_embedding, hf_embedding_model=hf_embedding_model) else: - assert len(db1) == 2 and db1[1] is None, "Bad MyData db: %s" % db1 + # in testing expect: + # assert len(db1) == 2 and db1[1] is None, "Bad MyData db: %s" % db1 + # for production hit, when user gets clicky: + assert len(db1) == 2, "Bad MyData db: %s" % db1 # then create # assign fresh hash for this user session, so not shared # if added has to original state and didn't change, then would be shared db for all users @@ -1758,7 +2249,7 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None, else: db1[0] = db source_files_added = get_source_files(db=db1[0], exceptions=exceptions) - return db1, x, y, source_files_added + return None, langchain_mode, db1, x, y, source_files_added else: from gpt_langchain import get_persist_directory persist_directory = get_persist_directory(langchain_mode) @@ -1779,7 +2270,7 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None, # return dbs[langchain_mode], x, y # db in this code path is updated in place source_files_added = get_source_files(db=dbs[langchain_mode], exceptions=exceptions) - return x, y, source_files_added + return None, langchain_mode, x, y, source_files_added def get_db(db1, langchain_mode, dbs=None): @@ -1789,7 +2280,7 @@ def get_db(db1, langchain_mode, dbs=None): db = None elif langchain_mode == 'MyData' and len(db1) > 0 and db1[0] is not None: db = db1[0] - elif langchain_mode in dbs and dbs[langchain_mode] is not None: + elif dbs is not None and langchain_mode in dbs and dbs[langchain_mode] is not None: db = dbs[langchain_mode] else: db = None @@ -1808,12 +2299,16 @@ def get_source_files(db=None, exceptions=None, metadatas=None): exceptions = [] # only should be one source, not confused - assert db is not None or metadatas is not None + # assert db is not None or metadatas is not None + # clicky user + if db is None and metadatas is None: + return "No Sources at all" if metadatas is None: source_label = "Sources:" if db is not None: - metadatas = db.get()['metadatas'] + from gpt_langchain import get_metadatas + metadatas = get_metadatas(db) else: metadatas = [] adding_new = False