thankrandomness commited on
Commit
e9af536
1 Parent(s): 4fcfab4
Files changed (2) hide show
  1. app.py +145 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from datasets import load_dataset, DatasetDict
4
+ from transformers import AutoTokenizer, AutoModel
5
+ import chromadb
6
+ import gradio as gr
7
+ from sklearn.metrics import precision_score, recall_score, f1_score
8
+
9
+ # Mean Pooling - Take attention mask into account for correct averaging
10
+ def meanpooling(output, mask):
11
+ embeddings = output[0] # First element of model_output contains all token embeddings
12
+ mask = mask.unsqueeze(-1).expand(embeddings.size()).float()
13
+ return torch.sum(embeddings * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
14
+
15
+ # Load the dataset
16
+ dataset = load_dataset("thankrandomness/mimic-iii-sample")
17
+
18
+ # Split the dataset into train and validation sets
19
+ split_dataset = dataset['train'].train_test_split(test_size=0.2, seed=42)
20
+ dataset = DatasetDict({
21
+ 'train': split_dataset['train'],
22
+ 'validation': split_dataset['test']
23
+ })
24
+
25
+ # Load the model and tokenizer
26
+ tokenizer = AutoTokenizer.from_pretrained("neuml/pubmedbert-base-embeddings-matryoshka")
27
+ model = AutoModel.from_pretrained("neuml/pubmedbert-base-embeddings-matryoshka")
28
+
29
+ # Function to embed text using mean pooling
30
+ def embed_text(text):
31
+ inputs = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt')
32
+ with torch.no_grad():
33
+ output = model(**inputs)
34
+ embeddings = meanpooling(output, inputs['attention_mask'])
35
+ return embeddings.numpy().tolist()
36
+
37
+ # Initialize ChromaDB client
38
+ client = chromadb.Client()
39
+ collection = client.create_collection(name="pubmedbert_matryoshka_embeddings")
40
+
41
+ # Function to upsert data into ChromaDB
42
+ def upsert_data(dataset_split):
43
+ for i, row in enumerate(dataset_split):
44
+ for note in row['notes']:
45
+ text = note.get('text', '')
46
+ annotations_list = []
47
+
48
+ for annotation in note.get('annotations', []):
49
+ try:
50
+ code = annotation['code']
51
+ code_system = annotation['code_system']
52
+ description = annotation['description']
53
+ annotations_list.append({"code": code, "code_system": code_system, "description": description})
54
+ except KeyError as e:
55
+ print(f"Skipping annotation due to missing key: {e}")
56
+
57
+ if text and annotations_list:
58
+ embeddings = embed_text([text])[0]
59
+
60
+ # Upsert data, embeddings, and annotations into ChromaDB
61
+ for j, annotation in enumerate(annotations_list):
62
+ collection.upsert(
63
+ ids=[f"note_{note['note_id']}_{j}"],
64
+ embeddings=[embeddings],
65
+ metadatas=[annotation]
66
+ )
67
+ else:
68
+ print(f"Skipping note {note['note_id']} due to missing 'text' or 'annotations'")
69
+
70
+ # Upsert training data
71
+ upsert_data(dataset['train'])
72
+
73
+ # Define retrieval function
74
+ def retrieve_relevant_text(input_text):
75
+ input_embedding = embed_text([input_text])[0] # Get the embedding for the single input text
76
+ results = collection.query(
77
+ query_embeddings=[input_embedding],
78
+ n_results=5,
79
+ include=["metadatas", "documents", "distances"]
80
+ )
81
+
82
+ # Extract code and similarity scores
83
+ output = []
84
+ for metadata, distance in zip(results['metadatas'][0], results['distances'][0]):
85
+ output.append({
86
+ "similarity_score": distance,
87
+ "code": metadata['code'],
88
+ "code_system": metadata['code_system'],
89
+ "description": metadata['description']
90
+ })
91
+ return output
92
+
93
+ # Evaluate retrieval efficiency on the validation/test set
94
+ def evaluate_efficiency(dataset_split):
95
+ y_true = []
96
+ y_pred = []
97
+ for i, row in enumerate(dataset_split):
98
+ for note in row['notes']:
99
+ text = note.get('text', '')
100
+ annotations_list = [annotation['code'] for annotation in note.get('annotations', []) if 'code' in annotation]
101
+
102
+ if text and annotations_list:
103
+ retrieved_results = retrieve_relevant_text(text)
104
+ retrieved_codes = [result['code'] for result in retrieved_results]
105
+
106
+ # Ground truth
107
+ y_true.extend(annotations_list)
108
+ # Predictions (limit to length of true annotations to avoid mismatch)
109
+ y_pred.extend(retrieved_codes[:len(annotations_list)])
110
+
111
+ if len(y_true) != len(y_pred):
112
+ min_length = min(len(y_true), len(y_pred))
113
+ y_true = y_true[:min_length]
114
+ y_pred = y_pred[:min_length]
115
+
116
+ precision = precision_score(y_true, y_pred, average='macro')
117
+ recall = recall_score(y_true, y_pred, average='macro')
118
+ f1 = f1_score(y_true, y_pred, average='macro')
119
+
120
+ return precision, recall, f1
121
+
122
+ # Calculate retrieval efficiency metrics
123
+ precision, recall, f1 = evaluate_efficiency(dataset['validation'])
124
+
125
+ # Gradio interface
126
+ def gradio_interface(input_text):
127
+ results = retrieve_relevant_text(input_text)
128
+ formatted_results = [
129
+ f"Similarity Score: {result['similarity_score']:.2f}, Code: {result['code']}, Description: {result['description']}"
130
+ for result in results
131
+ ]
132
+ return formatted_results
133
+
134
+ # Display retrieval efficiency metrics
135
+ metrics = f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1 Score: {f1:.2f}"
136
+
137
+ with gr.Blocks() as interface:
138
+ gr.Markdown("# Text Retrieval with Efficiency Metrics")
139
+ gr.Markdown(metrics)
140
+ text_input = gr.Textbox(label="Input Text")
141
+ text_output = gr.Textbox(label="Retrieved Results")
142
+ submit_button = gr.Button("Submit")
143
+ submit_button.click(fn=gradio_interface, inputs=text_input, outputs=text_output)
144
+
145
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ datasets
4
+ chromadb
5
+ gradio
6
+ numpy
7
+ scikit-learn