Spaces:
Running
Running
jeffrey
commited on
Commit
•
37c1830
1
Parent(s):
95681ec
init commit
Browse files- .gitignore +166 -0
- app.py +143 -0
- config/init_project_for_pseudo_trial.yaml +11 -0
- config/init_project_for_run.yaml +38 -0
- packages.txt +4 -0
- requirements.txt +1 -0
- src/__init__.py +0 -0
- src/data/__init__.py +0 -0
- src/data/chunk.py +7 -0
- src/data/parse.py +16 -0
- src/runner.py +96 -0
.gitignore
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
110 |
+
.pdm.toml
|
111 |
+
.pdm-python
|
112 |
+
.pdm-build/
|
113 |
+
|
114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
115 |
+
__pypackages__/
|
116 |
+
|
117 |
+
# Celery stuff
|
118 |
+
celerybeat-schedule
|
119 |
+
celerybeat.pid
|
120 |
+
|
121 |
+
# SageMath parsed files
|
122 |
+
*.sage.py
|
123 |
+
|
124 |
+
# Environments
|
125 |
+
.env
|
126 |
+
.venv
|
127 |
+
env/
|
128 |
+
venv/
|
129 |
+
ENV/
|
130 |
+
env.bak/
|
131 |
+
venv.bak/
|
132 |
+
|
133 |
+
# Spyder project settings
|
134 |
+
.spyderproject
|
135 |
+
.spyproject
|
136 |
+
|
137 |
+
# Rope project settings
|
138 |
+
.ropeproject
|
139 |
+
|
140 |
+
# mkdocs documentation
|
141 |
+
/site
|
142 |
+
|
143 |
+
# mypy
|
144 |
+
.mypy_cache/
|
145 |
+
.dmypy.json
|
146 |
+
dmypy.json
|
147 |
+
|
148 |
+
# Pyre type checker
|
149 |
+
.pyre/
|
150 |
+
|
151 |
+
# pytype static type analyzer
|
152 |
+
.pytype/
|
153 |
+
|
154 |
+
# Cython debug symbols
|
155 |
+
cython_debug/
|
156 |
+
|
157 |
+
# PyCharm
|
158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
+
.idea/
|
163 |
+
file_cache/
|
164 |
+
data/
|
165 |
+
init_project_dir/
|
166 |
+
!src/data
|
app.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import os.path
|
3 |
+
import tempfile
|
4 |
+
import uuid
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
import gradio
|
8 |
+
import gradio as gr
|
9 |
+
import openai
|
10 |
+
import pandas as pd
|
11 |
+
from autorag.evaluator import Evaluator
|
12 |
+
|
13 |
+
from src.data.chunk import chunk
|
14 |
+
from src.data.parse import parse_pdf
|
15 |
+
from src.runner import GradioStreamRunner
|
16 |
+
|
17 |
+
root_dir = os.path.dirname(os.path.realpath(__file__))
|
18 |
+
|
19 |
+
pseudo_trial_yaml_path = os.path.join(root_dir, "config", "init_project_for_pseudo_trial.yaml")
|
20 |
+
init_run_yaml = os.path.join(root_dir, "config", "init_project_for_run.yaml")
|
21 |
+
|
22 |
+
gradio_runner = None
|
23 |
+
|
24 |
+
# Code for Task 1
|
25 |
+
def file_ingest(input_files: List[str], temp_project_dir, progress=gr.Progress()):
|
26 |
+
if os.getenv("OPENAI_API_KEY") is None:
|
27 |
+
return "Please submit your OpenAI API key first."
|
28 |
+
if not input_files:
|
29 |
+
return "Please upload a file first."
|
30 |
+
progress(0.05)
|
31 |
+
# do parse
|
32 |
+
raw_df = parse_pdf(file_lists=input_files)
|
33 |
+
progress(0.3)
|
34 |
+
# do chunk
|
35 |
+
corpus_df = chunk(raw_df, method="recursivecharacter",
|
36 |
+
lang="en", chunk_size=512, chunk_overlap=128)
|
37 |
+
progress(0.5)
|
38 |
+
asyncio.sleep(0.5)
|
39 |
+
|
40 |
+
# Logic for button click
|
41 |
+
empty_qa_df = make_empty_qa(corpus_df=corpus_df)
|
42 |
+
with tempfile.TemporaryDirectory() as temp_data_dir:
|
43 |
+
empty_qa_df.to_parquet(os.path.join(temp_data_dir, "empty_qa.parquet"))
|
44 |
+
corpus_df.to_parquet(os.path.join(temp_data_dir, "corpus.parquet"))
|
45 |
+
|
46 |
+
evaluator = Evaluator(qa_data_path=os.path.join(temp_data_dir, "empty_qa.parquet"),
|
47 |
+
corpus_data_path=os.path.join(temp_data_dir, "corpus.parquet"),
|
48 |
+
project_dir=temp_project_dir)
|
49 |
+
evaluator.start_trial(pseudo_trial_yaml_path, skip_validation=True)
|
50 |
+
yield "Setting up"
|
51 |
+
progress(0.9)
|
52 |
+
set_runner(temp_project_dir)
|
53 |
+
progress(1.0)
|
54 |
+
yield "File uploaded complete. You can use it at chatbot now."
|
55 |
+
|
56 |
+
|
57 |
+
def make_empty_qa(corpus_df: pd.DataFrame):
|
58 |
+
doc_id = corpus_df["doc_id"].iloc[0]
|
59 |
+
return pd.DataFrame({
|
60 |
+
"qid": str(uuid.uuid4()),
|
61 |
+
"query": ["Who is Kai Havertz?"],
|
62 |
+
"retrieval_gt": [[[doc_id]]],
|
63 |
+
"generation_gt": [["Havertz is the greatest footballer."]],
|
64 |
+
})
|
65 |
+
|
66 |
+
|
67 |
+
def on_submit_openai_key(openai_key):
|
68 |
+
os.environ["OPENAI_API_KEY"] = openai_key
|
69 |
+
# Test openai key
|
70 |
+
try:
|
71 |
+
client = openai.OpenAI()
|
72 |
+
response = client.chat.completions.create(
|
73 |
+
messages=[
|
74 |
+
{"role": "user", "content": "What is the capital of France?"},
|
75 |
+
],
|
76 |
+
model="gpt-4o-mini",
|
77 |
+
max_tokens=3,
|
78 |
+
)
|
79 |
+
assert isinstance(response.choices[0].message.content, str)
|
80 |
+
gr.Info("OpenAI API key submitted.", duration=3)
|
81 |
+
return "Setting complete."
|
82 |
+
except openai.AuthenticationError as e:
|
83 |
+
gr.Error("OpenAI API key is invalid.", duration=3)
|
84 |
+
return "Not Set"
|
85 |
+
except AssertionError as e:
|
86 |
+
gr.Error("OpenAI server is not working properly.", duration=3)
|
87 |
+
return "Not Set"
|
88 |
+
|
89 |
+
|
90 |
+
def set_runner(project_dir):
|
91 |
+
runner = GradioStreamRunner.from_yaml(yaml_path=init_run_yaml, project_dir=project_dir)
|
92 |
+
global gradio_runner
|
93 |
+
gradio_runner = runner
|
94 |
+
|
95 |
+
|
96 |
+
def get_response(message, history):
|
97 |
+
global gradio_runner
|
98 |
+
if gradio_runner is None:
|
99 |
+
gradio.Warning("Please set the AutoRAG server first.")
|
100 |
+
return
|
101 |
+
if os.getenv("OPENAI_API_KEY", None) is None:
|
102 |
+
gradio.Warning("Please submit your OpenAI API key first.")
|
103 |
+
return
|
104 |
+
|
105 |
+
for output in gradio_runner.stream_run(message):
|
106 |
+
yield output[0]
|
107 |
+
|
108 |
+
|
109 |
+
# interface one
|
110 |
+
with gr.Blocks(theme="earneleh/paris") as demo:
|
111 |
+
with tempfile.TemporaryDirectory() as project_dir:
|
112 |
+
# Define components
|
113 |
+
with gr.Row():
|
114 |
+
with gr.Column(scale=3):
|
115 |
+
textbox = gr.Textbox(label="Please input your OpenAI API key and press Enter.", type="password",
|
116 |
+
info="You can get your API key from https://platform.openai.com/account/api-keys\n"
|
117 |
+
"AutoRAG do not store your API key.",
|
118 |
+
autofocus=True)
|
119 |
+
api_key_status_box = gr.Textbox(label="OpenAI API status", value="Not Set", interactive=False)
|
120 |
+
|
121 |
+
gr.Markdown("## Ingest Your Data")
|
122 |
+
|
123 |
+
file_input = gr.File(label="Upload Files", type="filepath", file_count="multiple")
|
124 |
+
button = gr.Button("Submit file")
|
125 |
+
text_output = gr.Textbox(label="Status update", interactive=False)
|
126 |
+
|
127 |
+
# Define layout and interactions
|
128 |
+
textbox.submit(on_submit_openai_key, inputs=[textbox], outputs=api_key_status_box)
|
129 |
+
button.click(file_ingest, inputs=[file_input, gr.State(project_dir)], outputs=[text_output])
|
130 |
+
|
131 |
+
with gr.Column(scale=7):
|
132 |
+
gr.ChatInterface(
|
133 |
+
get_response, title="This is your Naive RAG Chatbot 🚀", retry_btn=None, undo_btn=None,
|
134 |
+
)
|
135 |
+
|
136 |
+
gr.Markdown("## Do you like the result?\n\nIf you don't like it, try to optimize it with AutoRAG. Press below button and go to make evaluation data and optimize it. Both on the Huggingface space so you don't need to install anything.")
|
137 |
+
with gr.Row():
|
138 |
+
open_data_creation = gr.Button(value="1️⃣ : Data Creation",
|
139 |
+
link="https://huggingface.co/spaces/AutoRAG/AutoRAG-data-creation")
|
140 |
+
open_optimize = gr.Button(value="2️⃣ : Optimize", link="https://www.auto-rag.com/")
|
141 |
+
|
142 |
+
|
143 |
+
demo.launch(share=False, debug=True)
|
config/init_project_for_pseudo_trial.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
node_lines:
|
2 |
+
- node_line_name: retrieve_node_line
|
3 |
+
nodes:
|
4 |
+
- node_type: retrieval # Set Retrieval Node
|
5 |
+
strategy:
|
6 |
+
metrics: [retrieval_f1, retrieval_recall] # Set Retrieval Metrics
|
7 |
+
top_k: 3
|
8 |
+
modules:
|
9 |
+
- module_type: vectordb
|
10 |
+
embedding_model: openai
|
11 |
+
- module_type: bm25
|
config/init_project_for_run.yaml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
node_lines:
|
2 |
+
- node_line_name: retrieve_node_line
|
3 |
+
nodes:
|
4 |
+
- modules:
|
5 |
+
- module_type: vectordb
|
6 |
+
embedding_model: openai
|
7 |
+
top_k: 5
|
8 |
+
node_type: retrieval
|
9 |
+
strategy:
|
10 |
+
metrics:
|
11 |
+
- retrieval_f1
|
12 |
+
- retrieval_recall
|
13 |
+
- retrieval_precision
|
14 |
+
- node_line_name: post_retrieve_node_line
|
15 |
+
nodes:
|
16 |
+
- modules:
|
17 |
+
- module_type: fstring
|
18 |
+
prompt: "You are the helpful assistant to answer the question. I will give you a context to read. The context can be unrelated to the question.
|
19 |
+
If the context is related, you must answer the question base on the context.
|
20 |
+
If there is no context that relates to the question, you must say that you don't know about the answer.
|
21 |
+
DO NOT MAKE UP THE ANSWER.
|
22 |
+
If you can solve the question with your own knowledge, you can answer the question. But please do not lie or make up the answer without relevant information.
|
23 |
+
Question: {query} \n Context: {retrieved_contents} \n Answer : "
|
24 |
+
node_type: prompt_maker
|
25 |
+
strategy:
|
26 |
+
metrics:
|
27 |
+
- bleu
|
28 |
+
- meteor
|
29 |
+
- rouge
|
30 |
+
- modules:
|
31 |
+
- llm: openai
|
32 |
+
model: gpt-4o-mini
|
33 |
+
module_type: llama_index_llm
|
34 |
+
temperature: 1.0
|
35 |
+
node_type: generator
|
36 |
+
strategy:
|
37 |
+
metrics:
|
38 |
+
- rouge
|
packages.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gcc
|
2 |
+
poppler-utils
|
3 |
+
tesseract-ocr
|
4 |
+
libssl-dev
|
requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
AutoRAG[parse,ko,ja]>=0.3.5
|
src/__init__.py
ADDED
File without changes
|
src/data/__init__.py
ADDED
File without changes
|
src/data/chunk.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from autorag.data.chunk import langchain_chunk
|
3 |
+
|
4 |
+
|
5 |
+
def chunk(raw_df: pd.DataFrame, method: str, lang: str = "en", **kwargs) -> pd.DataFrame:
|
6 |
+
corpus_df = langchain_chunk(raw_df, chunk_method=method, add_file_name=lang, **kwargs)
|
7 |
+
return corpus_df
|
src/data/parse.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, List
|
2 |
+
|
3 |
+
from autorag.data.parse import langchain_parse
|
4 |
+
from autorag.data.parse.base import _add_last_modified_datetime
|
5 |
+
from autorag.utils import result_to_dataframe
|
6 |
+
|
7 |
+
|
8 |
+
@result_to_dataframe(["texts", "path", "page", "last_modified_datetime"])
|
9 |
+
def original_parse(fn: Callable, **kwargs):
|
10 |
+
result = fn(**kwargs)
|
11 |
+
result = _add_last_modified_datetime(result)
|
12 |
+
return result
|
13 |
+
|
14 |
+
def parse_pdf(file_lists: List[str], parse_method: str = "pdfminer"):
|
15 |
+
raw_df = original_parse(langchain_parse.__wrapped__, data_path_list=file_lists, parse_method=parse_method)
|
16 |
+
return raw_df
|
src/runner.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import uuid
|
3 |
+
from typing import List, Dict, Optional
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
from autorag.deploy import GradioRunner
|
7 |
+
from autorag.deploy.api import RetrievedPassage
|
8 |
+
from autorag.nodes.generator.base import BaseGenerator
|
9 |
+
from autorag.utils import fetch_contents
|
10 |
+
|
11 |
+
empty_retrieved_passage = RetrievedPassage(
|
12 |
+
content="", doc_id="", filepath=None, file_page=None, start_idx=None, end_idx=None
|
13 |
+
)
|
14 |
+
|
15 |
+
class GradioStreamRunner(GradioRunner):
|
16 |
+
def __init__(self, config: Dict, project_dir: Optional[str] = None):
|
17 |
+
super().__init__(config, project_dir)
|
18 |
+
|
19 |
+
data_dir = os.path.join(project_dir, "data")
|
20 |
+
self.corpus_df = pd.read_parquet(
|
21 |
+
os.path.join(data_dir, "corpus.parquet"), engine="pyarrow"
|
22 |
+
)
|
23 |
+
|
24 |
+
def stream_run(self, query: str):
|
25 |
+
previous_result = pd.DataFrame(
|
26 |
+
{
|
27 |
+
"qid": str(uuid.uuid4()),
|
28 |
+
"query": [query],
|
29 |
+
"retrieval_gt": [[]],
|
30 |
+
"generation_gt": [""],
|
31 |
+
}
|
32 |
+
) # pseudo qa data for execution
|
33 |
+
|
34 |
+
for module_instance, module_param in zip(
|
35 |
+
self.module_instances, self.module_params
|
36 |
+
):
|
37 |
+
if not isinstance(module_instance, BaseGenerator):
|
38 |
+
new_result = module_instance.pure(
|
39 |
+
previous_result=previous_result, **module_param
|
40 |
+
)
|
41 |
+
duplicated_columns = previous_result.columns.intersection(
|
42 |
+
new_result.columns
|
43 |
+
)
|
44 |
+
drop_previous_result = previous_result.drop(
|
45 |
+
columns=duplicated_columns
|
46 |
+
)
|
47 |
+
previous_result = pd.concat(
|
48 |
+
[drop_previous_result, new_result], axis=1
|
49 |
+
)
|
50 |
+
else:
|
51 |
+
# retrieved_passages = self.extract_retrieve_passage(
|
52 |
+
# previous_result
|
53 |
+
# )
|
54 |
+
# yield "", retrieved_passages
|
55 |
+
# Start streaming of the result
|
56 |
+
assert len(previous_result) == 1
|
57 |
+
prompt: str = previous_result["prompts"].tolist()[0]
|
58 |
+
for delta in module_instance.stream(prompt=prompt, **module_param):
|
59 |
+
yield delta, [empty_retrieved_passage]
|
60 |
+
|
61 |
+
|
62 |
+
def extract_retrieve_passage(self, df: pd.DataFrame) -> List[RetrievedPassage]:
|
63 |
+
retrieved_ids: List[str] = df["retrieved_ids"].tolist()[0]
|
64 |
+
contents = fetch_contents(self.corpus_df, [retrieved_ids])[0]
|
65 |
+
if "path" in self.corpus_df.columns:
|
66 |
+
paths = fetch_contents(self.corpus_df, [retrieved_ids], column_name="path")[
|
67 |
+
0
|
68 |
+
]
|
69 |
+
else:
|
70 |
+
paths = [None] * len(retrieved_ids)
|
71 |
+
metadatas = fetch_contents(
|
72 |
+
self.corpus_df, [retrieved_ids], column_name="metadata"
|
73 |
+
)[0]
|
74 |
+
if "start_end_idx" in self.corpus_df.columns:
|
75 |
+
start_end_indices = fetch_contents(
|
76 |
+
self.corpus_df, [retrieved_ids], column_name="start_end_idx"
|
77 |
+
)[0]
|
78 |
+
else:
|
79 |
+
start_end_indices = [None] * len(retrieved_ids)
|
80 |
+
return list(
|
81 |
+
map(
|
82 |
+
lambda content, doc_id, path, metadata, start_end_idx: RetrievedPassage(
|
83 |
+
content=content,
|
84 |
+
doc_id=doc_id,
|
85 |
+
filepath=path,
|
86 |
+
file_page=metadata.get("page", None),
|
87 |
+
start_idx=start_end_idx[0] if start_end_idx else None,
|
88 |
+
end_idx=start_end_idx[1] if start_end_idx else None,
|
89 |
+
),
|
90 |
+
contents,
|
91 |
+
retrieved_ids,
|
92 |
+
paths,
|
93 |
+
metadatas,
|
94 |
+
start_end_indices,
|
95 |
+
)
|
96 |
+
)
|