Spaces:
Sleeping
Sleeping
shresthasingh
commited on
Commit
•
0667fd0
1
Parent(s):
4e15ad1
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
from llama_parse import LlamaParse
|
4 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
5 |
+
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
|
6 |
+
from langchain_community.vectorstores import Chroma
|
7 |
+
from langchain.schema import Document as LangchainDocument
|
8 |
+
|
9 |
+
# Initialize global variables
|
10 |
+
vs_dict = {}
|
11 |
+
|
12 |
+
# Helper function to load and parse the input data
|
13 |
+
def mariela_parse(files):
|
14 |
+
parser = LlamaParse(
|
15 |
+
api_key=os.getenv("LLAMA_API_KEY"),
|
16 |
+
result_type="markdown",
|
17 |
+
verbose=True
|
18 |
+
)
|
19 |
+
parsed_documents = []
|
20 |
+
for file in files:
|
21 |
+
parsed_documents.extend(parser.load_data(file.name))
|
22 |
+
return parsed_documents
|
23 |
+
|
24 |
+
# Create vector database
|
25 |
+
def mariela_create_vector_database(parsed_documents, collection_name):
|
26 |
+
langchain_docs = [
|
27 |
+
LangchainDocument(page_content=doc.text, metadata=doc.metadata)
|
28 |
+
for doc in parsed_documents
|
29 |
+
]
|
30 |
+
|
31 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=100)
|
32 |
+
docs = text_splitter.split_documents(langchain_docs)
|
33 |
+
|
34 |
+
embed_model = FastEmbedEmbeddings(model_name="BAAI/bge-base-en-v1.5")
|
35 |
+
|
36 |
+
vs = Chroma.from_documents(
|
37 |
+
documents=docs,
|
38 |
+
embedding=embed_model,
|
39 |
+
persist_directory="chroma_db",
|
40 |
+
collection_name=collection_name
|
41 |
+
)
|
42 |
+
|
43 |
+
return vs
|
44 |
+
|
45 |
+
# Function to handle file upload and parsing
|
46 |
+
def mariela_upload_and_parse(files, collection_name):
|
47 |
+
global vs_dict
|
48 |
+
if not files:
|
49 |
+
return "Please upload at least one file."
|
50 |
+
|
51 |
+
parsed_documents = mariela_parse(files)
|
52 |
+
vs = mariela_create_vector_database(parsed_documents, collection_name)
|
53 |
+
|
54 |
+
vs_dict[collection_name] = vs
|
55 |
+
|
56 |
+
return f"Files uploaded, parsed, and stored successfully in collection: {collection_name}"
|
57 |
+
|
58 |
+
# Function to handle retrieval
|
59 |
+
def mariela_retrieve(question, collection_name):
|
60 |
+
global vs_dict
|
61 |
+
if collection_name not in vs_dict:
|
62 |
+
return f"Collection '{collection_name}' not found. Please upload and parse files for this collection first."
|
63 |
+
|
64 |
+
vs = vs_dict[collection_name]
|
65 |
+
results = vs.similarity_search(question, k=4)
|
66 |
+
|
67 |
+
formatted_results = []
|
68 |
+
for i, doc in enumerate(results, 1):
|
69 |
+
formatted_results.append(f"Result {i}:\n{doc.page_content}\n\nMetadata: {doc.metadata}\n")
|
70 |
+
|
71 |
+
return "\n\n".join(formatted_results)
|
72 |
+
|
73 |
+
# Supported file types list
|
74 |
+
supported_file_types = """
|
75 |
+
Supported Document Types:
|
76 |
+
- Base types: pdf
|
77 |
+
- Documents and presentations: 602, abw, cgm, cwk, doc, docx, docm, dot, dotm, hwp, key, lwp, mw, mcw, pages, pbd, ppt, pptm, pptx, pot, potm, potx, rtf, sda, sdd, sdp, sdw, sgl, sti, sxi, sxw, stw, sxg, txt, uof, uop, uot, vor, wpd, wps, xml, zabw, epub
|
78 |
+
- Images: jpg, jpeg, png, gif, bmp, svg, tiff, webp, web, htm, html
|
79 |
+
- Spreadsheets: xlsx, xls, xlsm, xlsb, xlw, csv, dif, sylk, slk, prn, numbers, et, ods, fods, uos1, uos2, dbf, wk1, wk2, wk3, wk4, wks, 123, wq1, wq2, wb1, wb2, wb3, qpw, xlr, eth, tsv
|
80 |
+
"""
|
81 |
+
|
82 |
+
# Create Gradio interface
|
83 |
+
with gr.Blocks() as demo:
|
84 |
+
gr.Markdown("# Mariela: Multi-Action Retrieval and Intelligent Extraction Learning Assistant")
|
85 |
+
gr.Markdown("This application allows you to upload documents, parse them, and then ask questions to retrieve relevant information.")
|
86 |
+
|
87 |
+
with gr.Tab("Upload and Parse Files"):
|
88 |
+
gr.Markdown("## Upload and Parse Files")
|
89 |
+
gr.Markdown("Upload your documents here to create a searchable knowledge base.")
|
90 |
+
gr.Markdown("""
|
91 |
+
### API Documentation
|
92 |
+
1. **Confirm that you have cURL installed on your system.**
|
93 |
+
|
94 |
+
```bash
|
95 |
+
$ curl --version
|
96 |
+
```
|
97 |
+
2. **Find the API endpoint below corresponding to your desired function in the app.**
|
98 |
+
|
99 |
+
**API Name: `/mariela_upload`**
|
100 |
+
|
101 |
+
```bash
|
102 |
+
curl -X POST {url_of_gradio_app}/call/mariela_upload -s -H "Content-Type: application/json" -d '{
|
103 |
+
"data": [
|
104 |
+
[handle_file('https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf')],
|
105 |
+
"Hello!!"
|
106 |
+
]}' \
|
107 |
+
| awk -F'"' '{ print $4}' \
|
108 |
+
| read EVENT_ID; curl -N {url_of_gradio_app}/call/mariela_upload/$EVENT_ID
|
109 |
+
```
|
110 |
+
|
111 |
+
**Accepts 2 parameters:**
|
112 |
+
|
113 |
+
- **[0] any (Required):** The input value that is provided in the "Upload Files" File component.
|
114 |
+
- **[1] string (Required):** The input value that is provided in the "Collection Name" Textbox component.
|
115 |
+
|
116 |
+
**Returns 1 element:**
|
117 |
+
|
118 |
+
- **string:** The output value that appears in the "Status" Textbox component.
|
119 |
+
""")
|
120 |
+
file_input = gr.File(label="Upload Files", file_count="multiple")
|
121 |
+
collection_name_input = gr.Textbox(label="Collection Name")
|
122 |
+
upload_button = gr.Button("Upload and Parse")
|
123 |
+
upload_output = gr.Textbox(label="Status")
|
124 |
+
|
125 |
+
upload_button.click(mariela_upload_and_parse, inputs=[file_input, collection_name_input], outputs=upload_output)
|
126 |
+
|
127 |
+
with gr.Tab("Retrieval"):
|
128 |
+
gr.Markdown("## Retrieval")
|
129 |
+
gr.Markdown("Ask questions about your uploaded documents here.")
|
130 |
+
gr.Markdown("""
|
131 |
+
### API Documentation
|
132 |
+
1. **Confirm that you have cURL installed on your system.**
|
133 |
+
|
134 |
+
```bash
|
135 |
+
$ curl --version
|
136 |
+
```
|
137 |
+
2. **Find the API endpoint below corresponding to your desired function in the app.**
|
138 |
+
|
139 |
+
**API Name: `/mariela_retrieve`**
|
140 |
+
|
141 |
+
```bash
|
142 |
+
curl -X POST {url_of_gradio_app}/call/mariela_retrieve -s -H "Content-Type: application/json" -d '{
|
143 |
+
"data": [
|
144 |
+
"Hello!!",
|
145 |
+
"Hello!!"
|
146 |
+
]}' \
|
147 |
+
| awk -F'"' '{ print $4}' \
|
148 |
+
| read EVENT_ID; curl -N {url_of_gradio_app}/call/mariela_retrieve/$EVENT_ID
|
149 |
+
```
|
150 |
+
|
151 |
+
**Accepts 2 parameters:**
|
152 |
+
|
153 |
+
- **[0] string (Required):** The input value that is provided in the "Enter a query to retrieve relevant passages" Textbox component.
|
154 |
+
- **[1] string (Required):** The input value that is provided in the "Collection Name" Textbox component.
|
155 |
+
|
156 |
+
**Returns 1 element:**
|
157 |
+
|
158 |
+
- **string:** The output value that appears in the "Retrieved Passages" Textbox component.
|
159 |
+
""")
|
160 |
+
collection_name_retrieval = gr.Textbox(label="Collection Name")
|
161 |
+
question_input = gr.Textbox(label="Enter a query to retrieve relevant passages")
|
162 |
+
retrieval_output = gr.Textbox(label="Retrieved Passages")
|
163 |
+
retrieval_button = gr.Button("Retrieve")
|
164 |
+
|
165 |
+
retrieval_button.click(mariela_retrieve, inputs=[question_input, collection_name_retrieval], outputs=retrieval_output)
|
166 |
+
|
167 |
+
with gr.Tab("Supported Document Types"):
|
168 |
+
gr.Markdown("## Supported Document Types")
|
169 |
+
gr.Markdown(supported_file_types)
|
170 |
+
|
171 |
+
demo.launch(debug=True)
|