File size: 5,713 Bytes
e9af536 1e4f0c7 e9af536 1e4f0c7 e9af536 1e4f0c7 e9af536 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import os
import torch
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer, AutoModel
import chromadb
import gradio as gr
from sklearn.metrics import precision_score, recall_score, f1_score
# Mean Pooling - Take attention mask into account for correct averaging
def meanpooling(output, mask):
embeddings = output[0] # First element of model_output contains all token embeddings
mask = mask.unsqueeze(-1).expand(embeddings.size()).float()
return torch.sum(embeddings * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
# Load the dataset
dataset = load_dataset("thankrandomness/mimic-iii-sample")
# Split the dataset into train and validation sets
split_dataset = dataset['train'].train_test_split(test_size=0.2, seed=42)
dataset = DatasetDict({
'train': split_dataset['train'],
'validation': split_dataset['test']
})
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("neuml/pubmedbert-base-embeddings-matryoshka")
model = AutoModel.from_pretrained("neuml/pubmedbert-base-embeddings-matryoshka")
# Function to embed text using mean pooling
def embed_text(text):
inputs = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt')
with torch.no_grad():
output = model(**inputs)
embeddings = meanpooling(output, inputs['attention_mask'])
return embeddings.numpy().tolist()
# Initialize ChromaDB client
client = chromadb.Client()
collection = client.create_collection(name="pubmedbert_matryoshka_embeddings")
# Function to upsert data into ChromaDB
def upsert_data(dataset_split):
for i, row in enumerate(dataset_split):
for note in row['notes']:
text = note.get('text', '')
annotations_list = []
for annotation in note.get('annotations', []):
try:
code = annotation['code']
code_system = annotation['code_system']
description = annotation['description']
annotations_list.append({"code": code, "code_system": code_system, "description": description})
except KeyError as e:
print(f"Skipping annotation due to missing key: {e}")
if text and annotations_list:
embeddings = embed_text([text])[0]
# Upsert data, embeddings, and annotations into ChromaDB
for j, annotation in enumerate(annotations_list):
collection.upsert(
ids=[f"note_{note['note_id']}_{j}"],
embeddings=[embeddings],
metadatas=[annotation]
)
else:
print(f"Skipping note {note['note_id']} due to missing 'text' or 'annotations'")
# Upsert training data
upsert_data(dataset['train'])
# Define retrieval function
def retrieve_relevant_text(input_text):
input_embedding = embed_text([input_text])[0] # Get the embedding for the single input text
results = collection.query(
query_embeddings=[input_embedding],
n_results=5,
include=["metadatas", "documents", "distances"]
)
# Extract code and similarity scores
output = []
for metadata, distance in zip(results['metadatas'][0], results['distances'][0]):
output.append({
"similarity_score": distance,
"code": metadata['code'],
"code_system": metadata['code_system'],
"description": metadata['description']
})
return output
# Evaluate retrieval efficiency on the validation/test set
def evaluate_efficiency(dataset_split):
y_true = []
y_pred = []
for i, row in enumerate(dataset_split):
for note in row['notes']:
text = note.get('text', '')
annotations_list = [annotation['code'] for annotation in note.get('annotations', []) if 'code' in annotation]
if text and annotations_list:
retrieved_results = retrieve_relevant_text(text)
retrieved_codes = [result['code'] for result in retrieved_results]
# Ground truth
y_true.extend(annotations_list)
# Predictions (limit to length of true annotations to avoid mismatch)
y_pred.extend(retrieved_codes[:len(annotations_list)])
if len(y_true) != len(y_pred):
min_length = min(len(y_true), len(y_pred))
y_true = y_true[:min_length]
y_pred = y_pred[:min_length]
precision = precision_score(y_true, y_pred, average='macro')
recall = recall_score(y_true, y_pred, average='macro')
f1 = f1_score(y_true, y_pred, average='macro')
return precision, recall, f1
# Calculate retrieval efficiency metrics
precision, recall, f1 = evaluate_efficiency(dataset['validation'])
# Gradio interface
def gradio_interface(input_text):
results = retrieve_relevant_text(input_text)
formatted_results = [
f"Similarity Score: {result['similarity_score']:.2f}, Code: {result['code']}, Description: {result['description']}"
for result in results
]
return formatted_results
# Display retrieval efficiency metrics
metrics = f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1 Score: {f1:.2f}"
with gr.Blocks() as interface:
gr.Markdown("# Text Retrieval with Efficiency Metrics")
gr.Markdown(metrics)
text_input = gr.Textbox(label="Input Text")
text_output = gr.Textbox(label="Retrieved Results")
submit_button = gr.Button("Submit")
submit_button.click(fn=gradio_interface, inputs=text_input, outputs=text_output)
interface.launch()
|