Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
from PyPDF2 import PdfReader
|
4 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
5 |
+
from langchain_core.documents import Document
|
6 |
+
import chromadb
|
7 |
+
from langchain_community.vectorstores import Chroma
|
8 |
+
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
9 |
+
from langchain_core.prompts import PromptTemplate
|
10 |
+
from langchain_core.output_parsers import StrOutputParser
|
11 |
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
12 |
+
from langgraph.graph import StateGraph, END
|
13 |
+
from typing import List, TypedDict
|
14 |
+
import pandas as pd
|
15 |
+
|
16 |
+
# Set API keys
|
17 |
+
os.environ["TAVILY_API_KEY"] = "YOUR_Tavily_API_KEY"
|
18 |
+
os.environ["NVIDIA_API_KEY"] = "YOUR_NVIDIA_API_KEY"
|
19 |
+
os.environ["LANGCHAIN_PROJECT"] = "RAG Compliance Checker"
|
20 |
+
|
21 |
+
# Initialize embedding model
|
22 |
+
model_name = "dunzhang/stella_en_1.5B_v5"
|
23 |
+
embedding_model = HuggingFaceEmbeddings(
|
24 |
+
model_name=model_name,
|
25 |
+
model_kwargs={'trust_remote_code': True},
|
26 |
+
show_progress=True
|
27 |
+
)
|
28 |
+
|
29 |
+
# Define data models
|
30 |
+
class GradeDocuments(BaseModel):
|
31 |
+
binary_score: str = Field(description="Relevance score 'yes' or 'no'")
|
32 |
+
|
33 |
+
class GraphState(TypedDict):
|
34 |
+
question: str
|
35 |
+
generation: str
|
36 |
+
decision: str
|
37 |
+
documents: List[Document]
|
38 |
+
|
39 |
+
def create_workflow(retriever):
|
40 |
+
# Define workflow nodes
|
41 |
+
def retrieve(state):
|
42 |
+
print("---RETRIEVING DOCUMENTS---")
|
43 |
+
question = state["question"]
|
44 |
+
documents = retriever.invoke(question)
|
45 |
+
return {"documents": documents, "question": question}
|
46 |
+
|
47 |
+
def grade_documents(state):
|
48 |
+
print("---GRADING DOCUMENTS---")
|
49 |
+
question = state["question"]
|
50 |
+
documents = state["documents"]
|
51 |
+
|
52 |
+
llm = ChatNVIDIA(model="meta/llama-3.3-70b-instruct")
|
53 |
+
grader = llm.with_structured_output(GradeDocuments)
|
54 |
+
|
55 |
+
system = """You are a relevance grader. Determine if the document contains
|
56 |
+
information related to the question. Answer 'yes' or 'no'."""
|
57 |
+
prompt = ChatPromptTemplate.from_messages([
|
58 |
+
("system", system),
|
59 |
+
("human", "Document:\n{document}\n\nQuestion: {question}")
|
60 |
+
])
|
61 |
+
|
62 |
+
filtered_docs = []
|
63 |
+
for doc in documents:
|
64 |
+
response = (prompt | grader).invoke({
|
65 |
+
"question": question,
|
66 |
+
"document": doc.page_content
|
67 |
+
})
|
68 |
+
if response.binary_score == "yes":
|
69 |
+
filtered_docs.append(doc)
|
70 |
+
|
71 |
+
return {"documents": filtered_docs, "question": question}
|
72 |
+
|
73 |
+
def generate_response(state):
|
74 |
+
print("---GENERATING RESPONSE---")
|
75 |
+
question = state["question"]
|
76 |
+
documents = state["documents"]
|
77 |
+
|
78 |
+
template = """Answer the question using only the context below:
|
79 |
+
Context: {context}
|
80 |
+
Question: {question}"""
|
81 |
+
|
82 |
+
prompt = PromptTemplate.from_template(template)
|
83 |
+
llm = ChatNVIDIA(model="meta/llama-3.3-70b-instruct")
|
84 |
+
|
85 |
+
chain = (
|
86 |
+
{"context": lambda _: "\n\n".join(d.page_content for d in documents), "question": RunnablePassthrough()}
|
87 |
+
| prompt
|
88 |
+
| llm
|
89 |
+
| StrOutputParser()
|
90 |
+
)
|
91 |
+
|
92 |
+
return {"generation": chain.invoke(question)}
|
93 |
+
|
94 |
+
# Build workflow
|
95 |
+
workflow = StateGraph(GraphState)
|
96 |
+
workflow.add_node("retrieve", retrieve)
|
97 |
+
workflow.add_node("grade", grade_documents)
|
98 |
+
workflow.add_node("generate", generate_response)
|
99 |
+
|
100 |
+
workflow.add_edge("retrieve", "grade")
|
101 |
+
workflow.add_conditional_edges(
|
102 |
+
"grade",
|
103 |
+
lambda state: "generate" if len(state["documents"]) > 0 else END,
|
104 |
+
{"generate": "generate"}
|
105 |
+
)
|
106 |
+
workflow.add_edge("generate", END)
|
107 |
+
|
108 |
+
return workflow.compile()
|
109 |
+
|
110 |
+
def process_documents(folder_path):
|
111 |
+
"""Process PDF files from uploaded folder"""
|
112 |
+
documents = []
|
113 |
+
for filename in os.listdir(folder_path):
|
114 |
+
if filename.endswith(".pdf"):
|
115 |
+
path = os.path.join(folder_path, filename)
|
116 |
+
try:
|
117 |
+
reader = PdfReader(path)
|
118 |
+
text = "\n".join([page.extract_text() for page in reader.pages])
|
119 |
+
documents.append(Document(
|
120 |
+
page_content=text,
|
121 |
+
metadata={"source": filename}
|
122 |
+
))
|
123 |
+
except Exception as e:
|
124 |
+
print(f"Error processing {filename}: {str(e)}")
|
125 |
+
return documents
|
126 |
+
|
127 |
+
def analyze_requirements(csv_file, documents):
|
128 |
+
"""Main analysis function"""
|
129 |
+
# Create vector store
|
130 |
+
client = chromadb.PersistentClient()
|
131 |
+
vector_store = Chroma(
|
132 |
+
client=client,
|
133 |
+
collection_name="dynamic_rag",
|
134 |
+
embedding_function=embedding_model
|
135 |
+
)
|
136 |
+
|
137 |
+
# Add documents in batches
|
138 |
+
batch_size = 500
|
139 |
+
for i in range(0, len(documents), batch_size):
|
140 |
+
batch = documents[i:i+batch_size]
|
141 |
+
vector_store.add_documents(batch, ids=[str(n) for n in range(len(batch))])
|
142 |
+
|
143 |
+
retriever = vector_store.as_retriever(search_kwargs={"k": 5})
|
144 |
+
app = create_workflow(retriever)
|
145 |
+
|
146 |
+
# Process requirements
|
147 |
+
df = pd.read_csv(csv_file.name)
|
148 |
+
results = []
|
149 |
+
|
150 |
+
for req in df['Requirement']:
|
151 |
+
response = app.invoke({"question": req})
|
152 |
+
results.append({
|
153 |
+
"Requirement": req,
|
154 |
+
"Response": response["generation"],
|
155 |
+
"Status": "Processed"
|
156 |
+
})
|
157 |
+
|
158 |
+
return pd.DataFrame(results)
|
159 |
+
|
160 |
+
# Gradio interface
|
161 |
+
with gr.Blocks(title="RAG Compliance Checker") as interface:
|
162 |
+
gr.Markdown("# AI Compliance Assistant")
|
163 |
+
gr.Markdown("Upload documents and requirements CSV for compliance analysis")
|
164 |
+
|
165 |
+
with gr.Row():
|
166 |
+
with gr.Column():
|
167 |
+
doc_upload = gr.File(label="Upload Documents Folder", file_count="directory")
|
168 |
+
csv_upload = gr.File(label="Upload Requirements CSV", file_types=[".csv"])
|
169 |
+
submit_btn = gr.Button("Analyze", variant="primary")
|
170 |
+
|
171 |
+
with gr.Column():
|
172 |
+
results_table = gr.DataFrame(
|
173 |
+
label="Analysis Results",
|
174 |
+
headers=["Requirement", "Response", "Status"],
|
175 |
+
interactive=False
|
176 |
+
)
|
177 |
+
status = gr.Textbox(label="Processing Status")
|
178 |
+
|
179 |
+
submit_btn.click(
|
180 |
+
fn=lambda doc, csv: analyze_requirements(csv, process_documents(doc)),
|
181 |
+
inputs=[doc_upload, csv_upload],
|
182 |
+
outputs=results_table,
|
183 |
+
api_name="analyze"
|
184 |
+
)
|
185 |
+
|
186 |
+
if __name__ == "__main__":
|
187 |
+
interface.launch(server_name="0.0.0.0", server_port=7860, share=True)
|