File size: 5,304 Bytes
c7b1d79
 
02392d7
c7b1d79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02392d7
c7b1d79
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import os

import streamlit as st
from langchain.schema import SystemMessage, HumanMessage
from langchain_openai import ChatOpenAI
from langfuse import Langfuse
from langfuse.callback import CallbackHandler

from prompt_engineering.one_shot_prompt import SYS_ONE_SHOT_PROMPT
from tools.json_utils import extract_json_from_response

# --- CONFIGURATION ---
langfuse = Langfuse()
langfuse_handler = CallbackHandler()
llm = ChatOpenAI(
    temperature=0,
    model_name="deepseek-chat",
    openai_api_base=os.environ["OPENAI_API_BASE"],
    openai_api_key=os.environ["OPENAI_API_KEY"],
    callbacks=[langfuse_handler],
)

# --- SYSTEM PROMPT ---
system_prompt = SystemMessage(content=SYS_ONE_SHOT_PROMPT)

# --- STREAMLIT UI SETUP ---
st.set_page_config(page_title="ACMG/ACGS Variant Interpreter")
st.title("🔬 Germline Variant Classifier")
st.write("Enter a variant name to get ACMG and ACGS-based classification.")

variant = st.text_input("Variant (e.g. BRCA1 c.68_69delAG)")

if st.button("Interpret Variant"):
    if not variant:
        st.warning("Please enter a variant name.")
    else:
        with st.spinner("Analyzing with DeepSeek..."):
            user_prompt = HumanMessage(content=f"Variant to interpret: {variant}")
            response = llm([system_prompt, user_prompt])
            try:
                parsed = extract_json_from_response(response.content)
                # Store the parsed response in session_state for later reference
                st.session_state["llm_response"] = parsed
                st.session_state["variant"] = variant

                st.subheader("📄 ACMG Classification")
                st.json(parsed.get("acmg", {}))
                st.subheader("📄 ACGS Scoring")
                st.json(parsed.get("acgs", {}))
                st.subheader("✅ Final Consensus")
                st.json(parsed.get("final_consensus", {}))
            except Exception:
                st.error(
                    "Failed to parse response. Please check the variant format or try again."
                )
                st.text(response.content)

# --- FEEDBACK SECTION ---
if "llm_response" in st.session_state:
    st.markdown("## Provide Your Feedback")
    st.write(
        "If you see any inaccuracies or have corrections to the predictions, "
        "please select your corrections below:"
    )

    # Corrected ACMG Classification (multi-select from 28 criteria)
    acmg_criteria_options = [
        "PVS1",
        "PS1",
        "PS2",
        "PS3",
        "PS4",
        "PM1",
        "PM2",
        "PM3",
        "PM4",
        "PM5",
        "PM6",
        "PP1",
        "PP2",
        "PP3",
        "PP4",
        "PP5",
        "BA1",
        "BS1",
        "BS2",
        "BS3",
        "BS4",
        "BP1",
        "BP2",
        "BP3",
        "BP4",
        "BP5",
        "BP6",
        "BP7",
    ]
    corrected_acmg = st.multiselect(
        "Select Corrected ACMG Classification Criteria",
        options=acmg_criteria_options,
        help="Search and select the criteria you think apply.",
    )

    # Corrected ACGS Scoring (number input with step increment/decrement)
    corrected_acgs_score = st.number_input(
        "Enter Corrected ACGS Score",
        min_value=-50,
        max_value=50,
        value=0,
        step=1,
        help="Use the arrow buttons to adjust the score.",
    )

    # Corrected Final Consensus (dropdown with 5 classifications)
    consensus_options = [
        "Pathogenic",
        "Likely Pathogenic",
        "Uncertain Significance",
        "Likely Benign",
        "Benign",
    ]
    corrected_final_consensus = st.selectbox(
        "Select Corrected Final Consensus Classification",
        options=consensus_options,
        help="Select the classification that you think is most appropriate.",
    )
    corrected_explanation = st.text_area(
        "Provide Explanation for Corrections",
        help="Explain why you made these corrections.",
        placeholder="Your explanation here...",
        height=100,
    )

    if st.button("Submit Feedback"):
        # Create a dictionary with the feedback data
        feedback_data = {
            "variant": st.session_state.get("variant", variant),
            "original_response": st.session_state["llm_response"],
            "feedback": {
                "acmg_criteria": corrected_acmg,
                "acgs_score": corrected_acgs_score,
                "final_consensus": corrected_final_consensus,
            },
        }

        # Log the feedback as a custom event to Langfuse
        langfuse.create_dataset_item(
            dataset_name="deepva-dev-feedback-v1",
            input={
                "variant": feedback_data["variant"],
                "original_response": feedback_data["original_response"],
            },
            expected_output={
                "acmg_criteria": feedback_data["feedback"]["acmg_criteria"],
                "acgs_score": feedback_data["feedback"]["acgs_score"],
                "final_consensus": feedback_data["feedback"]["final_consensus"],
                "explanation": corrected_explanation,
            },
        )

        # Assert that all events were sent to the Langfuse API
        langfuse.flush()

        st.success("Thank you for your feedback!")