jerome-white commited on
Commit
ef3d4ad
1 Parent(s): 9848c62

Initial commit

Browse files
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from tempfile import NamedTemporaryFile
4
+
5
+ import gradio as gr
6
+ from openai import OpenAI
7
+
8
+ from mylib import (
9
+ Logger,
10
+ FileManager,
11
+ ChatController,
12
+ MessageHandler,
13
+ NumericCitations,
14
+ )
15
+
16
+ #
17
+ #
18
+ #
19
+ class ErrorLogger:
20
+ def __init__(self, path):
21
+ self.path = path
22
+ if not self.path.exists():
23
+ self.path.mkdir(parents=True, exist_ok=True)
24
+
25
+ def dump(self, prompt, error):
26
+ msg = {
27
+ 'prompt': prompt,
28
+ }
29
+ msg.update(error.to_dict())
30
+ output = json.dumps(msg, indent=2)
31
+
32
+ with NamedTemporaryFile(mode='w',
33
+ prefix='',
34
+ dir=self.path,
35
+ delete=False) as fp:
36
+ print(output, file=fp)
37
+ return fp.name
38
+
39
+ #
40
+ #
41
+ #
42
+ class FileChat:
43
+ def __init__(self, client, config):
44
+ self.database = FileManager(client, config['chat']['prefix'])
45
+ self.messenger = MessageHandler(client, NumericCitations)
46
+ self.chat = ChatController(client, config['openai'], config['chat'])
47
+
48
+ def upload(self, *args):
49
+ (data, ) = args
50
+ return self.database(data)
51
+
52
+ def prompt(self, *args):
53
+ (message, *_) = args
54
+ if not self.database:
55
+ raise gr.Error('Please upload your documents to begin')
56
+
57
+ return self.messenger(self.chat(message, self.database))
58
+
59
+ #
60
+ #
61
+ #
62
+ with open(os.getenv('FILE_CHAT_CONFIG')) as fp:
63
+ config = json.load(fp)
64
+
65
+ with gr.Blocks() as demo:
66
+ client = OpenAI(api_key=config['openai']['api_key'])
67
+ mychat = FileChat(client, config)
68
+
69
+ with gr.Row():
70
+ upload = gr.UploadButton(file_count='multiple')
71
+ text = gr.Textbox(label='Files uploaded', interactive=False)
72
+ upload.upload(mychat.upload, upload, text)
73
+
74
+ gr.ChatInterface(
75
+ fn=mychat.prompt,
76
+ additional_inputs=[
77
+ upload,
78
+ text,
79
+ ],
80
+ retry_btn=None,
81
+ undo_btn=None,
82
+ clear_btn=None,
83
+ # additional_inputs_accordion=gr.Accordion(
84
+ # label='Upload documents',
85
+ # open=True,
86
+ # ),
87
+ )
88
+
89
+ if __name__ == '__main__':
90
+ # demo.queue().launch(server_name='0.0.0.0', **config['gradio'])
91
+ demo.queue().launch(server_name='0.0.0.0')
mylib/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from ._chat import ChatController
2
+ from ._files import FileManager
3
+ from ._logging import Logger
4
+ from ._message import MessageHandler
5
+ from ._citations import NumericCitations, NoCitations
mylib/_chat.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import time
3
+ from pathlib import Path
4
+
5
+ import pandas as pd
6
+
7
+ from ._logging import Logger
8
+
9
+ def parse_wait_time(err):
10
+ if err.code == 'rate_limit_exceeded':
11
+ for i in err.message.split('. '):
12
+ if i.startswith('Please try again in'):
13
+ (*_, wait) = i.split()
14
+ return (pd
15
+ .to_timedelta(wait)
16
+ .total_seconds())
17
+
18
+ raise TypeError(err.code)
19
+
20
+ class ChatController:
21
+ _gpt_defaults = {
22
+ 'model': 'gpt-4o',
23
+ 'max_completion_tokens': 2 ** 12,
24
+ }
25
+
26
+ def __init__(self, client, gpt, chat):
27
+ self.client = client
28
+ self.gpt = gpt
29
+ self.chat = chat
30
+
31
+ for i in self._gpt_defaults.items():
32
+ self.gpt.setdefault(*i)
33
+ instructions = Path(self.chat['system_prompt'])
34
+
35
+ self.assistant = self.client.beta.assistants.create(
36
+ name=self.gpt['assistant_name'],
37
+ model=self.gpt['model'],
38
+ instructions=instructions.read_text(),
39
+ temperature=0.1,
40
+ tools=[{
41
+ 'type': 'file_search',
42
+ }],
43
+ )
44
+ self.thread = self.client.beta.threads.create()
45
+ self.attached = False
46
+
47
+ def __call__(self, prompt, database):
48
+ if not self.attached:
49
+ self.client.beta.assistants.update(
50
+ assistant_id=self.assistant.id,
51
+ tool_resources={
52
+ 'file_search': {
53
+ 'vector_store_ids': [
54
+ database.vector_store_id,
55
+ ],
56
+ },
57
+ },
58
+ )
59
+ self.attached = True
60
+
61
+ return self.send(prompt)
62
+
63
+ def send(self, content):
64
+ self.client.beta.threads.messages.create(
65
+ self.thread.id,
66
+ role='user',
67
+ content=content,
68
+ )
69
+
70
+ for i in range(self.chat['retries']):
71
+ run = self.client.beta.threads.runs.create_and_poll(
72
+ thread_id=self.thread.id,
73
+ assistant_id=self.assistant.id,
74
+ )
75
+ if run.status == 'completed':
76
+ return self.client.beta.threads.messages.list(
77
+ thread_id=self.thread.id,
78
+ run_id=run.id,
79
+ )
80
+ Logger.error('%s (%d): %s', run.status, i + 1, run.last_error)
81
+
82
+ rest = math.ceil(parse_wait_time(run.last_error))
83
+ Logger.warning('Sleeping %ds', rest)
84
+ time.sleep(rest)
85
+
86
+ raise TimeoutError('Message retries exceeded')
mylib/_citations.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class CitationManager:
2
+ def __init__(self, annotations, client, start=1):
3
+ self.start = start
4
+ self.body = {}
5
+ self.citations = []
6
+
7
+ for a in annotations:
8
+ reference = f'[{start}]'
9
+ self.body[a.text] = reference
10
+ document = client.files.retrieve(a.file_citation.file_id)
11
+ self.citations.append('{} {}:{}--{}'.format(
12
+ reference,
13
+ document.filename,
14
+ a.start_index,
15
+ a.end_index,
16
+ ))
17
+ start += 1
18
+
19
+ def __len__(self):
20
+ return len(self.citations)
21
+
22
+ def __str__(self):
23
+ raise NotImplementedError()
24
+
25
+ def __iter__(self):
26
+ raise NotImplementedError()
27
+
28
+ def replace(self, body):
29
+ for i in self:
30
+ body = body.replace(*i)
31
+
32
+ return body
33
+
34
+ class NumericCitations(CitationManager):
35
+ def __str__(self):
36
+ return '\n\n{}'.format('\n'.join(self.citations))
37
+
38
+ def __iter__(self):
39
+ for (k, v) in self.body.items():
40
+ yield (k, f' {v}')
41
+
42
+ class NoCitations(CitationManager):
43
+ def __str__(self):
44
+ return ''
45
+
46
+ def __iter__(self):
47
+ yield from zip(self.body, it.repeat(''))
mylib/_files.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ import hashlib
3
+ import warnings
4
+ import itertools as it
5
+ import functools as ft
6
+ from pathlib import Path
7
+
8
+ class FileObject:
9
+ _window = 20
10
+
11
+ def __init__(self, path):
12
+ self.fp = path.open('rb')
13
+ self.chunk = 2 ** self._window
14
+
15
+ def close(self):
16
+ self.fp.close()
17
+
18
+ @ft.cached_property
19
+ def checksum(self):
20
+ csum = hashlib.blake2b()
21
+
22
+ while True:
23
+ data = self.fp.read(self.chunk)
24
+ if not data:
25
+ break
26
+ csum.update(data)
27
+ self.fp.seek(0)
28
+
29
+ return csum.hexdigest()
30
+
31
+ class FileStream:
32
+ def __init__(self, paths):
33
+ self.paths = paths
34
+ self.streams = []
35
+
36
+ def __len__(self):
37
+ return len(self.streams)
38
+
39
+ def __iter__(self):
40
+ for p in self.paths:
41
+ stream = FileObject(p)
42
+ self.streams.append(stream)
43
+ yield stream
44
+
45
+ def __enter__(self):
46
+ return self
47
+
48
+ def __exit__(self, exc_type, exc_value, traceback):
49
+ for s in self.streams:
50
+ s.close()
51
+ self.streams.clear()
52
+
53
+ class FileManager:
54
+ def __init__(self, client, prefix, batch_size=20):
55
+ self.client = client
56
+ self.prefix = prefix
57
+ self.batch_size = batch_size
58
+
59
+ self.storage = set()
60
+ self.vector_store_id = None
61
+
62
+ def __bool__(self):
63
+ return self.vector_store_id is not None
64
+
65
+ def __iter__(self):
66
+ if self:
67
+ kwargs = {}
68
+ while True:
69
+ vs_files = self.client.beta.vector_stores.files.list(
70
+ vector_store_id=self.vector_store_id,
71
+ **kwargs,
72
+ )
73
+ for f in vs_files.data:
74
+ result = self.client.files.retrieve(f.id)
75
+ yield result.filename
76
+
77
+ if not vs_files.has_more:
78
+ break
79
+ kwargs['after'] = vs_files.after
80
+
81
+ def __call__(self, paths):
82
+ files = []
83
+ self.test_and_setup()
84
+
85
+ for p in self.ls(paths):
86
+ with FileStream(p) as stream:
87
+ for s in stream:
88
+ if s.checksum not in self.storage:
89
+ files.append(s.fp)
90
+ self.storage.add(s.checksum)
91
+ if files:
92
+ self.put(files)
93
+ files.clear()
94
+
95
+ return '\n'.join(self)
96
+
97
+ def test_and_setup(self):
98
+ if self:
99
+ msg = f'Vector store already exists ({self.vector_store_id})'
100
+ warnings.warn(msg)
101
+ else:
102
+ name = f'{self.prefix}{uuid.uuid4()}'
103
+ vector_store = self.client.beta.vector_stores.create(
104
+ name=name,
105
+ )
106
+ self.vector_store_id = vector_store.id
107
+
108
+ def ls(self, paths):
109
+ left = 0
110
+ while left < len(paths):
111
+ right = left + self.batch_size
112
+ yield list(map(Path, it.islice(paths, left, right)))
113
+ left = right
114
+
115
+ def put(self, files):
116
+ batch = self.client.beta.vector_stores.file_batches.upload_and_poll(
117
+ vector_store_id=self.vector_store_id,
118
+ files=files,
119
+ )
120
+ if batch.file_counts.completed != len(files):
121
+ err = f'Error uploading documents: {batch.file_counts}'
122
+ raise InterruptedError(err)
mylib/_logging.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+ #
5
+ #
6
+ #
7
+ logging.basicConfig(
8
+ format='[ %(asctime)s %(levelname)s %(filename)s ] %(message)s',
9
+ datefmt='%H:%M:%S',
10
+ level=os.environ.get('PYTHONLOGLEVEL', 'WARNING').upper(),
11
+ )
12
+ logging.captureWarnings(True)
13
+ Logger = logging.getLogger(__name__)
mylib/_message.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._citations import NoCitations
2
+
3
+ class MessageHandler:
4
+ def __init__(self, client, citecls=None):
5
+ self.client = client
6
+ self.citecls = citecls or NoCitations
7
+
8
+ def __call__(self, message):
9
+ return '\n'.join(self.each(message))
10
+
11
+ def each(self, message):
12
+ refn = 1
13
+
14
+ for m in message:
15
+ for c in m.content:
16
+ cite = self.citecls(c.text.annotations, self.client, refn)
17
+ body = cite.replace(c.text.value)
18
+ refn = len(cite) + 1
19
+
20
+ yield f'{body}{cite}'
prompts/system.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ You are an expert file search assistant
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio
2
+ openai
3
+ pandas