noequal commited on
Commit
84fe4f3
1 Parent(s): 9128ec6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -83
app.py CHANGED
@@ -1,101 +1,67 @@
1
  import streamlit as st
2
- import torch
3
- from torch.utils.data import Dataset, random_split
4
- from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling
5
 
6
- # Generate sample clinical text and labels
7
- sample_data = [
8
- ("Had successful surgery today. Feeling relieved.", "surgery"),
9
- ("Started new medication for pain management.", "non-surgery"),
10
- ("Scheduled for surgery next week. Nervous but hopeful.", "surgery"),
11
- ("Attended a seminar on non-surgical treatments.", "non-surgery"),
12
- ]
13
 
14
- # Map labels to integers
15
- label_mapping = {"surgery": 1, "non-surgery": 0}
16
- train_texts, train_labels = zip(*sample_data)
17
- train_labels = [label_mapping[label] for label in train_labels]
18
 
19
- # Logging and Outputs
20
- st.write("Sample data:")
21
- for text, label in zip(train_texts, train_labels):
22
- st.write(f"Text: {text}\nLabel: {label}\n")
23
 
24
- # Load pre-trained model and tokenizer
25
- model_name = "distilbert-base-uncased" # You can use any suitable classification model
26
- model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
27
- tokenizer = AutoTokenizer.from_pretrained(model_name)
28
 
29
- # Create PyTorch Dataset object
30
- class ClinicalDataset(Dataset):
31
- def __init__(self, texts, labels, tokenizer, max_seq_length):
32
- self.texts = texts
33
- self.labels = labels
34
- self.tokenizer = tokenizer
35
- self.max_seq_length = max_seq_length
36
-
37
- def __len__(self):
38
- return len(self.texts)
39
-
40
- def __getitem__(self, idx):
41
- text = self.texts[idx]
42
- label = self.labels[idx]
43
-
44
- encoding = self.tokenizer(
45
- text,
46
- return_tensors="pt",
47
- padding='max_length', # Pad sequences to the maximum sequence length
48
- truncation=True,
49
- max_length=self.max_seq_length
50
- )
51
-
52
- return {
53
- "input_ids": encoding["input_ids"].squeeze(),
54
- "attention_mask": encoding["attention_mask"].squeeze(),
55
- "labels": torch.tensor(label)
56
- }
57
-
58
-
59
-
60
- # Data Collator
61
- data_collator = DataCollatorForLanguageModeling(
62
- tokenizer=tokenizer,
63
- mlm_probability=0.15
64
  )
65
 
66
- seq_length = 128
67
- dataset = ClinicalDataset(texts=train_texts, labels=train_labels, tokenizer=tokenizer, max_seq_length=seq_length)
 
 
 
68
 
69
- # Split dataset into training and validation sets
70
- train_size = int(0.8 * len(dataset))
71
- val_size = len(dataset) - train_size
72
- train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
73
 
74
- # Fine-tune pre-trained model on clinical dataset
75
  training_args = TrainingArguments(
76
- output_dir='./results', # output directory
77
- num_train_epochs=3, # total number of training epochs
78
- per_device_train_batch_size=16, # batch size per device during training
79
- per_device_eval_batch_size=64, # batch size for evaluation
80
- warmup_steps=500, # number of warmup steps for learning rate scheduler
81
- weight_decay=0.01, # strength of weight decay
82
- logging_dir='./logs', # directory for storing logs
83
- logging_steps=10,)
 
 
 
84
 
 
85
  trainer = Trainer(
86
  model=model,
87
  args=training_args,
88
- train_dataset=train_dataset,
89
- eval_dataset=val_dataset,
90
- data_collator=data_collator,
91
  )
92
 
 
 
 
93
 
94
- st.write("Training started...")
95
- trainer.train()
96
- st.write("Training completed.")
97
-
98
- # Logging Training Output
99
- st.write("Training logs:")
100
- with open('./logs/train.log', 'r') as log_file:
101
- st.code(log_file.read())
 
1
  import streamlit as st
2
+ import numpy as np
3
+ import evaluate
4
+ from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding
5
 
6
+ # Load tokenizer and model
7
+ tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
8
+ model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
 
 
 
 
9
 
10
+ # Define label mappings
11
+ id2label = {0: "SURGERY", 1: "NON-SURGERY"}
12
+ label2id = {"SURGERY": 0, "NON-SURGERY": 1}
 
13
 
14
+ # Load evaluation metric
15
+ accuracy = evaluate.load("accuracy")
 
 
16
 
17
+ # Define preprocessing function
18
+ def preprocess_function(examples):
19
+ return tokenizer(examples, truncation=True, padding=True)
 
20
 
21
+ # Load model for sequence classification
22
+ model = AutoModelForSequenceClassification.from_pretrained(
23
+ "emilyalsentzer/Bio_ClinicalBERT", num_labels=2, id2label=id2label, label2id=label2id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  )
25
 
26
+ # Define compute_metrics function
27
+ def compute_metrics(eval_pred):
28
+ predictions, labels = eval_pred
29
+ predictions = np.argmax(predictions, axis=1)
30
+ return accuracy.compute(predictions=predictions, references=labels)
31
 
32
+ # Define data collator
33
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
 
 
34
 
35
+ # Define training arguments
36
  training_args = TrainingArguments(
37
+ output_dir="my_awesome_model",
38
+ learning_rate=2e-5,
39
+ per_device_train_batch_size=16,
40
+ per_device_eval_batch_size=16,
41
+ num_train_epochs=2,
42
+ weight_decay=0.01,
43
+ evaluation_strategy="epoch",
44
+ save_strategy="epoch",
45
+ load_best_model_at_end=True,
46
+ push_to_hub=True,
47
+ )
48
 
49
+ # Initialize trainer
50
  trainer = Trainer(
51
  model=model,
52
  args=training_args,
53
+ tokenizer=tokenizer,
54
+ data_collator=data_collator,
55
+ compute_metrics=compute_metrics,
56
  )
57
 
58
+ # Streamlit UI
59
+ st.title("Clinical Text Classification")
60
+ text = st.text_area("Enter clinical text:", "")
61
 
62
+ if st.button("Classify"):
63
+ # Tokenize user input and predict
64
+ tokenized_text = preprocess_function(text)
65
+ result = trainer.predict(tokenized_text)
66
+ prediction = np.argmax(result.predictions, axis=1)[0]
67
+ st.write("Predicted Label:", id2label[prediction])