|
|
|
from dataclasses import dataclass |
|
from enum import IntEnum |
|
from typing import List, Optional, Dict, Tuple |
|
import streamlit as st |
|
import time |
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
|
|
from step_1_index_documents import index_documents |
|
from step_2_compare_documents import compare_documents |
|
from step_3_evaluate_documents import evaluate_documents |
|
from step_4_generate_response import generate_response |
|
|
|
class ProcessingStep(IntEnum): |
|
IDLE = 0 |
|
INDEXING = 1 |
|
COMPARING = 2 |
|
EVALUATING = 3 |
|
GENERATING = 4 |
|
COMPLETED = 5 |
|
|
|
step_messages = { |
|
ProcessingStep.INDEXING: "Step 1: Indexing documents", |
|
ProcessingStep.COMPARING: "Step 2: Comparing documents", |
|
ProcessingStep.EVALUATING: "Step 3: Evaluating documents", |
|
ProcessingStep.GENERATING: "Step 4: Generating response" |
|
} |
|
|
|
@dataclass |
|
class SessionState: |
|
current_step: ProcessingStep = ProcessingStep.IDLE |
|
completed_steps: Dict[ProcessingStep, Tuple[str, st.expander]] = None |
|
ground_truth_files: List[str] = None |
|
proposal_files: List[str] = None |
|
|
|
def __post_init__(self): |
|
if self.completed_steps is None: |
|
self.completed_steps = {} |
|
if self.ground_truth_files is None: |
|
self.ground_truth_files = [] |
|
if self.proposal_files is None: |
|
self.proposal_files = [] |
|
|
|
def __post_init__(self): |
|
if self.completed_steps is None: |
|
self.completed_steps = {} |
|
if self.ground_truth_files is None: |
|
self.ground_truth_files = [] |
|
if self.proposal_files is None: |
|
self.proposal_files = [] |
|
|
|
def initialize_session_state() -> None: |
|
if 'state' not in st.session_state: |
|
st.session_state.state = SessionState() |
|
|
|
def upload_files(label: str) -> List[str]: |
|
uploaded_files = st.file_uploader(label, accept_multiple_files=True) |
|
return uploaded_files |
|
|
|
|
|
def render_file_upload_sections() -> Tuple[List[str], List[str]]: |
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
st.subheader("Ground Truth") |
|
ground_truth_files = upload_files("Upload Ground Truth Files") |
|
st.session_state.state.ground_truth_files = ground_truth_files |
|
|
|
with col2: |
|
st.subheader("Proposals") |
|
proposal_files = upload_files("Upload Proposal Files") |
|
st.session_state.state.proposal_files = proposal_files |
|
|
|
return ground_truth_files, proposal_files |
|
|
|
def process_step(step: ProcessingStep) -> str: |
|
state = st.session_state.state |
|
step_functions = { |
|
ProcessingStep.INDEXING: lambda: index_documents(state.ground_truth_files, state.proposal_files, st), |
|
ProcessingStep.COMPARING: lambda: compare_documents(st), |
|
ProcessingStep.EVALUATING: lambda: evaluate_documents(st), |
|
ProcessingStep.GENERATING: lambda: generate_response(st), |
|
} |
|
|
|
with st.spinner(f"{step_messages[state.current_step]}..."): |
|
result, details = step_functions[state.current_step]() |
|
if result: |
|
state.completed_steps[state.current_step] = (result, details) |
|
|
|
def get_next_step(current_step: ProcessingStep) -> ProcessingStep: |
|
if current_step == ProcessingStep.COMPLETED: |
|
return ProcessingStep.COMPLETED |
|
return ProcessingStep(current_step + 1) |
|
|
|
def render_processing_button() -> None: |
|
state = st.session_state.state |
|
|
|
button_text = "Start Processing" if state.current_step == ProcessingStep.IDLE else "Processing..." |
|
|
|
st.button( |
|
button_text, |
|
on_click=lambda: setattr(state, 'current_step', ProcessingStep.INDEXING), |
|
disabled=state.current_step != ProcessingStep.IDLE, |
|
use_container_width=True, |
|
type="primary" |
|
) |
|
|
|
def display_progress(): |
|
state = st.session_state.state |
|
for step, (message, details) in state.completed_steps.items(): |
|
|
|
with st.expander(f"{step_messages[step.value]} Details - Completed"): |
|
st.success(message) |
|
st.markdown(details) |
|
|
|
def handle_processing_steps(): |
|
state = st.session_state.state |
|
display_progress() |
|
|
|
if state.current_step not in (ProcessingStep.IDLE, ProcessingStep.COMPLETED): |
|
step_functions = { |
|
ProcessingStep.INDEXING: lambda: index_documents(state.ground_truth_files, state.proposal_files, st), |
|
ProcessingStep.COMPARING: lambda: compare_documents(st), |
|
ProcessingStep.EVALUATING: lambda: evaluate_documents(st), |
|
ProcessingStep.GENERATING: lambda: generate_response(st), |
|
} |
|
|
|
|
|
|
|
process_step(state.current_step) |
|
|
|
next_step = get_next_step(state.current_step) |
|
state.current_step = next_step |
|
st.rerun() |
|
|
|
def render_completion_message() -> None: |
|
state = st.session_state.state |
|
if state.current_step == ProcessingStep.COMPLETED: |
|
|
|
with st.expander("All processing steps completed successfully!"): |
|
st.markdown(st.session_state.state.final_text) |
|
|
|
def main() -> None: |
|
st.title("Contract Specification Comparison") |
|
|
|
initialize_session_state() |
|
|
|
ground_truth_files, proposal_files = render_file_upload_sections() |
|
|
|
with st.container(): |
|
render_processing_button() |
|
|
|
with st.container(): |
|
handle_processing_steps() |
|
render_completion_message() |
|
|
|
if __name__ == "__main__": |
|
main() |