# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
import gradio as gr
import subprocess
from ui.configuration import Configuration
import math
import uuid
import sys
import os
import kaizen_theme as kaizen
import re
import threading
import webbrowser
import socket
import random
class MainInterface:
_dataset_path_key = 'dataset'
_models_path_key = 'models'
_dataset_directory_label_text = ".txt, .pdf, .doc files supported"
_dataset_path_updated_callback = None
_dataset_source_updated_callback = None
_shutdown_callback = None
_reset_chat_callback = None
_undo_last_chat_callback = None
_model_change_callback = None
_regenerate_index_callback = None
_query_handler = None
_state = None
_interface = None
_streaming = False
_models_list = {}
def _get_enable_disable_elemet_list(self):
ret_val = [
self._chat_query_input_textbox,
self._chat_bot_window,
self._chat_submit_button,
self._chat_retry_button,
self._chat_reset_button,
self._models_dropdown,
self._dataset_source_dropdown,
self._dataset_update_source_edit_button,
self._dataset_regenerate_index_button
]
if self._chat_undo_button is not None:
ret_val.append(self._chat_undo_button)
return ret_val
def _before_change_element_state(self, request: gr.Request):
self._validate_session(request)
ret_val = [
gr.Textbox("", interactive=False),
gr.Chatbot(),
gr.Button(interactive=False),
gr.Button(interactive=False),
gr.Button(interactive=False),
gr.Dropdown(interactive=False),
gr.Dropdown(interactive=False),
gr.Button(interactive=False),
gr.Button(interactive=False)
]
if self._chat_undo_button is not None:
ret_val.append(gr.Button(interactive=False))
return ret_val
def _after_change_element_state(self, request: gr.Request):
self._validate_session(request)
ret_val = [
gr.Textbox(interactive=True),
[],
gr.Button(interactive=True),
gr.Button(interactive=True),
gr.Button(interactive=True),
gr.Dropdown(interactive=True),
gr.Dropdown(interactive=True),
gr.Button(interactive=True),
gr.Button(interactive=True)
]
if self._chat_undo_button is not None:
ret_val.append(gr.Button(interactive=True))
return ret_val
def __init__(self, chatbot=None, streaming = False) -> None:
self._interface = None
self._query_handler = chatbot
self._streaming = streaming
self.config = Configuration()
self._dataset_path = self._get_dataset_path()
self._default_dataset_path = self._get_default_dataset_path()
pass
def _get_dataset_path(self):
_dataset_path = ""
dataset_config = self.config.get_config(self._dataset_path_key) or {}
if 'path' in dataset_config:
_dataset_path = dataset_config['path']
if 'isRelative' in dataset_config and dataset_config['isRelative'] is True:
_dataset_path = os.path.join(os.getcwd(), _dataset_path)
return _dataset_path
def _get_default_dataset_path(self):
_dataset_path = ""
dataset_config = self.config.get_config_from_file(self._dataset_path_key, "config/config.json") or {}
if 'path' in dataset_config:
_dataset_path = dataset_config['path']
if 'isRelative' in dataset_config and dataset_config['isRelative'] is True:
_dataset_path = os.path.join(os.getcwd(), _dataset_path)
return _dataset_path
def on_dataset_path_updated(self, callback):
self._dataset_path_updated_callback = callback
def on_dataset_source_updated(self, callback):
self._dataset_source_updated_callback = callback
def on_shutdown(self, callback):
self._shutdown_callback = callback
def on_reset_chat(self, callback):
self._reset_chat_callback = callback
def on_undo_last_chat(self, callback):
self._undo_last_chat_callback = callback
def on_model_change(self, callback):
self._model_change_callback = callback
def on_regenerate_index(self, callback):
self._regenerate_index_callback = callback
def _get_theme(self):
primary_hue = gr.themes.Color("#76B900", "#76B900", "#76B900", "#76B900", "#76B900", "#76B900", "#76B900", "#76B900", "#76B900", "#76B900", "#76B900")
neutral_hue = gr.themes.Color("#292929", "#292929", "#292929", "#292929", "#292929", "#292929", "#292929", "#292929", "#292929", "#292929", "#292929")
return gr.Theme(
primary_hue=primary_hue,
neutral_hue=neutral_hue
).set(
body_background_fill="#191919",
body_background_fill_dark="#191919",
block_background_fill="#292929",
block_background_fill_dark="#292929",
block_label_background_fill="#292929",
block_label_background_fill_dark="#292929",
border_color_primary="#191919",#components background
border_color_primary_dark="#191919",
background_fill_primary="#292929",#dropdown
background_fill_primary_dark="#292929",
background_fill_secondary="#393939",#response chatbot bubble
background_fill_secondary_dark="#393939",
color_accent_soft="#393939",#request chatbot bubble
color_accent_soft_dark="#393939",
#text colors
block_label_text_color="#FFFFFF",
block_label_text_color_dark="#FFFFFF",
body_text_color="#FFFFFF",
body_text_color_dark="#FFFFFF",
body_text_color_subdued="#FFFFFF",
body_text_color_subdued_dark="#FFFFFF",
button_secondary_text_color="#FFFFFF",
button_secondary_text_color_dark="#FFFFFF",
button_primary_text_color="#FFFFFF",
button_primary_text_color_dark="#FFFFFF",
input_placeholder_color="#FFFFFF",#placeholder text color
input_placeholder_color_dark="#FFFFFF",
)
def get_css(self):
return kaizen.css() + open(os.path.join(os.path.dirname(__file__), 'www/app.css')).read()
def render(self):
with gr.Blocks(
title="Chat with RTX",
analytics_enabled=False,
theme=kaizen.theme(),
css=self.get_css(),
js=os.path.join(os.path.dirname(__file__), 'www/app.js')
) as interface:
self._interface = interface
self._state = gr.State({})
(
self._shutdown_button,
self._shutdown_post_shutdown_group,
self._shutdown_memory_released_markdown,
self._shutdown_invalid_session_markdown
) = self._render_logo_shut_down()
with gr.Row():
self._models_dropdown, self._models_group = self._render_models()
(
self._dataset_source_textbox,
self._dataset_update_source_edit_button,
self._dataset_source_dropdown,
self._dataset_regenerate_index_button,
self._dataset_label_markdown,
self._dataset_group
) = self._render_dataset_picker()
(
self._sample_question_components,
self._sample_question_rows,
self._sample_question_empty_space_component,
self._sample_qustion_default_dataset_markdown
) = self._render_sample_question()
(
self._chat_bot_window,
self._chat_query_input_textbox,
self._chat_submit_button,
self._chat_retry_button,
self._chat_undo_button,
self._chat_reset_button,
self._chat_query_group,
self._chat_disclaimer_markdown
) = self._render_chatbot(show_chatbot=len(self._sample_question_components) == 0)
self._handle_events()
interface.queue()
port = self._get_free_port()
self._open_app(port)
interface.launch(
favicon_path=os.path.join(os.path.dirname(__file__), 'assets/nvidia_logo.png'),
show_api=False,
server_port=port,share=True
)
def _get_free_port(self):
# Create a socket object
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Set a short timeout for the connection attempt
sock.settimeout(1)
port = None
while port is None:
port = 1048
try:
# Attempt to bind to the port
sock.bind((socket.gethostbyname(socket.gethostname()), port))
print(socket.gethostbyname(socket.gethostname()), port)
except OSError as e:
port = None
if e.errno != 98: # errno 98: Address already in use
print('OS error', e)
break
sock.close()
return port
def _open_app(self, port):
def launch_thread(cookie):
launch_url = f'http://127.0.0.1:{port}?cookie={cookie}&__theme=dark'
print(f'Open {launch_url} in browser to start Chat with RTX')
webbrowser.open(launch_url)
return None
self._secure_cookie = str(uuid.uuid4())
threading.Thread(target=launch_thread, args=(self._secure_cookie,)).start()
return None
def _validate_request(self, request: gr.Request):
return True
def _get_session_id(self, state):
if isinstance(state, object):
if not 'session-id' in state:
state['session-id'] = str(uuid.uuid4())
return state['session-id']
return None
def _render_models(self):
models = self.config.get_config(self._models_path_key)
model_names = []
for model in models['supported']:
if model['installed'] is True:
model_names.append(model['name'])
self._models_list[model['name']] = model
with gr.Column():
with gr.Group(elem_classes="padding-8p model-goup") as models_group:
gr.Markdown("AI model")
gr.Markdown(
'Select AI model',
elem_classes="description-secondary-markdown"
)
def get_selected_value():
selected = self.config.get_config('models/selected')
if len(self._models_list) > 0:
if not selected in self._models_list:
selected = self.config.get_config_from_file('models/selected', "config/config.json")
if not selected in self._models_list:
selected = list(self._models_list.keys())[0]
return selected
return gr.Dropdown(
model_names,
elem_classes="height-40p",
value=get_selected_value,
container=False,
filterable=False
), models_group
def _render_logo_shut_down(self):
with gr.Row():
gr.Image(os.path.join(os.path.dirname(__file__), "assets/nvidia_logo.png"),
interactive=False,
show_label=False,
show_download_button=False,
width=40,
scale=0,
container=False,
min_width=40
)
gr.HTML("""
Chat with RTX
""")
shutdown_button = gr.Button(
"",
scale=0,
icon=os.path.join(os.path.dirname(__file__), 'assets/shutdown.png'),
elem_classes="icon-button tooltip-component",
elem_id="shutdown-btn"
)
with gr.Group(visible=False, elem_classes="shutdown-group") as post_shutdown_group:
with gr.Row():
gr.HTML("")
gr.Image(os.path.join(os.path.dirname(__file__), "assets/info.png"),
interactive=False,
show_label=False,
show_download_button=False,
width=40,
scale=0,
container=False,
min_width=40
)
gr.HTML("")
with gr.Row():
shutdown_memory_released_markdown = gr.Markdown(
"Video memory released. Reopen RTX Chat from desktop to continue chatting.",
elem_classes="text-align-center"
)
shutdown_invalid_session_markdown = gr.Markdown(
"Invalid session. Reopen RTX Chat from desktop to continue chatting.",
elem_classes="text-align-center"
)
return shutdown_button, post_shutdown_group, shutdown_memory_released_markdown, shutdown_invalid_session_markdown
def _render_dataset_picker(self):
sources = self.config.get_config("dataset/sources")
self._dataset_selected_source = self.config.get_config("dataset/selected")
with gr.Column(elem_classes="dataset-column"):
with gr.Group(elem_classes="padding-8p dataset-goup") as dataset_group:
with gr.Row():
with gr.Column():
gr.Markdown("Dataset")
dataset_label_markdown = gr.Markdown(
self._dataset_directory_label_text,
elem_classes="description-secondary-markdown",
elem_id="dataset-description-label"
)
regenerate_vector_button = gr.Button(
"",
icon=os.path.join(os.path.dirname(__file__), 'assets/regenerate.png'),
elem_classes="icon-button tooltip-component",
elem_id="dataset-regenerate-index-btn",
scale=0
)
dataset_source_dropdown = gr.Dropdown(
self.config.get_display_strings(sources),
value=lambda: self.config.get_display_strings(self._dataset_selected_source),
show_label=False,
container=False,
filterable=False,
elem_classes="margin-bottom-8p height-40p"
)
with gr.Row():
dataset_source_textbox = gr.Textbox(
lambda: self._dataset_path,
scale=9,
container=False,
elem_classes="height-40p margin-right-8p",
interactive=False,
placeholder="Enter URL...",
max_lines=1,
autoscroll=True
)
dataset_update_source_edit_button = gr.Button(
"",
icon=os.path.join(os.path.dirname(__file__), 'assets/edit.png'),
elem_classes="icon-button tooltip-component",
elem_id="dataset-update-source-edit-button",
visible=self._dataset_selected_source=="directory",
scale=0
)
return (
dataset_source_textbox,
dataset_update_source_edit_button,
dataset_source_dropdown,
regenerate_vector_button,
dataset_label_markdown,
dataset_group
)
def _render_sample_question(self):
question_butons = []
question_rows = []
sample_questions: list = self.config.get_config("sample_questions")
if sample_questions is None or len(sample_questions) == 0:
return question_butons, question_rows
chat_window_height = 500
row_height = 16
button_height = 42
elem_per_row = 2
row_count = math.ceil(len(sample_questions) / elem_per_row)
height = chat_window_height - (row_count*(row_height + button_height))
empty_space_component = gr.HTML(f'')
default_dataset_label = gr.Markdown(
"Default dataset is a sampling of articles recently published on GeForce News",
elem_classes="description-secondary-markdown chat-disclaimer-message margin-"
)
for i in range(0, len(sample_questions), 2):
row_questions = sample_questions[:2]
sample_questions = sample_questions[2:]
with gr.Row() as question_row:
for index, question in enumerate(row_questions):
query = question["query"]
button = gr.Button(
query,
elem_classes="sample-question-button"
)
question_butons.append({
"question": query,
"component": button
})
question_rows.append(question_row)
return question_butons, question_rows, empty_space_component, default_dataset_label
def _render_chatbot(self, show_chatbot):
chatbot_window = gr.Chatbot(
show_label=False,
elem_classes="chat-window",
visible=show_chatbot,
elem_id="main-chatbot-window",
sanitize_html=True
)
with gr.Group() as query_group:
with gr.Row():
query_input = gr.Textbox(placeholder="Chat with RTX...", scale=9, container=False)
submit_button = gr.Button("SEND", variant="primary", scale=1)
with gr.Row():
gr.HTML("")
retry_button = gr.Button(
"",
elem_classes="icon-button tooltip-component",
elem_id="chatbot-retry-button",
scale=0,
icon=os.path.join(os.path.dirname(__file__), 'assets/retry.png'),
)
undo_button = None
if self.config.get_config_from_file("is_chat_engine", os.path.join(os.path.curdir, "config/app_config.json")) == False:
undo_button = gr.Button(
"",
scale=0,
icon=os.path.join(os.path.dirname(__file__), 'assets/undo.png'),
elem_classes="icon-button tooltip-component",
elem_id="chatbot-undo-button"
)
reset_button = gr.Button(
"",
elem_classes="icon-button tooltip-component",
elem_id="chatbot-reset-button",
scale=0,
icon=os.path.join(os.path.dirname(__file__), 'assets/reset.png'),
)
gr.HTML("")
chat_disclaimer_markdown = gr.Markdown(
"Chat with RTX response quality depends on the AI model's accuracy and the input dataset. Please verify important information.",
elem_classes="description-secondary-markdown chat-disclaimer-message margin-"
)
return (chatbot_window, query_input, submit_button, retry_button, undo_button, reset_button, query_group, chat_disclaimer_markdown)
def _handle_events(self):
self._handle_load_events()
self._handle_shutdown_events()
self._handle_model_events()
self._handle_dataset_events()
self._handle_chatbot_events()
return None
def _validate_session_and_raise(self, request: gr.Request):
try:
self._validate_request(request)
except Exception as e:
raise gr.Error('Invalid session')
def _validate_session(self, request: gr.Request):
try:
self._validate_request(request)
except Exception as e:
return [
gr.Group(visible=False),
gr.Group(visible=False),
gr.Chatbot(visible=False),
gr.Group(visible=False),
gr.Button(visible=False),
gr.Button(visible=False),
gr.Group(visible=True),
gr.Button(visible=False),
gr.Button(visible=False),
gr.Markdown(visible=False),
gr.Markdown(visible=True),
gr.Markdown(visible=False)
] + self._get_sample_question_components_new(True)
return [
gr.Group(),
gr.Group(),
gr.Chatbot(),
gr.Group(),
gr.Button(),
gr.Button(),
gr.Group(),
gr.Button(),
gr.Button(),
gr.Markdown(),
gr.Markdown(),
gr.Markdown()
] + self._get_sample_question_components_new()
def _get_validate_session_output(self):
return [
self._models_group,
self._dataset_group,
self._chat_bot_window,
self._chat_query_group,
self._chat_reset_button,
self._chat_retry_button,
self._shutdown_post_shutdown_group,
self._shutdown_button,
self._chat_undo_button,
self._chat_disclaimer_markdown,
self._shutdown_invalid_session_markdown,
self._shutdown_memory_released_markdown
] + self._get_sample_question_components()
def _handle_load_events(self):
self._interface.load(
self._validate_session,
None,
self._get_validate_session_output()
).then(
self._validate_session_and_raise,
None,
None
).success(
self._show_hide_sample_questions,
self._get_show_hide_sample_questions_inputs(),
self._get_show_hide_sample_questions_outputs()
)
return None
def _handle_shutdown_events(self):
def close_thread(session_id):
if self._shutdown_callback:
self._shutdown_callback(session_id)
self._interface.close()
self._interface = None
print('exiting')
os._exit(0)
def handle_shutdown(state, request: gr.Request):
self._validate_session(request)
if self._interface is not None:
_close_thread = threading.Thread(target=close_thread, args=(self._get_session_id(state),))
_close_thread.start()
else:
print("Interface not initialized or already closed")
return state
def before_shutdown(request: gr.Request):
self._validate_session(request)
ret_val = [
gr.Group(visible=False),
gr.Group(visible=False),
gr.Chatbot(visible=False),
gr.Group(visible=False),
gr.Button(visible=False),
gr.Button(visible=False),
gr.Group(visible=True),
gr.Button(visible=False),
gr.Button(visible=False),
gr.Markdown(visible=False),
gr.Markdown(visible=False),
gr.Markdown(visible=True)
] + self._get_sample_question_components_new(True)
return ret_val
self._shutdown_button.click(
self._validate_session,
None,
self._get_validate_session_output()
).then(
self._validate_session_and_raise,
None,
None
).success(
before_shutdown,
None,
[
self._models_group,
self._dataset_group,
self._chat_bot_window,
self._chat_query_group,
self._chat_reset_button,
self._chat_retry_button,
self._shutdown_post_shutdown_group,
self._shutdown_button,
self._chat_undo_button,
self._chat_disclaimer_markdown,
self._shutdown_invalid_session_markdown,
self._shutdown_memory_released_markdown
] + self._get_sample_question_components()
).then(
handle_shutdown,
self._state,
self._state
)
def _handle_model_events(self):
def on_selection_change(newModel, state, request: gr.Request):
self._validate_session(request)
if self._model_change_callback:
self._model_change_callback(
self._models_list[newModel]['name'],
self._models_list[newModel]['metadata'],
self._get_session_id(state)
)
self.config.set_config("models/selected", newModel)
return newModel, state
self._models_dropdown.change(
self._validate_session,
None,
self._get_validate_session_output()
).then(
self._validate_session_and_raise,
None,
None
).success(
self._before_change_element_state,
None,
self._get_enable_disable_elemet_list()
).then(
on_selection_change,
[self._models_dropdown, self._state],
[self._models_dropdown, self._state]
).then(
self._after_change_element_state,
None,
self._get_enable_disable_elemet_list(),
show_progress=False
)
def _handle_dataset_events(self):
def select_folder(path, state, request: gr.Request):
self._validate_session(request)
if self._dataset_selected_source == "directory":
command = [sys.executable, "./ui/select_folder.py"]
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, _ = process.communicate()
# Check if the command was successful
result_string = ""
if process.returncode == 0:
result_string = output.decode().strip()
else:
print("Error executing script:", process.returncode)
if len(result_string) > 0:
self._dataset_path = result_string
self.config.set_config(self._dataset_path_key, {"path": self._dataset_path, "isRelative": False})
else:
self._dataset_path = path
if self._dataset_path_updated_callback:
self._dataset_path_updated_callback(
self._dataset_selected_source,
self._dataset_path,
None,
self._get_session_id(state)
)
return self._dataset_path, state
self._dataset_update_source_edit_button.click(
self._validate_session,
None,
self._get_validate_session_output()
).then(
self._validate_session_and_raise,
None,
None
).success(
self._before_change_element_state,
None,
self._get_enable_disable_elemet_list()
).then(
select_folder,
[self._dataset_source_textbox, self._state],
[self._dataset_source_textbox, self._state]
).then(
self._after_change_element_state,
None,
self._get_enable_disable_elemet_list(),
show_progress=False
).then(
self._show_hide_sample_questions,
self._get_show_hide_sample_questions_inputs(),
self._get_show_hide_sample_questions_outputs(),
show_progress=False
)
def on_dataset_source_changed(source, state, request: gr.Request):
self._validate_session(request)
self._dataset_selected_source = self.config.get_display_string_keys(source)
source = self._dataset_selected_source
self._dataset_path = self._get_dataset_path() if source=="directory" else ""
if self._dataset_source_updated_callback:
self._dataset_source_updated_callback(
self._dataset_selected_source,
self._dataset_path,
self._get_session_id(state)
)
return [
gr.Textbox(
interactive=False,
visible=source!="nodataset",
value=self._dataset_path
),
gr.Button(visible=source=="directory"),
gr.Button(visible=source!="nodataset"),
state
]
change_source_description = """
(source) => {
let element = document.querySelector('[id="dataset-description-label"]');
let label = ".txt, .pdf, .doc files supported";
if(element) {
element.querySelector('p').textContent = label;
} else {
console.error("failed to update");
}
return source;
}
"""
self._dataset_source_dropdown.change(
self._validate_session,
None,
self._get_validate_session_output()
).then(
self._validate_session_and_raise,
None,
None
).success(
self._before_change_element_state,
None,
self._get_enable_disable_elemet_list(),
show_progress=False
).then(
on_dataset_source_changed,
[self._dataset_source_dropdown, self._state],
[
self._dataset_source_textbox,
self._dataset_update_source_edit_button,
self._dataset_regenerate_index_button,
self._state
],
show_progress=False
).then(
lambda x: x,
self._dataset_source_dropdown,
self._dataset_source_dropdown,
show_progress=False,
js=change_source_description
).then(
self._after_change_element_state,
None,
self._get_enable_disable_elemet_list(),
show_progress=False
).then(
self._show_hide_sample_questions,
self._get_show_hide_sample_questions_inputs(),
self._get_show_hide_sample_questions_outputs(),
show_progress=False
)
def regenerate_index(state, request: gr.Request):
self._validate_session(request)
if self._regenerate_index_callback:
self._regenerate_index_callback(self._dataset_selected_source, self._dataset_path, self._get_session_id(state))
return self._dataset_path, state
self._dataset_regenerate_index_button.click(
self._validate_session,
None,
self._get_validate_session_output()
).then(
self._validate_session_and_raise,
None,
None
).success(
self._before_change_element_state,
None,
self._get_enable_disable_elemet_list()
).then(
regenerate_index,
self._state,
[self._dataset_source_textbox, self._state]
).then(
self._after_change_element_state,
None,
self._get_enable_disable_elemet_list(),
show_progress=False
)
# dataset events ends
def _show_hide_sample_questions(self, query, history, dataset_source, state, request: gr.Request):
self._validate_session(request)
dataset_source = self.config.get_display_string_keys(dataset_source)
sample_question_shown = state['sample_question_shown'] if isinstance(state, object) and 'sample_question_shown' in state else False
hide_sample_ques = \
len(query) > 0 or \
len(history) > 0 or \
sample_question_shown or \
(os.path.normpath(self._dataset_path) != os.path.normpath(self._default_dataset_path))
if isinstance(state, object):
state['sample_question_shown'] = hide_sample_ques
ret_val = [gr.Button(visible=not hide_sample_ques) for _ in self._sample_question_components]
ret_val.insert(0, gr.Chatbot(history, visible=hide_sample_ques))
[ret_val.append(gr.Row(visible=not hide_sample_ques)) for _ in self._sample_question_rows]
ret_val.append(gr.HTML(visible=not hide_sample_ques))
ret_val.append(gr.Markdown(visible=not hide_sample_ques))
ret_val.append(state)
return ret_val
def _get_show_hide_sample_questions_inputs(self):
return [
self._chat_query_input_textbox ,self._chat_bot_window, self._dataset_source_dropdown, self._state
]
def _get_sample_question_components_new(self, hide_sample_ques: bool = None):
if hide_sample_ques is None: # neither show nor hide
ret_val = [gr.Button() for _ in self._sample_question_components]
[ret_val.append(gr.Row()) for _ in self._sample_question_rows]
ret_val.append(gr.HTML())
ret_val.append(gr.Markdown())
else:
ret_val = [gr.Button(visible=not hide_sample_ques) for _ in self._sample_question_components]
[ret_val.append(gr.Row(visible=not hide_sample_ques)) for _ in self._sample_question_rows]
ret_val.append(gr.HTML(visible=not hide_sample_ques))
ret_val.append(gr.Markdown(visible=not hide_sample_ques))
return ret_val
def _get_sample_question_components(self):
sample_questions_buttons = [question['component'] for question in self._sample_question_components]
return sample_questions_buttons + self._sample_question_rows + [self._sample_question_empty_space_component, self._sample_qustion_default_dataset_markdown]
def _get_show_hide_sample_questions_outputs(self):
return [self._chat_bot_window] + self._get_sample_question_components() + [self._state]
# chat bot events
def _handle_chatbot_events(self):
def process_input(query, history, request: gr.Request):
self._validate_session(request)
if len(query) == 0:
return "", history
history.append([query, None])
return "", history
def process_output(history, state, request: gr.Request):
self._validate_session(request)
if len(history) == 0:
yield history, state
else:
query = history[-1]
if query[1] != None:
yield history, state
elif self._query_handler:
for response in self._query_handler(query[0], history[:-1], self._get_session_id(state)):
history[-1][1] = response
yield history, state
else:
history[-1][1] = "ChatBot not ready..."
yield history, state
#undo handler
def process_undo_last_chat(history: list, state, request: gr.Request):
self._validate_session(request)
if len(history) == 0:
return history, state
history = history[:len(history) - 1]
if self._undo_last_chat_callback:
self._undo_last_chat_callback(history, self._get_session_id(state))
return history, state
#retry handler
def process_retry(history: list, request: gr.Request):
self._validate_session(request)
if len(history) == 0:
return history
lastChat = history[-1]
history = history[:len(history) - 1]
_, history = process_input(lastChat[0], history, request)
return history
def reset(state, request: gr.Request):
self._validate_session(request)
if self._reset_chat_callback:
self._reset_chat_callback(self._get_session_id(state))
return "", [], state
gr.on(
[self._chat_query_input_textbox.submit, self._chat_submit_button.click],
self._validate_session,
None,
self._get_validate_session_output()
).then(
self._validate_session_and_raise,
None,
None
).success(
self._show_hide_sample_questions,
self._get_show_hide_sample_questions_inputs(),
self._get_show_hide_sample_questions_outputs()
).then(
process_input,
[self._chat_query_input_textbox, self._chat_bot_window],
[self._chat_query_input_textbox, self._chat_bot_window]
).then(
process_output,
[self._chat_bot_window, self._state],
[self._chat_bot_window, self._state]
)
self._chat_retry_button.click(
self._validate_session,
None,
self._get_validate_session_output()
).then(
self._validate_session_and_raise,
None,
None
).success(
process_retry,
[self._chat_bot_window],
[self._chat_bot_window]
).then(
process_output,
[self._chat_bot_window, self._state],
[self._chat_bot_window, self._state]
)
if self._chat_undo_button:
self._chat_undo_button.click(
self._validate_session,
None,
self._get_validate_session_output()
).then(
self._validate_session_and_raise,
None,
None
).success(
process_undo_last_chat,
[self._chat_bot_window, self._state],
[self._chat_bot_window, self._state]
)
self._chat_reset_button.click(
self._validate_session,
None,
self._get_validate_session_output()
).then(
self._validate_session_and_raise,
None,
None
).success(
reset,
self._state,
[self._chat_query_input_textbox, self._chat_bot_window, self._state]
)
def handle_sample_question_click(evt: gr.EventData, request: gr.Request):
self._validate_session(request)
return evt.target.value
for sample in self._sample_question_components:
button: gr.Button = sample['component']
button.click(
self._validate_session,
None,
self._get_validate_session_output()
).then(
self._validate_session_and_raise,
None,
None
).success(
handle_sample_question_click,
None,
self._chat_query_input_textbox
).then(
self._show_hide_sample_questions,
self._get_show_hide_sample_questions_inputs(),
self._get_show_hide_sample_questions_outputs()
).then(
process_input,
[self._chat_query_input_textbox, self._chat_bot_window],
[self._chat_query_input_textbox, self._chat_bot_window]
).then(
process_output,
[self._chat_bot_window, self._state],
[self._chat_bot_window, self._state]
)