mvp / app.py
Math
Add functions
22ecb24
from dataclasses import dataclass
from enum import IntEnum
from typing import List, Optional, Dict, Tuple
import streamlit as st
import time
from dotenv import load_dotenv
load_dotenv()
from step_1_index_documents import index_documents
from step_2_compare_documents import compare_documents
from step_3_evaluate_documents import evaluate_documents
from step_4_generate_response import generate_response
class ProcessingStep(IntEnum):
IDLE = 0
INDEXING = 1
COMPARING = 2
EVALUATING = 3
GENERATING = 4
COMPLETED = 5
step_messages = {
ProcessingStep.INDEXING: "Step 1: Indexing documents",
ProcessingStep.COMPARING: "Step 2: Comparing documents",
ProcessingStep.EVALUATING: "Step 3: Evaluating documents",
ProcessingStep.GENERATING: "Step 4: Generating response"
}
@dataclass
class SessionState:
current_step: ProcessingStep = ProcessingStep.IDLE
completed_steps: Dict[ProcessingStep, Tuple[str, st.expander]] = None # Store expander with message
ground_truth_files: List[str] = None
proposal_files: List[str] = None
def __post_init__(self):
if self.completed_steps is None:
self.completed_steps = {}
if self.ground_truth_files is None:
self.ground_truth_files = []
if self.proposal_files is None:
self.proposal_files = []
def __post_init__(self):
if self.completed_steps is None:
self.completed_steps = {}
if self.ground_truth_files is None:
self.ground_truth_files = []
if self.proposal_files is None:
self.proposal_files = []
def initialize_session_state() -> None:
if 'state' not in st.session_state:
st.session_state.state = SessionState()
def upload_files(label: str) -> List[str]:
uploaded_files = st.file_uploader(label, accept_multiple_files=True)
return uploaded_files
# return [file.name for file in uploaded_files] if uploaded_files else []
def render_file_upload_sections() -> Tuple[List[str], List[str]]:
col1, col2 = st.columns(2)
with col1:
st.subheader("Ground Truth")
ground_truth_files = upload_files("Upload Ground Truth Files")
st.session_state.state.ground_truth_files = ground_truth_files
with col2:
st.subheader("Proposals")
proposal_files = upload_files("Upload Proposal Files")
st.session_state.state.proposal_files = proposal_files
return ground_truth_files, proposal_files
def process_step(step: ProcessingStep) -> str:
state = st.session_state.state
step_functions = {
ProcessingStep.INDEXING: lambda: index_documents(state.ground_truth_files, state.proposal_files, st),
ProcessingStep.COMPARING: lambda: compare_documents(st),
ProcessingStep.EVALUATING: lambda: evaluate_documents(st),
ProcessingStep.GENERATING: lambda: generate_response(st),
}
with st.spinner(f"{step_messages[state.current_step]}..."):
result, details = step_functions[state.current_step]()
if result:
state.completed_steps[state.current_step] = (result, details) # Store message and details
def get_next_step(current_step: ProcessingStep) -> ProcessingStep:
if current_step == ProcessingStep.COMPLETED:
return ProcessingStep.COMPLETED
return ProcessingStep(current_step + 1)
def render_processing_button() -> None:
state = st.session_state.state
button_text = "Start Processing" if state.current_step == ProcessingStep.IDLE else "Processing..."
st.button(
button_text,
on_click=lambda: setattr(state, 'current_step', ProcessingStep.INDEXING),
disabled=state.current_step != ProcessingStep.IDLE,
use_container_width=True,
type="primary"
)
def display_progress():
state = st.session_state.state
for step, (message, details) in state.completed_steps.items():
with st.expander(f"{step_messages[step.value]} Details - Completed"):
st.success(message)
st.markdown(details) # Display details within the expander
def handle_processing_steps():
state = st.session_state.state
display_progress()
if state.current_step not in (ProcessingStep.IDLE, ProcessingStep.COMPLETED):
step_functions = {
ProcessingStep.INDEXING: lambda: index_documents(state.ground_truth_files, state.proposal_files, st),
ProcessingStep.COMPARING: lambda: compare_documents(st),
ProcessingStep.EVALUATING: lambda: evaluate_documents(st),
ProcessingStep.GENERATING: lambda: generate_response(st),
}
# process_step now handles the expander creation and storage in state.completed_steps
process_step(state.current_step)
next_step = get_next_step(state.current_step)
state.current_step = next_step
st.rerun()
def render_completion_message() -> None:
state = st.session_state.state
if state.current_step == ProcessingStep.COMPLETED:
with st.expander("All processing steps completed successfully!"):
st.markdown(st.session_state.state.final_text)
def main() -> None:
st.title("Contract Specification Comparison")
initialize_session_state()
ground_truth_files, proposal_files = render_file_upload_sections()
with st.container():
render_processing_button()
with st.container():
handle_processing_steps()
render_completion_message()
if __name__ == "__main__":
main()