jeffrey commited on
Commit
37c1830
1 Parent(s): 95681ec

init commit

Browse files
.gitignore ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ .idea/
163
+ file_cache/
164
+ data/
165
+ init_project_dir/
166
+ !src/data
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os.path
3
+ import tempfile
4
+ import uuid
5
+ from typing import List
6
+
7
+ import gradio
8
+ import gradio as gr
9
+ import openai
10
+ import pandas as pd
11
+ from autorag.evaluator import Evaluator
12
+
13
+ from src.data.chunk import chunk
14
+ from src.data.parse import parse_pdf
15
+ from src.runner import GradioStreamRunner
16
+
17
+ root_dir = os.path.dirname(os.path.realpath(__file__))
18
+
19
+ pseudo_trial_yaml_path = os.path.join(root_dir, "config", "init_project_for_pseudo_trial.yaml")
20
+ init_run_yaml = os.path.join(root_dir, "config", "init_project_for_run.yaml")
21
+
22
+ gradio_runner = None
23
+
24
+ # Code for Task 1
25
+ def file_ingest(input_files: List[str], temp_project_dir, progress=gr.Progress()):
26
+ if os.getenv("OPENAI_API_KEY") is None:
27
+ return "Please submit your OpenAI API key first."
28
+ if not input_files:
29
+ return "Please upload a file first."
30
+ progress(0.05)
31
+ # do parse
32
+ raw_df = parse_pdf(file_lists=input_files)
33
+ progress(0.3)
34
+ # do chunk
35
+ corpus_df = chunk(raw_df, method="recursivecharacter",
36
+ lang="en", chunk_size=512, chunk_overlap=128)
37
+ progress(0.5)
38
+ asyncio.sleep(0.5)
39
+
40
+ # Logic for button click
41
+ empty_qa_df = make_empty_qa(corpus_df=corpus_df)
42
+ with tempfile.TemporaryDirectory() as temp_data_dir:
43
+ empty_qa_df.to_parquet(os.path.join(temp_data_dir, "empty_qa.parquet"))
44
+ corpus_df.to_parquet(os.path.join(temp_data_dir, "corpus.parquet"))
45
+
46
+ evaluator = Evaluator(qa_data_path=os.path.join(temp_data_dir, "empty_qa.parquet"),
47
+ corpus_data_path=os.path.join(temp_data_dir, "corpus.parquet"),
48
+ project_dir=temp_project_dir)
49
+ evaluator.start_trial(pseudo_trial_yaml_path, skip_validation=True)
50
+ yield "Setting up"
51
+ progress(0.9)
52
+ set_runner(temp_project_dir)
53
+ progress(1.0)
54
+ yield "File uploaded complete. You can use it at chatbot now."
55
+
56
+
57
+ def make_empty_qa(corpus_df: pd.DataFrame):
58
+ doc_id = corpus_df["doc_id"].iloc[0]
59
+ return pd.DataFrame({
60
+ "qid": str(uuid.uuid4()),
61
+ "query": ["Who is Kai Havertz?"],
62
+ "retrieval_gt": [[[doc_id]]],
63
+ "generation_gt": [["Havertz is the greatest footballer."]],
64
+ })
65
+
66
+
67
+ def on_submit_openai_key(openai_key):
68
+ os.environ["OPENAI_API_KEY"] = openai_key
69
+ # Test openai key
70
+ try:
71
+ client = openai.OpenAI()
72
+ response = client.chat.completions.create(
73
+ messages=[
74
+ {"role": "user", "content": "What is the capital of France?"},
75
+ ],
76
+ model="gpt-4o-mini",
77
+ max_tokens=3,
78
+ )
79
+ assert isinstance(response.choices[0].message.content, str)
80
+ gr.Info("OpenAI API key submitted.", duration=3)
81
+ return "Setting complete."
82
+ except openai.AuthenticationError as e:
83
+ gr.Error("OpenAI API key is invalid.", duration=3)
84
+ return "Not Set"
85
+ except AssertionError as e:
86
+ gr.Error("OpenAI server is not working properly.", duration=3)
87
+ return "Not Set"
88
+
89
+
90
+ def set_runner(project_dir):
91
+ runner = GradioStreamRunner.from_yaml(yaml_path=init_run_yaml, project_dir=project_dir)
92
+ global gradio_runner
93
+ gradio_runner = runner
94
+
95
+
96
+ def get_response(message, history):
97
+ global gradio_runner
98
+ if gradio_runner is None:
99
+ gradio.Warning("Please set the AutoRAG server first.")
100
+ return
101
+ if os.getenv("OPENAI_API_KEY", None) is None:
102
+ gradio.Warning("Please submit your OpenAI API key first.")
103
+ return
104
+
105
+ for output in gradio_runner.stream_run(message):
106
+ yield output[0]
107
+
108
+
109
+ # interface one
110
+ with gr.Blocks(theme="earneleh/paris") as demo:
111
+ with tempfile.TemporaryDirectory() as project_dir:
112
+ # Define components
113
+ with gr.Row():
114
+ with gr.Column(scale=3):
115
+ textbox = gr.Textbox(label="Please input your OpenAI API key and press Enter.", type="password",
116
+ info="You can get your API key from https://platform.openai.com/account/api-keys\n"
117
+ "AutoRAG do not store your API key.",
118
+ autofocus=True)
119
+ api_key_status_box = gr.Textbox(label="OpenAI API status", value="Not Set", interactive=False)
120
+
121
+ gr.Markdown("## Ingest Your Data")
122
+
123
+ file_input = gr.File(label="Upload Files", type="filepath", file_count="multiple")
124
+ button = gr.Button("Submit file")
125
+ text_output = gr.Textbox(label="Status update", interactive=False)
126
+
127
+ # Define layout and interactions
128
+ textbox.submit(on_submit_openai_key, inputs=[textbox], outputs=api_key_status_box)
129
+ button.click(file_ingest, inputs=[file_input, gr.State(project_dir)], outputs=[text_output])
130
+
131
+ with gr.Column(scale=7):
132
+ gr.ChatInterface(
133
+ get_response, title="This is your Naive RAG Chatbot 🚀", retry_btn=None, undo_btn=None,
134
+ )
135
+
136
+ gr.Markdown("## Do you like the result?\n\nIf you don't like it, try to optimize it with AutoRAG. Press below button and go to make evaluation data and optimize it. Both on the Huggingface space so you don't need to install anything.")
137
+ with gr.Row():
138
+ open_data_creation = gr.Button(value="1️⃣ : Data Creation",
139
+ link="https://huggingface.co/spaces/AutoRAG/AutoRAG-data-creation")
140
+ open_optimize = gr.Button(value="2️⃣ : Optimize", link="https://www.auto-rag.com/")
141
+
142
+
143
+ demo.launch(share=False, debug=True)
config/init_project_for_pseudo_trial.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ node_lines:
2
+ - node_line_name: retrieve_node_line
3
+ nodes:
4
+ - node_type: retrieval # Set Retrieval Node
5
+ strategy:
6
+ metrics: [retrieval_f1, retrieval_recall] # Set Retrieval Metrics
7
+ top_k: 3
8
+ modules:
9
+ - module_type: vectordb
10
+ embedding_model: openai
11
+ - module_type: bm25
config/init_project_for_run.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ node_lines:
2
+ - node_line_name: retrieve_node_line
3
+ nodes:
4
+ - modules:
5
+ - module_type: vectordb
6
+ embedding_model: openai
7
+ top_k: 5
8
+ node_type: retrieval
9
+ strategy:
10
+ metrics:
11
+ - retrieval_f1
12
+ - retrieval_recall
13
+ - retrieval_precision
14
+ - node_line_name: post_retrieve_node_line
15
+ nodes:
16
+ - modules:
17
+ - module_type: fstring
18
+ prompt: "You are the helpful assistant to answer the question. I will give you a context to read. The context can be unrelated to the question.
19
+ If the context is related, you must answer the question base on the context.
20
+ If there is no context that relates to the question, you must say that you don't know about the answer.
21
+ DO NOT MAKE UP THE ANSWER.
22
+ If you can solve the question with your own knowledge, you can answer the question. But please do not lie or make up the answer without relevant information.
23
+ Question: {query} \n Context: {retrieved_contents} \n Answer : "
24
+ node_type: prompt_maker
25
+ strategy:
26
+ metrics:
27
+ - bleu
28
+ - meteor
29
+ - rouge
30
+ - modules:
31
+ - llm: openai
32
+ model: gpt-4o-mini
33
+ module_type: llama_index_llm
34
+ temperature: 1.0
35
+ node_type: generator
36
+ strategy:
37
+ metrics:
38
+ - rouge
packages.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gcc
2
+ poppler-utils
3
+ tesseract-ocr
4
+ libssl-dev
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ AutoRAG[parse,ko,ja]>=0.3.5
src/__init__.py ADDED
File without changes
src/data/__init__.py ADDED
File without changes
src/data/chunk.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from autorag.data.chunk import langchain_chunk
3
+
4
+
5
+ def chunk(raw_df: pd.DataFrame, method: str, lang: str = "en", **kwargs) -> pd.DataFrame:
6
+ corpus_df = langchain_chunk(raw_df, chunk_method=method, add_file_name=lang, **kwargs)
7
+ return corpus_df
src/data/parse.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, List
2
+
3
+ from autorag.data.parse import langchain_parse
4
+ from autorag.data.parse.base import _add_last_modified_datetime
5
+ from autorag.utils import result_to_dataframe
6
+
7
+
8
+ @result_to_dataframe(["texts", "path", "page", "last_modified_datetime"])
9
+ def original_parse(fn: Callable, **kwargs):
10
+ result = fn(**kwargs)
11
+ result = _add_last_modified_datetime(result)
12
+ return result
13
+
14
+ def parse_pdf(file_lists: List[str], parse_method: str = "pdfminer"):
15
+ raw_df = original_parse(langchain_parse.__wrapped__, data_path_list=file_lists, parse_method=parse_method)
16
+ return raw_df
src/runner.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ from typing import List, Dict, Optional
4
+
5
+ import pandas as pd
6
+ from autorag.deploy import GradioRunner
7
+ from autorag.deploy.api import RetrievedPassage
8
+ from autorag.nodes.generator.base import BaseGenerator
9
+ from autorag.utils import fetch_contents
10
+
11
+ empty_retrieved_passage = RetrievedPassage(
12
+ content="", doc_id="", filepath=None, file_page=None, start_idx=None, end_idx=None
13
+ )
14
+
15
+ class GradioStreamRunner(GradioRunner):
16
+ def __init__(self, config: Dict, project_dir: Optional[str] = None):
17
+ super().__init__(config, project_dir)
18
+
19
+ data_dir = os.path.join(project_dir, "data")
20
+ self.corpus_df = pd.read_parquet(
21
+ os.path.join(data_dir, "corpus.parquet"), engine="pyarrow"
22
+ )
23
+
24
+ def stream_run(self, query: str):
25
+ previous_result = pd.DataFrame(
26
+ {
27
+ "qid": str(uuid.uuid4()),
28
+ "query": [query],
29
+ "retrieval_gt": [[]],
30
+ "generation_gt": [""],
31
+ }
32
+ ) # pseudo qa data for execution
33
+
34
+ for module_instance, module_param in zip(
35
+ self.module_instances, self.module_params
36
+ ):
37
+ if not isinstance(module_instance, BaseGenerator):
38
+ new_result = module_instance.pure(
39
+ previous_result=previous_result, **module_param
40
+ )
41
+ duplicated_columns = previous_result.columns.intersection(
42
+ new_result.columns
43
+ )
44
+ drop_previous_result = previous_result.drop(
45
+ columns=duplicated_columns
46
+ )
47
+ previous_result = pd.concat(
48
+ [drop_previous_result, new_result], axis=1
49
+ )
50
+ else:
51
+ # retrieved_passages = self.extract_retrieve_passage(
52
+ # previous_result
53
+ # )
54
+ # yield "", retrieved_passages
55
+ # Start streaming of the result
56
+ assert len(previous_result) == 1
57
+ prompt: str = previous_result["prompts"].tolist()[0]
58
+ for delta in module_instance.stream(prompt=prompt, **module_param):
59
+ yield delta, [empty_retrieved_passage]
60
+
61
+
62
+ def extract_retrieve_passage(self, df: pd.DataFrame) -> List[RetrievedPassage]:
63
+ retrieved_ids: List[str] = df["retrieved_ids"].tolist()[0]
64
+ contents = fetch_contents(self.corpus_df, [retrieved_ids])[0]
65
+ if "path" in self.corpus_df.columns:
66
+ paths = fetch_contents(self.corpus_df, [retrieved_ids], column_name="path")[
67
+ 0
68
+ ]
69
+ else:
70
+ paths = [None] * len(retrieved_ids)
71
+ metadatas = fetch_contents(
72
+ self.corpus_df, [retrieved_ids], column_name="metadata"
73
+ )[0]
74
+ if "start_end_idx" in self.corpus_df.columns:
75
+ start_end_indices = fetch_contents(
76
+ self.corpus_df, [retrieved_ids], column_name="start_end_idx"
77
+ )[0]
78
+ else:
79
+ start_end_indices = [None] * len(retrieved_ids)
80
+ return list(
81
+ map(
82
+ lambda content, doc_id, path, metadata, start_end_idx: RetrievedPassage(
83
+ content=content,
84
+ doc_id=doc_id,
85
+ filepath=path,
86
+ file_page=metadata.get("page", None),
87
+ start_idx=start_end_idx[0] if start_end_idx else None,
88
+ end_idx=start_end_idx[1] if start_end_idx else None,
89
+ ),
90
+ contents,
91
+ retrieved_ids,
92
+ paths,
93
+ metadatas,
94
+ start_end_indices,
95
+ )
96
+ )