| import asyncio |
| import json |
| import os |
| import time |
| from datetime import datetime |
| from pathlib import Path |
| from typing import Dict, List |
|
|
| import streamlit as st |
| from acp_sdk.client import Client |
| from acp_sdk.models import Message, MessagePart |
| from rich.console import Console |
|
|
| from gaf_guard.clients.stream_adaptors import get_adapter |
| from gaf_guard.core.models import WorkflowMessage |
| from gaf_guard.toolkit.enums import MessageType, Role, StreamStatus, UserInputType |
| from gaf_guard.toolkit.file_utils import resolve_file_paths |
|
|
|
|
| GAF_GUARD_ROOT = Path(__file__).parent.parent.absolute() |
|
|
| |
| st.markdown( |
| """ |
| <style> |
| .header { |
| padding: 1rem; |
| background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); |
| color: white; |
| text-align: center; |
| border-radius: 10px; |
| margin-bottom: 1rem; |
| } |
| .message-card { |
| padding: 1rem; |
| border-left: 4px solid #667eea; |
| background-color: #f8f9fa; |
| border-radius: 5px; |
| margin: 0.5rem 0; |
| } |
| .stApp[data-teststate=running] .stChatInput textarea, |
| .stApp[data-test-script-state=running] .stChatInput textarea { |
| display: none !important; |
| } |
| .stTextInput {{ |
| position: fixed; |
| bottom: 3rem; |
| }} |
| .block-container { |
| padding-top: 1rem; |
| padding-bottom: 0rem; |
| padding-left: 5rem; |
| padding-right: 5rem; |
| } |
| </style> |
| """, |
| unsafe_allow_html=True, |
| ) |
|
|
| |
| st.session_state.priority = ["low", "medium", "high"] |
| st.session_state.initial_risks_master = ["Toxic output", "Hallucination"] |
| st.set_page_config( |
| page_title="GAF Guard - A real-time monitoring system for risk assessment and drift monitoring.", |
| layout="wide", |
| |
| ) |
| console = Console(log_time=True) |
| run_configs = { |
| "RiskGeneratorAgent": { |
| "risk_questionnaire_cot": os.path.join( |
| GAF_GUARD_ROOT, "chain_of_thought", "risk_questionnaire.json" |
| ) |
| }, |
| "DriftMonitoringAgent": { |
| "drift_monitoring_cot": os.path.join( |
| GAF_GUARD_ROOT, "chain_of_thought", "drift_monitoring.json" |
| ), |
| "drift_threshold": ( |
| st.session_state.drift_threshold |
| if "drift_threshold" in st.session_state |
| else 8 |
| ), |
| }, |
| } |
| resolve_file_paths(run_configs) |
|
|
|
|
| def file_uploaded(): |
| st.session_state.prompt_file = st.session_state.prompt_file_uploader.getvalue() |
| message = WorkflowMessage( |
| name="GAF Guard Client", |
| type=MessageType.GAF_GUARD_QUERY, |
| role=Role.SYSTEM, |
| content=f"**File uploaded successfully:** {st.session_state.prompt_file_uploader.name}", |
| accept=UserInputType.INPUT_PROMPT, |
| run_configs=run_configs, |
| ) |
| st.session_state.messages.append(message) |
| render(message, simulate=True) |
|
|
|
|
| def play_button(adapter_type): |
| if st.session_state.setdefault( |
| "stream_adaptor", |
| get_adapter( |
| adapter_type, |
| config={"byte_data": st.session_state.prompt_file}, |
| ), |
| ): |
| st.session_state.stream_status = StreamStatus.ACTIVE |
| else: |
| st.write("Selected adaptor is not available.") |
|
|
|
|
| def pause_button(): |
| st.session_state.stream_status = StreamStatus.PAUSED |
| st.session_state.messages.append( |
| WorkflowMessage( |
| name="GAF Guard Client", |
| type=MessageType.GAF_GUARD_QUERY, |
| role=Role.SYSTEM, |
| content=f"**:red[Alert:]** Current input stream is paused. Please click on **Start** to resume.", |
| accept=UserInputType.INPUT_PROMPT, |
| run_configs=run_configs, |
| ) |
| ) |
|
|
|
|
| @st.fragment |
| def pause_fragment(adapter_type): |
| st.button( |
| "⏸️ Pause", |
| use_container_width=True, |
| disabled=( |
| adapter_type == "Select" |
| or st.session_state.stream_status |
| in [StreamStatus.STOPPED, StreamStatus.PAUSED] |
| ), |
| on_click=pause_button, |
| ) |
|
|
|
|
| def add_sidebar(): |
|
|
| with st.sidebar: |
| st.sidebar.title("⚙️ Settings") |
| if st.session_state.sidebar_display in ["settings_view", "input_prompt_source"]: |
| st.subheader(f":blue[Taxonomy:] {st.session_state.taxonomy}") |
| st.subheader(f":blue[Drift Threshold:] {st.session_state.drift_threshold}") |
| if st.session_state.sidebar_display == "settings_edit": |
| st.session_state.taxonomy = st.selectbox( |
| "Risk Taxonomy", |
| ("IBM Risk Atlas"), |
| ) |
| st.session_state.drift_threshold = st.slider( |
| "Drift Threshold", |
| value=st.session_state.drift_threshold, |
| min_value=2, |
| max_value=10, |
| step=1, |
| ) |
| if st.session_state.sidebar_display == "input_prompt_source": |
| st.sidebar.title("⚙️ Streaming Source") |
| adapter_type = st.selectbox( |
| "Select Input Prompt Source", |
| ["Select", "JSON"], |
| help="Choose your streaming source", |
| index=0, |
| disabled="stream_adaptor" in st.session_state, |
| ) |
| if adapter_type == "JSON": |
| st.subheader("JSON File Source") |
| st.file_uploader( |
| "OK", |
| accept_multiple_files=False, |
| type="json", |
| label_visibility="collapsed", |
| on_change=file_uploaded, |
| key="prompt_file_uploader", |
| disabled="stream_adaptor" in st.session_state, |
| ) |
|
|
| |
| col1, col2 = st.columns(2) |
| with col1: |
| st.button( |
| "▶️ Start", |
| use_container_width=True, |
| disabled=( |
| adapter_type == "Select" |
| or "prompt_file" not in st.session_state |
| or st.session_state.stream_status == StreamStatus.ACTIVE |
| ), |
| on_click=play_button, |
| args=(adapter_type,), |
| ) |
| with col2: |
| pause_fragment(adapter_type) |
| st.markdown( |
| "**Note:** Pause button will temporarily halt the stream after processing the current prompt." |
| ) |
|
|
| st.divider() |
| ai_atlas_button = st.container( |
| horizontal_alignment="center", vertical_alignment="bottom", height="stretch" |
| ) |
| ai_atlas_button.markdown(":blue[Powered by:]", text_alignment="center") |
| ai_atlas_button.link_button( |
| "AI Atlas Nexus", |
| "https://github.com/IBM/ai-atlas-nexus", |
| icon=":material/thumb_up:", |
| type="secondary", |
| ) |
| if hasattr(st.session_state, "client_session"): |
| ai_atlas_button.markdown( |
| f"Client Id: {str(st.session_state.client_session._session.id)[0:13]} \n :violet-badge[:material/rocket_launch: Connected to :yellow[GAF Guard] Server:] :orange-badge[:material/check: {st.session_state.host}:{st.session_state.port}]", |
| text_alignment="center", |
| ) |
| else: |
| ai_atlas_button.markdown( |
| f":red-badge[:material/mimo_disconnect: Client Disconnected]", |
| text_alignment="center", |
| ) |
|
|
|
|
| |
| def render(message: WorkflowMessage, simulate=False): |
|
|
| def simulate_agent_response( |
| role: Role, |
| message: str, |
| json_data: Dict = None, |
| simulate: bool = False, |
| accept: Dict = None, |
| ): |
| with st.chat_message(role): |
| if simulate: |
| message_placeholder = st.empty() |
| full_response = "" |
| for chunk in message.split(): |
| full_response += chunk + " " |
| time.sleep(0.05) |
| message_placeholder.markdown(full_response + "▌") |
| message_placeholder.markdown(full_response) |
| else: |
| st.markdown(message) |
|
|
| if json_data: |
| st.json(json_data, expanded=4) |
| elif accept == UserInputType.INITIAL_RISKS: |
| st.button( |
| "Add Initial Risks", |
| on_click=initial_risks_selector, |
| disabled=hasattr(st.session_state, "initial_risks"), |
| ) |
| st.session_state.disabled_input = False |
| elif accept == UserInputType.INPUT_PROMPT: |
| st.session_state.sidebar_display = "input_prompt_source" |
| st.session_state.disabled_input = True |
|
|
| if not message.display: |
| return False |
| if message.type == MessageType.GAF_GUARD_WF_STARTED: |
| return False |
| if message.type == MessageType.GAF_GUARD_WF_COMPLETED: |
| return False |
| elif message.type == MessageType.GAF_GUARD_STEP_STARTED: |
| simulate_agent_response( |
| role=message.role.value, |
| message=f"##### :blue[Workflow Step:] **{message.name}**", |
| simulate=simulate, |
| accept=message.accept, |
| ) |
| elif message.type == MessageType.GAF_GUARD_STEP_COMPLETED: |
| |
| |
| |
| |
| |
| |
| return False |
| elif message.type == MessageType.GAF_GUARD_STEP_DATA: |
| if isinstance(message.content, dict): |
| if message.name == "Input Prompt": |
| simulate_agent_response( |
| role=message.role.value, |
| message=f"###### :yellow[**Prompt {message.content["prompt_index"]}**]: {message.content["prompt"]}", |
| simulate=simulate, |
| accept=message.accept, |
| ) |
| else: |
| if len(message.content.items()) > 2: |
| data = [] |
| for key, value in message.content.items(): |
| data.append({key.title(): value}) |
|
|
| simulate_agent_response( |
| role=message.role.value, |
| message="###### :yellow[Risk Report]", |
| json_data=data, |
| simulate=simulate, |
| accept=message.accept, |
| ) |
| else: |
| for key, value in message.content.items(): |
| if key == "identified_risks": |
| st.session_state.risks = value |
| if isinstance(value, List) or isinstance(value, Dict): |
| simulate_agent_response( |
| role=message.role.value, |
| message=f"###### :yellow[{key.replace('_', ' ').title()}]", |
| json_data=value, |
| simulate=simulate, |
| accept=message.accept, |
| ) |
| elif isinstance(value, str) and key.endswith("alert"): |
| simulate_agent_response( |
| role=message.role.value, |
| message=f"###### :yellow[{key.replace('_', ' ').title()}]: :red[{value}]", |
| simulate=simulate, |
| accept=message.accept, |
| ) |
| else: |
| simulate_agent_response( |
| role=message.role.value, |
| message=f"###### :yellow[{key.replace('_', ' ').title()}]: {value}", |
| simulate=simulate, |
| accept=message.accept, |
| ) |
| elif message.type == MessageType.GAF_GUARD_QUERY: |
| simulate_agent_response( |
| role=message.role.value, |
| message=f":blue[{message.content}]", |
| simulate=simulate, |
| accept=message.accept, |
| ) |
| else: |
| |
| if message.content: |
| simulate_agent_response( |
| role=message.role.value, |
| message=message.content, |
| simulate=simulate, |
| accept=message.accept, |
| ) |
|
|
| return True |
|
|
|
|
| @st.dialog("Initial risks", width="medium") |
| def initial_risks_selector(): |
|
|
| def add_row(): |
| st.session_state.setdefault("initial_risks", {}).update( |
| { |
| str(len(st.session_state.initial_risks)): { |
| "risk": st.session_state.initial_risks_master[0], |
| "priority": "low", |
| "threshold": 0.01, |
| } |
| } |
| ) |
|
|
| if "initial_risks" not in st.session_state: |
| add_row() |
|
|
| st.button("Add New Row", type="primary", on_click=add_row) |
| with st.form("input_form"): |
|
|
| |
| col1, col2, col3 = st.columns(3) |
|
|
| for key, initial_risk in st.session_state.initial_risks.items(): |
| with col1: |
| value = st.selectbox( |
| "Risk" if key == "0" else " ", |
| tuple(st.session_state.initial_risks_master), |
| key=f"col1{key}", |
| index=st.session_state.initial_risks_master.index( |
| initial_risk["risk"] |
| ), |
| ) |
| st.session_state.initial_risks[key].update({"risk": value}) |
| with col2: |
| value = st.selectbox( |
| "Priority" if key == "0" else " ", |
| tuple(st.session_state.priority), |
| key=f"col2{key}", |
| index=st.session_state.priority.index(initial_risk["priority"]), |
| ) |
| st.session_state.initial_risks[key].update({"priority": value}) |
| with col3: |
| threshold = st.number_input( |
| "Threshold" if key == "0" else " ", |
| key=f"col3{key}", |
| value=initial_risk["threshold"], |
| ) |
| st.session_state.initial_risks[key].update({"threshold": threshold}) |
|
|
| submitted = st.form_submit_button("Submit") |
|
|
| if submitted: |
| st.session_state.user_input = json.dumps( |
| list(st.session_state.initial_risks.values()) |
| ) |
| st.rerun() |
|
|
|
|
| @st.dialog( |
| "GAF Guard Connect", |
| width="medium", |
| dismissible=False, |
| icon=":material/login:", |
| ) |
| def connect_screen_dialog(): |
| if hasattr(st.session_state, "error"): |
| st.error(st.session_state.error, icon="🚨") |
| with st.form("login_form"): |
| input_host = st.text_input("GAF Guard Host", value="localhost") |
| input_port = st.number_input("GAF Guard Port", value=8000) |
| submitted = st.form_submit_button("Connect", type="primary") |
|
|
| if submitted: |
| if hasattr(st.session_state, "error"): |
| del st.session_state["error"] |
| st.session_state.host = input_host |
| st.session_state.port = input_port |
| st.rerun() |
|
|
|
|
| @st.dialog( |
| "GAF Guard Connect", |
| width="medium", |
| dismissible=False, |
| icon=":material/login:", |
| ) |
| def connect(): |
|
|
| async def ping_server(client): |
| await client.ping() |
|
|
| with st.status( |
| f"Connecting to GAF Guard using host: :blue[**{st.session_state.host}**] and port: :blue[**{st.session_state.port}**]", |
| expanded=True, |
| ) as status: |
| try: |
| client = Client( |
| base_url=f"http://{st.session_state.host}:{st.session_state.port}", |
| verify=True, |
| ) |
| |
| st.write("Client created...") |
| except Exception as e: |
| st.session_state.error = "Failed to connect. Check hostname and port." |
| st.rerun() |
|
|
| st.session_state.client_session = client.session() |
| st.write("Client session created...") |
|
|
| st.session_state.drift_threshold = 8 |
| st.session_state.disabled_input = False |
| st.session_state.stream_status = StreamStatus.STOPPED |
| st.session_state.sidebar_display = "settings_edit" |
| st.session_state.messages = [ |
| WorkflowMessage( |
| name="GAF Guard Client", |
| type=MessageType.CLIENT_INPUT, |
| role=Role.USER, |
| accept=UserInputType.USER_INTENT, |
| run_configs=run_configs, |
| ) |
| ] |
| st.write("Client initialisation done...") |
|
|
| |
| console.print( |
| f"[[bold white]{datetime.now().strftime('%d-%m-%Y %H:%M:%S')}[/]] [italic bold white] :rocket: Connected to GAF Guard Server at[/italic bold white] [bold white]{st.session_state.host}:{st.session_state.port}[/bold white]" |
| ) |
| console.print( |
| f"[[bold white]{datetime.now().strftime('%d-%m-%Y %H:%M:%S')}[/]] Client Id: {st.session_state.client_session._session.id}" |
| ) |
| |
| |
| |
|
|
| |
| |
| |
|
|
| status.update( |
| label=f":material/rocket_launch: Connected to :yellow[**GAF Guard**] Server: :orange-badge[:material/check: {st.session_state.host}:{st.session_state.port}]", |
| state="complete", |
| expanded=True, |
| ) |
| time.sleep(1) |
|
|
| st.rerun() |
|
|
|
|
| def submit_input(): |
| st.session_state.sidebar_display = "settings_view" |
|
|
|
|
| async def app(): |
|
|
| st.title(f":yellow[GAF Guard]", text_alignment="center") |
| st.subheader( |
| "A real-time monitoring system for risk assessment and drift monitoring", |
| text_alignment="center", |
| divider=True, |
| ) |
|
|
| |
| add_sidebar() |
|
|
| |
| for message in st.session_state.messages: |
| render(message) |
|
|
| last_message: WorkflowMessage = st.session_state.messages[-1] |
|
|
| if st.session_state.stream_status == StreamStatus.ACTIVE: |
| user_input = st.session_state.stream_adaptor.next() |
| if not user_input: |
| del st.session_state["stream_adaptor"] |
| st.session_state.stream_status = StreamStatus.STOPPED |
| st.session_state.messages.append( |
| WorkflowMessage( |
| name="GAF Guard Client", |
| type=MessageType.GAF_GUARD_QUERY, |
| role=Role.SYSTEM, |
| content=f"**The streaming input has ended. Please choose a streaming source and start again.**", |
| accept=UserInputType.INPUT_PROMPT, |
| run_configs=run_configs, |
| ) |
| ) |
| st.rerun() |
| else: |
| |
| user_input = st.chat_input( |
| placeholder="Enter your response here", |
| key="user_input", |
| disabled=st.session_state.disabled_input, |
| on_submit=submit_input, |
| ) |
|
|
| if not user_input: |
| st.stop() |
| else: |
| COMPLETED = False |
| async for event in st.session_state.client_session.run_stream( |
| agent="orchestrator", |
| input=[ |
| Message( |
| parts=[ |
| MessagePart( |
| content=WorkflowMessage( |
| name="GAF Guard Client", |
| type=( |
| MessageType.CLIENT_RESPONSE |
| if last_message.type == MessageType.GAF_GUARD_QUERY |
| else MessageType.CLIENT_INPUT |
| ), |
| role=Role.USER, |
| content={last_message.accept: user_input}, |
| run_configs=run_configs, |
| ).model_dump_json(), |
| content_type="text/plain", |
| ) |
| ] |
| ) |
| ], |
| ): |
| if event.type == "message.part": |
| message = WorkflowMessage(**json.loads(event.part.content)) |
| if render(message, simulate=True): |
| st.session_state.messages.append(message) |
| elif event.type == "run.awaiting": |
| if hasattr(event, "run"): |
| message = WorkflowMessage( |
| **json.loads(event.run.await_request.message.parts[0].content) |
| ) |
| if message.accept == UserInputType.INPUT_PROMPT: |
| if st.session_state.stream_status == StreamStatus.STOPPED: |
| render(message, simulate=True) |
| else: |
| message.display = False |
| else: |
| render(message, simulate=True) |
|
|
| st.session_state.messages.append(message) |
| st.session_state.disabled_input = True |
| st.rerun() |
|
|
|
|
| if hasattr(st.session_state, "client_session"): |
| asyncio.run(app()) |
| elif ( |
| not hasattr(st.session_state, "error") |
| and hasattr(st.session_state, "host") |
| and hasattr(st.session_state, "port") |
| ): |
| connect() |
| else: |
| connect_screen_dialog() |
|
|