SurgerySort / app.py
noequal's picture
Update app.py
84fe4f3
raw
history blame contribute delete
No virus
2.09 kB
import streamlit as st
import numpy as np
import evaluate
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
# Define label mappings
id2label = {0: "SURGERY", 1: "NON-SURGERY"}
label2id = {"SURGERY": 0, "NON-SURGERY": 1}
# Load evaluation metric
accuracy = evaluate.load("accuracy")
# Define preprocessing function
def preprocess_function(examples):
return tokenizer(examples, truncation=True, padding=True)
# Load model for sequence classification
model = AutoModelForSequenceClassification.from_pretrained(
"emilyalsentzer/Bio_ClinicalBERT", num_labels=2, id2label=id2label, label2id=label2id
)
# Define compute_metrics function
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return accuracy.compute(predictions=predictions, references=labels)
# Define data collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# Define training arguments
training_args = TrainingArguments(
output_dir="my_awesome_model",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=2,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
push_to_hub=True,
)
# Initialize trainer
trainer = Trainer(
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
# Streamlit UI
st.title("Clinical Text Classification")
text = st.text_area("Enter clinical text:", "")
if st.button("Classify"):
# Tokenize user input and predict
tokenized_text = preprocess_function(text)
result = trainer.predict(tokenized_text)
prediction = np.argmax(result.predictions, axis=1)[0]
st.write("Predicted Label:", id2label[prediction])