File size: 3,033 Bytes
ef3d4ad
 
53ebd47
529dafe
a5aec38
ef3d4ad
 
 
 
 
 
 
 
 
 
 
 
 
 
529dafe
ef3d4ad
53ebd47
 
 
 
 
ef3d4ad
 
 
529dafe
53ebd47
529dafe
 
 
 
 
 
 
 
ef3d4ad
e946c57
 
32d4c53
e946c57
529dafe
a9490de
 
 
 
ef3d4ad
529dafe
a3c275c
 
 
 
 
 
 
 
529dafe
e776497
 
 
 
ef3d4ad
 
 
 
 
e946c57
 
 
 
a5aec38
 
 
 
 
ef3d4ad
a5aec38
529dafe
01851b4
 
 
 
529dafe
 
 
 
 
 
 
e946c57
 
 
 
529dafe
 
ef3d4ad
529dafe
8d48750
 
 
 
a9490de
24a698a
 
 
529dafe
 
e946c57
 
 
 
 
e776497
 
 
 
529dafe
ef3d4ad
 
53ebd47
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import json
import functools as ft
import collections as cl
from pathlib import Path

import gradio as gr
from openai import OpenAI

from mylib import (
    FileManager,
    ChatController,
    MessageHandler,
    NumericCitations,
)

#
#
#
ChatState = cl.namedtuple('ChatState', 'database, messenger, chat')

@ft.cache
def scancfg():
    with open(os.getenv('FILE_CHAT_CONFIG')) as fp:
        return json.load(fp)

#
#
#
def load():
    config = scancfg()
    (_openai, _chat) = map(config.get, ('openai', 'chat'))
    client = OpenAI(api_key=_openai['api_key'])

    database = FileManager(client, _chat['prefix'])
    messenger = MessageHandler(client, NumericCitations)
    chat = ChatController(client, database, _openai, _chat)

    return ChatState(database, messenger, chat)

def eject(state):
    state.database.cleanup()
    state.chat.cleanup()

def upload(data, state):
    try:
        return state.database(data)
    except InterruptedError as err:
        raise gr.Error(str(err))

def prompt(message, history, state):
    if state.database:
        response = state.messenger(state.chat(message))
        history.append((
            message,
            response,
        ))
    else:
        gr.Warning('Please upload your documents to begin')

    return (     # textbox submit outputs
        '',      # clear the input text
        history, # update the chat output
    )

#
#
#
with gr.Blocks() as demo:
    state = gr.State(
        value=load,
        delete_callback=eject,
    )
    howto = Path('static/howto').with_suffix('.md')

    with gr.Row():
        with gr.Accordion(label='Instructions', open=False):
            gr.Markdown(howto.read_text())
    with gr.Row():

        with gr.Column():
            data = gr.UploadButton(
                label='Select and upload your files',
                file_count='multiple',
            )
            repository = gr.Textbox(
                label='Files uploaded',
                placeholder='Upload your files to begin!',
                interactive=False,
            )
            data.upload(
                fn=upload,
                inputs=[
                    data,
                    state,
                ],
                outputs=repository,
            )

        with gr.Column(scale=2):
            chatbot = gr.Chatbot(
                height='70vh',
                show_copy_button=True,
            )
            chatbot.change(scroll_to_output=True)
            interaction = gr.Textbox(
                label='Ask a question about your documents and press "Enter"',
            )
            interaction.submit(
                fn=prompt,
                inputs=[
                    interaction,
                    chatbot,
                    state,
                ],
                outputs=[
                    interaction,
                    chatbot,
                ],
            )

if __name__ == '__main__':
    kwargs = scancfg().get('gradio')
    demo.queue().launch(server_name='0.0.0.0', **kwargs)