Codequestt commited on
Commit
987c40f
·
verified ·
1 Parent(s): 5c72581

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -0
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from PyPDF2 import PdfReader
4
+ from langchain_huggingface import HuggingFaceEmbeddings
5
+ from langchain_core.documents import Document
6
+ import chromadb
7
+ from langchain_community.vectorstores import Chroma
8
+ from langchain_nvidia_ai_endpoints import ChatNVIDIA
9
+ from langchain_core.prompts import PromptTemplate
10
+ from langchain_core.output_parsers import StrOutputParser
11
+ from langchain_core.pydantic_v1 import BaseModel, Field
12
+ from langgraph.graph import StateGraph, END
13
+ from typing import List, TypedDict
14
+ import pandas as pd
15
+
16
+ # Set API keys
17
+ os.environ["TAVILY_API_KEY"] = "YOUR_Tavily_API_KEY"
18
+ os.environ["NVIDIA_API_KEY"] = "YOUR_NVIDIA_API_KEY"
19
+ os.environ["LANGCHAIN_PROJECT"] = "RAG Compliance Checker"
20
+
21
+ # Initialize embedding model
22
+ model_name = "dunzhang/stella_en_1.5B_v5"
23
+ embedding_model = HuggingFaceEmbeddings(
24
+ model_name=model_name,
25
+ model_kwargs={'trust_remote_code': True},
26
+ show_progress=True
27
+ )
28
+
29
+ # Define data models
30
+ class GradeDocuments(BaseModel):
31
+ binary_score: str = Field(description="Relevance score 'yes' or 'no'")
32
+
33
+ class GraphState(TypedDict):
34
+ question: str
35
+ generation: str
36
+ decision: str
37
+ documents: List[Document]
38
+
39
+ def create_workflow(retriever):
40
+ # Define workflow nodes
41
+ def retrieve(state):
42
+ print("---RETRIEVING DOCUMENTS---")
43
+ question = state["question"]
44
+ documents = retriever.invoke(question)
45
+ return {"documents": documents, "question": question}
46
+
47
+ def grade_documents(state):
48
+ print("---GRADING DOCUMENTS---")
49
+ question = state["question"]
50
+ documents = state["documents"]
51
+
52
+ llm = ChatNVIDIA(model="meta/llama-3.3-70b-instruct")
53
+ grader = llm.with_structured_output(GradeDocuments)
54
+
55
+ system = """You are a relevance grader. Determine if the document contains
56
+ information related to the question. Answer 'yes' or 'no'."""
57
+ prompt = ChatPromptTemplate.from_messages([
58
+ ("system", system),
59
+ ("human", "Document:\n{document}\n\nQuestion: {question}")
60
+ ])
61
+
62
+ filtered_docs = []
63
+ for doc in documents:
64
+ response = (prompt | grader).invoke({
65
+ "question": question,
66
+ "document": doc.page_content
67
+ })
68
+ if response.binary_score == "yes":
69
+ filtered_docs.append(doc)
70
+
71
+ return {"documents": filtered_docs, "question": question}
72
+
73
+ def generate_response(state):
74
+ print("---GENERATING RESPONSE---")
75
+ question = state["question"]
76
+ documents = state["documents"]
77
+
78
+ template = """Answer the question using only the context below:
79
+ Context: {context}
80
+ Question: {question}"""
81
+
82
+ prompt = PromptTemplate.from_template(template)
83
+ llm = ChatNVIDIA(model="meta/llama-3.3-70b-instruct")
84
+
85
+ chain = (
86
+ {"context": lambda _: "\n\n".join(d.page_content for d in documents), "question": RunnablePassthrough()}
87
+ | prompt
88
+ | llm
89
+ | StrOutputParser()
90
+ )
91
+
92
+ return {"generation": chain.invoke(question)}
93
+
94
+ # Build workflow
95
+ workflow = StateGraph(GraphState)
96
+ workflow.add_node("retrieve", retrieve)
97
+ workflow.add_node("grade", grade_documents)
98
+ workflow.add_node("generate", generate_response)
99
+
100
+ workflow.add_edge("retrieve", "grade")
101
+ workflow.add_conditional_edges(
102
+ "grade",
103
+ lambda state: "generate" if len(state["documents"]) > 0 else END,
104
+ {"generate": "generate"}
105
+ )
106
+ workflow.add_edge("generate", END)
107
+
108
+ return workflow.compile()
109
+
110
+ def process_documents(folder_path):
111
+ """Process PDF files from uploaded folder"""
112
+ documents = []
113
+ for filename in os.listdir(folder_path):
114
+ if filename.endswith(".pdf"):
115
+ path = os.path.join(folder_path, filename)
116
+ try:
117
+ reader = PdfReader(path)
118
+ text = "\n".join([page.extract_text() for page in reader.pages])
119
+ documents.append(Document(
120
+ page_content=text,
121
+ metadata={"source": filename}
122
+ ))
123
+ except Exception as e:
124
+ print(f"Error processing {filename}: {str(e)}")
125
+ return documents
126
+
127
+ def analyze_requirements(csv_file, documents):
128
+ """Main analysis function"""
129
+ # Create vector store
130
+ client = chromadb.PersistentClient()
131
+ vector_store = Chroma(
132
+ client=client,
133
+ collection_name="dynamic_rag",
134
+ embedding_function=embedding_model
135
+ )
136
+
137
+ # Add documents in batches
138
+ batch_size = 500
139
+ for i in range(0, len(documents), batch_size):
140
+ batch = documents[i:i+batch_size]
141
+ vector_store.add_documents(batch, ids=[str(n) for n in range(len(batch))])
142
+
143
+ retriever = vector_store.as_retriever(search_kwargs={"k": 5})
144
+ app = create_workflow(retriever)
145
+
146
+ # Process requirements
147
+ df = pd.read_csv(csv_file.name)
148
+ results = []
149
+
150
+ for req in df['Requirement']:
151
+ response = app.invoke({"question": req})
152
+ results.append({
153
+ "Requirement": req,
154
+ "Response": response["generation"],
155
+ "Status": "Processed"
156
+ })
157
+
158
+ return pd.DataFrame(results)
159
+
160
+ # Gradio interface
161
+ with gr.Blocks(title="RAG Compliance Checker") as interface:
162
+ gr.Markdown("# AI Compliance Assistant")
163
+ gr.Markdown("Upload documents and requirements CSV for compliance analysis")
164
+
165
+ with gr.Row():
166
+ with gr.Column():
167
+ doc_upload = gr.File(label="Upload Documents Folder", file_count="directory")
168
+ csv_upload = gr.File(label="Upload Requirements CSV", file_types=[".csv"])
169
+ submit_btn = gr.Button("Analyze", variant="primary")
170
+
171
+ with gr.Column():
172
+ results_table = gr.DataFrame(
173
+ label="Analysis Results",
174
+ headers=["Requirement", "Response", "Status"],
175
+ interactive=False
176
+ )
177
+ status = gr.Textbox(label="Processing Status")
178
+
179
+ submit_btn.click(
180
+ fn=lambda doc, csv: analyze_requirements(csv, process_documents(doc)),
181
+ inputs=[doc_upload, csv_upload],
182
+ outputs=results_table,
183
+ api_name="analyze"
184
+ )
185
+
186
+ if __name__ == "__main__":
187
+ interface.launch(server_name="0.0.0.0", server_port=7860, share=True)