from typing import List, Tuple, Optional import gradio as gr from langchain_core.vectorstores import VectorStore from config import ( LLM_MODEL_REPOS, EMBED_MODEL_REPOS, SUBTITLES_LANGUAGES, GENERATE_KWARGS, CONTEXT_TEMPLATE, ) from utils import ( load_llm_model, load_embed_model, load_documents_and_create_db, user_message_to_chatbot, update_user_message_with_context, get_llm_response, get_gguf_model_names, add_new_model_repo, clear_llm_folder, clear_embed_folder, get_memory_usage, ) # ============ INTERFACE COMPONENT INITIALIZATION FUNCS ============ def get_rag_mode_component(db: Optional[VectorStore]) -> gr.Checkbox: value = visible = db is not None return gr.Checkbox(value=value, label='RAG Mode', scale=1, visible=visible) def get_rag_settings( rag_mode: bool, context_template_value: str, render: bool = True, ) -> Tuple[gr.component, ...]: k = gr.Radio( choices=[1, 2, 3, 4, 5, 'all'], value=2, label='Number of relevant documents for search', visible=rag_mode, render=render, ) score_threshold = gr.Slider( minimum=0, maximum=1, value=0.5, step=0.05, label='relevance_scores_threshold', visible=rag_mode, render=render, ) context_template = gr.Textbox( value=context_template_value, label='Context Template', lines=len(context_template_value.split('\n')), visible=rag_mode, render=render, ) return k, score_threshold, context_template def get_user_message_with_context(text: str, rag_mode: bool) -> gr.component: num_lines = len(text.split('\n')) max_lines = 10 num_lines = max_lines if num_lines > max_lines else num_lines return gr.Textbox( text, visible=rag_mode, interactive=False, label='User Message With Context', lines=num_lines, ) def get_system_prompt_component(interactive: bool) -> gr.Textbox: value = '' if interactive else 'System prompt is not supported by this model' return gr.Textbox(value=value, label='System prompt', interactive=interactive) def get_generate_args(do_sample: bool) -> List[gr.component]: generate_args = [ gr.Slider(minimum=0.1, maximum=3, value=GENERATE_KWARGS['temperature'], step=0.1, label='temperature', visible=do_sample), gr.Slider(minimum=0, maximum=1, value=GENERATE_KWARGS['top_p'], step=0.01, label='top_p', visible=do_sample), gr.Slider(minimum=1, maximum=50, value=GENERATE_KWARGS['top_k'], step=1, label='top_k', visible=do_sample), gr.Slider(minimum=1, maximum=5, value=GENERATE_KWARGS['repeat_penalty'], step=0.1, label='repeat_penalty', visible=do_sample), ] return generate_args # ================ LOADING AND INITIALIZING MODELS ======================== start_llm_model, start_support_system_role, load_log = load_llm_model(LLM_MODEL_REPOS[0], 'gemma-2-2b-it-Q8_0.gguf') start_embed_model, load_log = load_embed_model(EMBED_MODEL_REPOS[0]) # ================== APPLICATION WEB INTERFACE ============================ css = '''.gradio-container {width: 60% !important}''' with gr.Blocks(css=css) as interface: # ==================== GRADIO STATES =============================== documents = gr.State([]) db = gr.State(None) user_message_with_context = gr.State('') support_system_role = gr.State(start_support_system_role) llm_model_repos = gr.State(LLM_MODEL_REPOS) embed_model_repos = gr.State(EMBED_MODEL_REPOS) llm_model = gr.State(start_llm_model) embed_model = gr.State(start_embed_model) # ==================== BOT PAGE ================================= with gr.Tab(label='Chatbot'): with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot( type='messages', # new in gradio 5+ show_copy_button=True, bubble_full_width=False, height=480, ) user_message = gr.Textbox(label='User') with gr.Row(): user_message_btn = gr.Button('Send') stop_btn = gr.Button('Stop') clear_btn = gr.Button('Clear') # ------------- GENERATION PARAMETERS ------------------- with gr.Column(scale=1, min_width=80): with gr.Group(): gr.Markdown('History size') history_len = gr.Slider( minimum=0, maximum=5, value=0, step=1, info='Number of previous messages taken into account in history', label='history_len', show_label=False, ) with gr.Group(): gr.Markdown('Generation parameters') do_sample = gr.Checkbox( value=False, label='do_sample', info='Activate random sampling', ) generate_args = get_generate_args(do_sample.value) do_sample.change( fn=get_generate_args, inputs=do_sample, outputs=generate_args, show_progress=False, ) rag_mode = get_rag_mode_component(db=db.value) k, score_threshold, context_template = get_rag_settings( rag_mode=rag_mode.value, context_template_value=CONTEXT_TEMPLATE, render=False, ) rag_mode.change( fn=get_rag_settings, inputs=[rag_mode, context_template], outputs=[k, score_threshold, context_template], ) with gr.Row(): k.render() score_threshold.render() # ---------------- SYSTEM PROMPT AND USER MESSAGE ----------- with gr.Accordion('Prompt', open=True): system_prompt = get_system_prompt_component(interactive=support_system_role.value) context_template.render() user_message_with_context = get_user_message_with_context(text='', rag_mode=rag_mode.value) # ---------------- SEND, CLEAR AND STOP BUTTONS ------------ generate_event = gr.on( triggers=[user_message.submit, user_message_btn.click], fn=user_message_to_chatbot, inputs=[user_message, chatbot], outputs=[user_message, chatbot], queue=False, ).then( fn=update_user_message_with_context, inputs=[chatbot, rag_mode, db, k, score_threshold, context_template], outputs=[user_message_with_context], ).then( fn=get_user_message_with_context, inputs=[user_message_with_context, rag_mode], outputs=[user_message_with_context], ).then( fn=get_llm_response, inputs=[chatbot, llm_model, user_message_with_context, rag_mode, system_prompt, support_system_role, history_len, do_sample, *generate_args], outputs=[chatbot], ) stop_btn.click( fn=None, inputs=None, outputs=None, cancels=generate_event, queue=False, ) clear_btn.click( fn=lambda: (None, ''), inputs=None, outputs=[chatbot, user_message_with_context], queue=False, ) # ================= FILE DOWNLOAD PAGE ========================= with gr.Tab(label='Load documents'): with gr.Row(variant='compact'): upload_files = gr.File(file_count='multiple', label='Loading text files') web_links = gr.Textbox(lines=6, label='Links to Web sites or YouTube') with gr.Row(variant='compact'): chunk_size = gr.Slider(50, 2000, value=500, step=50, label='Chunk size') chunk_overlap = gr.Slider(0, 200, value=20, step=10, label='Chunk overlap') subtitles_lang = gr.Radio( SUBTITLES_LANGUAGES, value=SUBTITLES_LANGUAGES[0], label='YouTube subtitle language', ) load_documents_btn = gr.Button(value='Upload documents and initialize database') load_docs_log = gr.Textbox(label='Status of loading and splitting documents', interactive=False) load_documents_btn.click( fn=load_documents_and_create_db, inputs=[upload_files, web_links, subtitles_lang, chunk_size, chunk_overlap, embed_model], outputs=[documents, db, load_docs_log], ).success( fn=get_rag_mode_component, inputs=[db], outputs=[rag_mode], ) gr.HTML("""