Spaces:
Runtime error
Runtime error
jerome-white
commited on
Commit
•
ef3d4ad
1
Parent(s):
9848c62
Initial commit
Browse files- app.py +91 -0
- mylib/__init__.py +5 -0
- mylib/_chat.py +86 -0
- mylib/_citations.py +47 -0
- mylib/_files.py +122 -0
- mylib/_logging.py +13 -0
- mylib/_message.py +20 -0
- prompts/system.txt +1 -0
- requirements.txt +3 -0
app.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from tempfile import NamedTemporaryFile
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
from openai import OpenAI
|
7 |
+
|
8 |
+
from mylib import (
|
9 |
+
Logger,
|
10 |
+
FileManager,
|
11 |
+
ChatController,
|
12 |
+
MessageHandler,
|
13 |
+
NumericCitations,
|
14 |
+
)
|
15 |
+
|
16 |
+
#
|
17 |
+
#
|
18 |
+
#
|
19 |
+
class ErrorLogger:
|
20 |
+
def __init__(self, path):
|
21 |
+
self.path = path
|
22 |
+
if not self.path.exists():
|
23 |
+
self.path.mkdir(parents=True, exist_ok=True)
|
24 |
+
|
25 |
+
def dump(self, prompt, error):
|
26 |
+
msg = {
|
27 |
+
'prompt': prompt,
|
28 |
+
}
|
29 |
+
msg.update(error.to_dict())
|
30 |
+
output = json.dumps(msg, indent=2)
|
31 |
+
|
32 |
+
with NamedTemporaryFile(mode='w',
|
33 |
+
prefix='',
|
34 |
+
dir=self.path,
|
35 |
+
delete=False) as fp:
|
36 |
+
print(output, file=fp)
|
37 |
+
return fp.name
|
38 |
+
|
39 |
+
#
|
40 |
+
#
|
41 |
+
#
|
42 |
+
class FileChat:
|
43 |
+
def __init__(self, client, config):
|
44 |
+
self.database = FileManager(client, config['chat']['prefix'])
|
45 |
+
self.messenger = MessageHandler(client, NumericCitations)
|
46 |
+
self.chat = ChatController(client, config['openai'], config['chat'])
|
47 |
+
|
48 |
+
def upload(self, *args):
|
49 |
+
(data, ) = args
|
50 |
+
return self.database(data)
|
51 |
+
|
52 |
+
def prompt(self, *args):
|
53 |
+
(message, *_) = args
|
54 |
+
if not self.database:
|
55 |
+
raise gr.Error('Please upload your documents to begin')
|
56 |
+
|
57 |
+
return self.messenger(self.chat(message, self.database))
|
58 |
+
|
59 |
+
#
|
60 |
+
#
|
61 |
+
#
|
62 |
+
with open(os.getenv('FILE_CHAT_CONFIG')) as fp:
|
63 |
+
config = json.load(fp)
|
64 |
+
|
65 |
+
with gr.Blocks() as demo:
|
66 |
+
client = OpenAI(api_key=config['openai']['api_key'])
|
67 |
+
mychat = FileChat(client, config)
|
68 |
+
|
69 |
+
with gr.Row():
|
70 |
+
upload = gr.UploadButton(file_count='multiple')
|
71 |
+
text = gr.Textbox(label='Files uploaded', interactive=False)
|
72 |
+
upload.upload(mychat.upload, upload, text)
|
73 |
+
|
74 |
+
gr.ChatInterface(
|
75 |
+
fn=mychat.prompt,
|
76 |
+
additional_inputs=[
|
77 |
+
upload,
|
78 |
+
text,
|
79 |
+
],
|
80 |
+
retry_btn=None,
|
81 |
+
undo_btn=None,
|
82 |
+
clear_btn=None,
|
83 |
+
# additional_inputs_accordion=gr.Accordion(
|
84 |
+
# label='Upload documents',
|
85 |
+
# open=True,
|
86 |
+
# ),
|
87 |
+
)
|
88 |
+
|
89 |
+
if __name__ == '__main__':
|
90 |
+
# demo.queue().launch(server_name='0.0.0.0', **config['gradio'])
|
91 |
+
demo.queue().launch(server_name='0.0.0.0')
|
mylib/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ._chat import ChatController
|
2 |
+
from ._files import FileManager
|
3 |
+
from ._logging import Logger
|
4 |
+
from ._message import MessageHandler
|
5 |
+
from ._citations import NumericCitations, NoCitations
|
mylib/_chat.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import time
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
from ._logging import Logger
|
8 |
+
|
9 |
+
def parse_wait_time(err):
|
10 |
+
if err.code == 'rate_limit_exceeded':
|
11 |
+
for i in err.message.split('. '):
|
12 |
+
if i.startswith('Please try again in'):
|
13 |
+
(*_, wait) = i.split()
|
14 |
+
return (pd
|
15 |
+
.to_timedelta(wait)
|
16 |
+
.total_seconds())
|
17 |
+
|
18 |
+
raise TypeError(err.code)
|
19 |
+
|
20 |
+
class ChatController:
|
21 |
+
_gpt_defaults = {
|
22 |
+
'model': 'gpt-4o',
|
23 |
+
'max_completion_tokens': 2 ** 12,
|
24 |
+
}
|
25 |
+
|
26 |
+
def __init__(self, client, gpt, chat):
|
27 |
+
self.client = client
|
28 |
+
self.gpt = gpt
|
29 |
+
self.chat = chat
|
30 |
+
|
31 |
+
for i in self._gpt_defaults.items():
|
32 |
+
self.gpt.setdefault(*i)
|
33 |
+
instructions = Path(self.chat['system_prompt'])
|
34 |
+
|
35 |
+
self.assistant = self.client.beta.assistants.create(
|
36 |
+
name=self.gpt['assistant_name'],
|
37 |
+
model=self.gpt['model'],
|
38 |
+
instructions=instructions.read_text(),
|
39 |
+
temperature=0.1,
|
40 |
+
tools=[{
|
41 |
+
'type': 'file_search',
|
42 |
+
}],
|
43 |
+
)
|
44 |
+
self.thread = self.client.beta.threads.create()
|
45 |
+
self.attached = False
|
46 |
+
|
47 |
+
def __call__(self, prompt, database):
|
48 |
+
if not self.attached:
|
49 |
+
self.client.beta.assistants.update(
|
50 |
+
assistant_id=self.assistant.id,
|
51 |
+
tool_resources={
|
52 |
+
'file_search': {
|
53 |
+
'vector_store_ids': [
|
54 |
+
database.vector_store_id,
|
55 |
+
],
|
56 |
+
},
|
57 |
+
},
|
58 |
+
)
|
59 |
+
self.attached = True
|
60 |
+
|
61 |
+
return self.send(prompt)
|
62 |
+
|
63 |
+
def send(self, content):
|
64 |
+
self.client.beta.threads.messages.create(
|
65 |
+
self.thread.id,
|
66 |
+
role='user',
|
67 |
+
content=content,
|
68 |
+
)
|
69 |
+
|
70 |
+
for i in range(self.chat['retries']):
|
71 |
+
run = self.client.beta.threads.runs.create_and_poll(
|
72 |
+
thread_id=self.thread.id,
|
73 |
+
assistant_id=self.assistant.id,
|
74 |
+
)
|
75 |
+
if run.status == 'completed':
|
76 |
+
return self.client.beta.threads.messages.list(
|
77 |
+
thread_id=self.thread.id,
|
78 |
+
run_id=run.id,
|
79 |
+
)
|
80 |
+
Logger.error('%s (%d): %s', run.status, i + 1, run.last_error)
|
81 |
+
|
82 |
+
rest = math.ceil(parse_wait_time(run.last_error))
|
83 |
+
Logger.warning('Sleeping %ds', rest)
|
84 |
+
time.sleep(rest)
|
85 |
+
|
86 |
+
raise TimeoutError('Message retries exceeded')
|
mylib/_citations.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class CitationManager:
|
2 |
+
def __init__(self, annotations, client, start=1):
|
3 |
+
self.start = start
|
4 |
+
self.body = {}
|
5 |
+
self.citations = []
|
6 |
+
|
7 |
+
for a in annotations:
|
8 |
+
reference = f'[{start}]'
|
9 |
+
self.body[a.text] = reference
|
10 |
+
document = client.files.retrieve(a.file_citation.file_id)
|
11 |
+
self.citations.append('{} {}:{}--{}'.format(
|
12 |
+
reference,
|
13 |
+
document.filename,
|
14 |
+
a.start_index,
|
15 |
+
a.end_index,
|
16 |
+
))
|
17 |
+
start += 1
|
18 |
+
|
19 |
+
def __len__(self):
|
20 |
+
return len(self.citations)
|
21 |
+
|
22 |
+
def __str__(self):
|
23 |
+
raise NotImplementedError()
|
24 |
+
|
25 |
+
def __iter__(self):
|
26 |
+
raise NotImplementedError()
|
27 |
+
|
28 |
+
def replace(self, body):
|
29 |
+
for i in self:
|
30 |
+
body = body.replace(*i)
|
31 |
+
|
32 |
+
return body
|
33 |
+
|
34 |
+
class NumericCitations(CitationManager):
|
35 |
+
def __str__(self):
|
36 |
+
return '\n\n{}'.format('\n'.join(self.citations))
|
37 |
+
|
38 |
+
def __iter__(self):
|
39 |
+
for (k, v) in self.body.items():
|
40 |
+
yield (k, f' {v}')
|
41 |
+
|
42 |
+
class NoCitations(CitationManager):
|
43 |
+
def __str__(self):
|
44 |
+
return ''
|
45 |
+
|
46 |
+
def __iter__(self):
|
47 |
+
yield from zip(self.body, it.repeat(''))
|
mylib/_files.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uuid
|
2 |
+
import hashlib
|
3 |
+
import warnings
|
4 |
+
import itertools as it
|
5 |
+
import functools as ft
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
class FileObject:
|
9 |
+
_window = 20
|
10 |
+
|
11 |
+
def __init__(self, path):
|
12 |
+
self.fp = path.open('rb')
|
13 |
+
self.chunk = 2 ** self._window
|
14 |
+
|
15 |
+
def close(self):
|
16 |
+
self.fp.close()
|
17 |
+
|
18 |
+
@ft.cached_property
|
19 |
+
def checksum(self):
|
20 |
+
csum = hashlib.blake2b()
|
21 |
+
|
22 |
+
while True:
|
23 |
+
data = self.fp.read(self.chunk)
|
24 |
+
if not data:
|
25 |
+
break
|
26 |
+
csum.update(data)
|
27 |
+
self.fp.seek(0)
|
28 |
+
|
29 |
+
return csum.hexdigest()
|
30 |
+
|
31 |
+
class FileStream:
|
32 |
+
def __init__(self, paths):
|
33 |
+
self.paths = paths
|
34 |
+
self.streams = []
|
35 |
+
|
36 |
+
def __len__(self):
|
37 |
+
return len(self.streams)
|
38 |
+
|
39 |
+
def __iter__(self):
|
40 |
+
for p in self.paths:
|
41 |
+
stream = FileObject(p)
|
42 |
+
self.streams.append(stream)
|
43 |
+
yield stream
|
44 |
+
|
45 |
+
def __enter__(self):
|
46 |
+
return self
|
47 |
+
|
48 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
49 |
+
for s in self.streams:
|
50 |
+
s.close()
|
51 |
+
self.streams.clear()
|
52 |
+
|
53 |
+
class FileManager:
|
54 |
+
def __init__(self, client, prefix, batch_size=20):
|
55 |
+
self.client = client
|
56 |
+
self.prefix = prefix
|
57 |
+
self.batch_size = batch_size
|
58 |
+
|
59 |
+
self.storage = set()
|
60 |
+
self.vector_store_id = None
|
61 |
+
|
62 |
+
def __bool__(self):
|
63 |
+
return self.vector_store_id is not None
|
64 |
+
|
65 |
+
def __iter__(self):
|
66 |
+
if self:
|
67 |
+
kwargs = {}
|
68 |
+
while True:
|
69 |
+
vs_files = self.client.beta.vector_stores.files.list(
|
70 |
+
vector_store_id=self.vector_store_id,
|
71 |
+
**kwargs,
|
72 |
+
)
|
73 |
+
for f in vs_files.data:
|
74 |
+
result = self.client.files.retrieve(f.id)
|
75 |
+
yield result.filename
|
76 |
+
|
77 |
+
if not vs_files.has_more:
|
78 |
+
break
|
79 |
+
kwargs['after'] = vs_files.after
|
80 |
+
|
81 |
+
def __call__(self, paths):
|
82 |
+
files = []
|
83 |
+
self.test_and_setup()
|
84 |
+
|
85 |
+
for p in self.ls(paths):
|
86 |
+
with FileStream(p) as stream:
|
87 |
+
for s in stream:
|
88 |
+
if s.checksum not in self.storage:
|
89 |
+
files.append(s.fp)
|
90 |
+
self.storage.add(s.checksum)
|
91 |
+
if files:
|
92 |
+
self.put(files)
|
93 |
+
files.clear()
|
94 |
+
|
95 |
+
return '\n'.join(self)
|
96 |
+
|
97 |
+
def test_and_setup(self):
|
98 |
+
if self:
|
99 |
+
msg = f'Vector store already exists ({self.vector_store_id})'
|
100 |
+
warnings.warn(msg)
|
101 |
+
else:
|
102 |
+
name = f'{self.prefix}{uuid.uuid4()}'
|
103 |
+
vector_store = self.client.beta.vector_stores.create(
|
104 |
+
name=name,
|
105 |
+
)
|
106 |
+
self.vector_store_id = vector_store.id
|
107 |
+
|
108 |
+
def ls(self, paths):
|
109 |
+
left = 0
|
110 |
+
while left < len(paths):
|
111 |
+
right = left + self.batch_size
|
112 |
+
yield list(map(Path, it.islice(paths, left, right)))
|
113 |
+
left = right
|
114 |
+
|
115 |
+
def put(self, files):
|
116 |
+
batch = self.client.beta.vector_stores.file_batches.upload_and_poll(
|
117 |
+
vector_store_id=self.vector_store_id,
|
118 |
+
files=files,
|
119 |
+
)
|
120 |
+
if batch.file_counts.completed != len(files):
|
121 |
+
err = f'Error uploading documents: {batch.file_counts}'
|
122 |
+
raise InterruptedError(err)
|
mylib/_logging.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
|
4 |
+
#
|
5 |
+
#
|
6 |
+
#
|
7 |
+
logging.basicConfig(
|
8 |
+
format='[ %(asctime)s %(levelname)s %(filename)s ] %(message)s',
|
9 |
+
datefmt='%H:%M:%S',
|
10 |
+
level=os.environ.get('PYTHONLOGLEVEL', 'WARNING').upper(),
|
11 |
+
)
|
12 |
+
logging.captureWarnings(True)
|
13 |
+
Logger = logging.getLogger(__name__)
|
mylib/_message.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ._citations import NoCitations
|
2 |
+
|
3 |
+
class MessageHandler:
|
4 |
+
def __init__(self, client, citecls=None):
|
5 |
+
self.client = client
|
6 |
+
self.citecls = citecls or NoCitations
|
7 |
+
|
8 |
+
def __call__(self, message):
|
9 |
+
return '\n'.join(self.each(message))
|
10 |
+
|
11 |
+
def each(self, message):
|
12 |
+
refn = 1
|
13 |
+
|
14 |
+
for m in message:
|
15 |
+
for c in m.content:
|
16 |
+
cite = self.citecls(c.text.annotations, self.client, refn)
|
17 |
+
body = cite.replace(c.text.value)
|
18 |
+
refn = len(cite) + 1
|
19 |
+
|
20 |
+
yield f'{body}{cite}'
|
prompts/system.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
You are an expert file search assistant
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
openai
|
3 |
+
pandas
|