dsm5_diagnosis / app.py
pohjie
add most common dsm5 disorders
0afa1d3
raw
history blame
2.37 kB
import gradio as gr
import json
from pathlib import Path
from sklearn.metrics.pairwise import cosine_similarity
from InstructorEmbedding import INSTRUCTOR
model = INSTRUCTOR("hkunlp/instructor-large")
EMBED_FILE_PATH = Path(__file__).parent / "top15_disorders_emb.json"
case_note_instruction = (
"Represent the case note for a possible DSM-5 mental health diagnosis:"
)
def read_json_file(file_path):
with open(file_path, "r", encoding="utf-8") as file:
data = json.load(file)
return data
def get_case_note_embedding(case_note):
case_note_emb = model.encode([[case_note_instruction, case_note]])
return case_note_emb
def get_top_n_diagnoses(case_note):
top_n = 3
case_note_emb = get_case_note_embedding(case_note)
diagnoses_embed = read_json_file(EMBED_FILE_PATH)["diagnoses"]
diagnosis_list = []
for diagnosis in diagnoses_embed:
diagnosis_name = diagnosis["name"]
criteria_scores = (
[]
) # Store criterion descriptions and scores for each diagnosis
count = 0
sum_score = 0
for criterion in diagnosis["criteria"]:
count += 1
score = cosine_similarity(criterion["embedding"], case_note_emb)
sum_score += score
criterion_data = {
"description": criterion["description"],
"score": float(score),
}
if "symptom_list" in criterion:
criterion_data["symptom_list"] = criterion["symptom_list"]
criteria_scores.append(criterion_data)
sum_score /= count
diagnosis_list.append(
{
"name": diagnosis_name,
"sum_score": float(sum_score),
"criteria_scores": criteria_scores,
}
)
# Sort the diagnoses based on the sum_score in descending order
sorted_diagnoses = sorted(
diagnosis_list, key=lambda x: x["sum_score"], reverse=True
)
# Select the top n diagnoses
top_n_diagnoses = sorted_diagnoses[:top_n]
# Convert the top_n_diagnoses to JSON format
top_n_diagnoses_json = json.dumps(top_n_diagnoses, indent=2)
return top_n_diagnoses_json
def greet(name):
return "Hello " + name + "!!"
iface = gr.Interface(fn=get_top_n_diagnoses, inputs="text", outputs="text")
iface.launch()