noequal commited on
Commit
5c0eadd
1 Parent(s): 51c1886

Create app.py

Browse files

Entry point for app. This file handles input processing including tokenizing

Files changed (1) hide show
  1. app.py +47 -0
app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ # Load a pre-trained version of ClinicalGPT
3
+ model = AutoModelForCausalLM.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
4
+ # Tokenize your clinical text data using the AutoTokenizer class
5
+ tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
6
+ # Convert your tokenized data into PyTorch tensors and create a PyTorch Dataset object
7
+ import torch
8
+ from torch.utils.data import Dataset
9
+
10
+ class ClinicalDataset(Dataset):
11
+ def __init__(self, texts, labels, tokenizer):
12
+ self.texts = texts
13
+ self.labels = labels
14
+ self.tokenizer = tokenizer
15
+
16
+ def __len__(self):
17
+ return len(self.texts)
18
+
19
+ def __getitem__(self, idx):
20
+ text = self.texts[idx]
21
+ label = self.labels[idx]
22
+ encoding = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
23
+ return {"input_ids": encoding["input_ids"].squeeze(), "attention_mask": encoding["attention_mask"].squeeze(), "labels": torch.tensor(label)}
24
+
25
+ dataset = ClinicalDataset(texts=train_texts, labels=train_labels, tokenizer=tokenizer)
26
+ # Fine-tune the pre-trained model on your clinical dataset
27
+ from transformers import Trainer, TrainingArguments
28
+
29
+ training_args = TrainingArguments(
30
+ output_dir='./results', # output directory
31
+ num_train_epochs=3, # total number of training epochs
32
+ per_device_train_batch_size=16, # batch size per device during training
33
+ per_device_eval_batch_size=64, # batch size for evaluation
34
+ warmup_steps=500, # number of warmup steps for learning rate scheduler
35
+ weight_decay=0.01, # strength of weight decay
36
+ logging_dir='./logs', # directory for storing logs
37
+ logging_steps=10, )
38
+
39
+ trainer = Trainer(
40
+ model=model,
41
+ args=training_args,
42
+ train_dataset=dataset,
43
+ eval_dataset=val_dataset,
44
+ data_collator=lambda data: {'input_ids': torch.stack([f['input_ids'] for f in data]),
45
+ 'attention_mask': torch.stack([f['attention_mask'] for f in data]),
46
+ 'labels': torch.stack([f['labels'] for f in data])}, )
47
+ trainer.train()