File size: 5,598 Bytes
22ecb24
 
 
 
1a4c70c
 
22ecb24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a4c70c
 
 
22ecb24
 
1a4c70c
22ecb24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a4c70c
22ecb24
 
 
 
1a4c70c
22ecb24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a4c70c
22ecb24
 
 
 
 
 
 
 
 
 
 
1a4c70c
22ecb24
 
 
 
 
 
1a4c70c
 
22ecb24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a4c70c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159

from dataclasses import dataclass
from enum import IntEnum
from typing import List, Optional, Dict, Tuple
import streamlit as st
import time
from dotenv import load_dotenv
load_dotenv()

from step_1_index_documents import index_documents 
from step_2_compare_documents import compare_documents
from step_3_evaluate_documents import evaluate_documents
from step_4_generate_response import generate_response

class ProcessingStep(IntEnum):
    IDLE = 0
    INDEXING = 1
    COMPARING = 2 
    EVALUATING = 3
    GENERATING = 4
    COMPLETED = 5

step_messages = {
    ProcessingStep.INDEXING: "Step 1: Indexing documents",
    ProcessingStep.COMPARING: "Step 2: Comparing documents",
    ProcessingStep.EVALUATING: "Step 3: Evaluating documents",
    ProcessingStep.GENERATING: "Step 4: Generating response"
}

@dataclass
class SessionState:
    current_step: ProcessingStep = ProcessingStep.IDLE
    completed_steps: Dict[ProcessingStep, Tuple[str, st.expander]] = None  # Store expander with message
    ground_truth_files: List[str] = None
    proposal_files: List[str] = None

    def __post_init__(self):
        if self.completed_steps is None:
            self.completed_steps = {}
        if self.ground_truth_files is None:
            self.ground_truth_files = []
        if self.proposal_files is None:
            self.proposal_files = []
    
    def __post_init__(self):
        if self.completed_steps is None:
            self.completed_steps = {}
        if self.ground_truth_files is None:
            self.ground_truth_files = []
        if self.proposal_files is None:
            self.proposal_files = []

def initialize_session_state() -> None:
    if 'state' not in st.session_state:
        st.session_state.state = SessionState()

def upload_files(label: str) -> List[str]:
    uploaded_files = st.file_uploader(label, accept_multiple_files=True)
    return uploaded_files
    # return [file.name for file in uploaded_files] if uploaded_files else []

def render_file_upload_sections() -> Tuple[List[str], List[str]]:
    col1, col2 = st.columns(2)
    
    with col1:
        st.subheader("Ground Truth")
        ground_truth_files = upload_files("Upload Ground Truth Files")
        st.session_state.state.ground_truth_files = ground_truth_files
        
    with col2:
        st.subheader("Proposals")
        proposal_files = upload_files("Upload Proposal Files")
        st.session_state.state.proposal_files = proposal_files
        
    return ground_truth_files, proposal_files

def process_step(step: ProcessingStep) -> str:
    state = st.session_state.state
    step_functions = {
        ProcessingStep.INDEXING: lambda: index_documents(state.ground_truth_files, state.proposal_files, st),
        ProcessingStep.COMPARING: lambda: compare_documents(st),
        ProcessingStep.EVALUATING: lambda: evaluate_documents(st),
        ProcessingStep.GENERATING: lambda: generate_response(st),
    }
    
    with st.spinner(f"{step_messages[state.current_step]}..."):
        result, details = step_functions[state.current_step]()
        if result:
            state.completed_steps[state.current_step] = (result, details) # Store message and details

def get_next_step(current_step: ProcessingStep) -> ProcessingStep:
    if current_step == ProcessingStep.COMPLETED:
        return ProcessingStep.COMPLETED
    return ProcessingStep(current_step + 1)

def render_processing_button() -> None:
    state = st.session_state.state
    
    button_text = "Start Processing" if state.current_step == ProcessingStep.IDLE else "Processing..."
    
    st.button(
        button_text,
        on_click=lambda: setattr(state, 'current_step', ProcessingStep.INDEXING),
        disabled=state.current_step != ProcessingStep.IDLE,
        use_container_width=True,
        type="primary"
    )

def display_progress():
    state = st.session_state.state
    for step, (message, details) in state.completed_steps.items():
        
        with st.expander(f"{step_messages[step.value]} Details - Completed"):
            st.success(message)
            st.markdown(details)  # Display details within the expander
        
def handle_processing_steps():
    state = st.session_state.state
    display_progress()

    if state.current_step not in (ProcessingStep.IDLE, ProcessingStep.COMPLETED):
        step_functions = {
            ProcessingStep.INDEXING: lambda: index_documents(state.ground_truth_files, state.proposal_files, st),
            ProcessingStep.COMPARING: lambda: compare_documents(st),
            ProcessingStep.EVALUATING: lambda: evaluate_documents(st),
            ProcessingStep.GENERATING: lambda: generate_response(st),
        }
        

        # process_step now handles the expander creation and storage in state.completed_steps
        process_step(state.current_step) 

        next_step = get_next_step(state.current_step)
        state.current_step = next_step
        st.rerun()

def render_completion_message() -> None:
    state = st.session_state.state
    if state.current_step == ProcessingStep.COMPLETED:
        
        with st.expander("All processing steps completed successfully!"):
            st.markdown(st.session_state.state.final_text)

def main() -> None:
    st.title("Contract Specification Comparison")
    
    initialize_session_state()
    
    ground_truth_files, proposal_files = render_file_upload_sections()
    
    with st.container():
        render_processing_button()
        
    with st.container():
        handle_processing_steps()
        render_completion_message()

if __name__ == "__main__":
    main()