File size: 5,978 Bytes
e9af536
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e4f0c7
 
 
 
 
 
 
 
e9af536
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e4f0c7
a53ec7d
 
 
 
 
 
 
1e4f0c7
a53ec7d
e9af536
 
 
 
 
 
 
a53ec7d
 
 
 
 
 
e9af536
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import torch
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer, AutoModel
import chromadb
import gradio as gr
from sklearn.metrics import precision_score, recall_score, f1_score

# Mean Pooling - Take attention mask into account for correct averaging
def meanpooling(output, mask):
    embeddings = output[0]  # First element of model_output contains all token embeddings
    mask = mask.unsqueeze(-1).expand(embeddings.size()).float()
    return torch.sum(embeddings * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)

# Load the dataset
dataset = load_dataset("thankrandomness/mimic-iii-sample")

# Split the dataset into train and validation sets
split_dataset = dataset['train'].train_test_split(test_size=0.2, seed=42)
dataset = DatasetDict({
    'train': split_dataset['train'],
    'validation': split_dataset['test']
})

# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("neuml/pubmedbert-base-embeddings-matryoshka")
model = AutoModel.from_pretrained("neuml/pubmedbert-base-embeddings-matryoshka")

# Function to embed text using mean pooling
def embed_text(text):
    inputs = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt')
    with torch.no_grad():
        output = model(**inputs)
    embeddings = meanpooling(output, inputs['attention_mask'])
    return embeddings.numpy().tolist()

# Initialize ChromaDB client
client = chromadb.Client()
collection = client.create_collection(name="pubmedbert_matryoshka_embeddings")

# Function to upsert data into ChromaDB
def upsert_data(dataset_split):
    for i, row in enumerate(dataset_split):
        for note in row['notes']:
            text = note.get('text', '')
            annotations_list = []
            
            for annotation in note.get('annotations', []):
                try:
                    code = annotation['code']
                    code_system = annotation['code_system']
                    description = annotation['description']
                    annotations_list.append({"code": code, "code_system": code_system, "description": description})
                except KeyError as e:
                    print(f"Skipping annotation due to missing key: {e}")

            if text and annotations_list:
                embeddings = embed_text([text])[0]

                # Upsert data, embeddings, and annotations into ChromaDB
                for j, annotation in enumerate(annotations_list):
                    collection.upsert(
                        ids=[f"note_{note['note_id']}_{j}"],
                        embeddings=[embeddings],
                        metadatas=[annotation]
                    )
            else:
                print(f"Skipping note {note['note_id']} due to missing 'text' or 'annotations'")

# Upsert training data
upsert_data(dataset['train'])

# Define retrieval function
def retrieve_relevant_text(input_text):
    input_embedding = embed_text([input_text])[0]  # Get the embedding for the single input text
    results = collection.query(
        query_embeddings=[input_embedding],
        n_results=5,
        include=["metadatas", "documents", "distances"]
    )
    
    # Extract code and similarity scores
    output = []
    for metadata, distance in zip(results['metadatas'][0], results['distances'][0]):
        output.append({
            "similarity_score": distance,
            "code": metadata['code'],
            "code_system": metadata['code_system'],
            "description": metadata['description']
        })
    return output

# Evaluate retrieval efficiency on the validation/test set
def evaluate_efficiency(dataset_split):
    y_true = []
    y_pred = []
    for i, row in enumerate(dataset_split):
        for note in row['notes']:
            text = note.get('text', '')
            annotations_list = [annotation['code'] for annotation in note.get('annotations', []) if 'code' in annotation]
            
            if text and annotations_list:
                retrieved_results = retrieve_relevant_text(text)
                retrieved_codes = [result['code'] for result in retrieved_results]
                
                # Ground truth
                y_true.extend(annotations_list)
                # Predictions (limit to length of true annotations to avoid mismatch)
                y_pred.extend(retrieved_codes[:len(annotations_list)])
                
    if len(y_true) != len(y_pred):
        min_length = min(len(y_true), len(y_pred))
        y_true = y_true[:min_length]
        y_pred = y_pred[:min_length]
    
    precision = precision_score(y_true, y_pred, average='macro')
    recall = recall_score(y_true, y_pred, average='macro')
    f1 = f1_score(y_true, y_pred, average='macro')
    
    return precision, recall, f1

# Calculate retrieval efficiency metrics
precision, recall, f1 = evaluate_efficiency(dataset['validation'])

# Gradio interface
def gradio_interface(input_text):
    results = retrieve_relevant_text(input_text)
    formatted_results = [
        f"Result {i + 1}:\n"
        f"Similarity Score: {result['similarity_score']:.2f}\n"
        f"Code: {result['code']}\n"
        f"Code System: {result['code_system']}\n"
        f"Description: {result['description']}\n"
        "-------------------"
        for i, result in enumerate(results)
    ]
    return "\n".join(formatted_results)

# Display retrieval efficiency metrics
metrics = f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1 Score: {f1:.2f}"

with gr.Blocks() as interface:
    gr.Markdown("# Text Retrieval with Efficiency Metrics")
    gr.Markdown(metrics)
    with gr.Row():
        with gr.Column():
            text_input = gr.Textbox(label="Input Text")
            submit_button = gr.Button("Submit")
        with gr.Column():
            text_output = gr.Textbox(label="Retrieved Results", lines=10)
    submit_button.click(fn=gradio_interface, inputs=text_input, outputs=text_output)

interface.launch()