noequal commited on
Commit
001896c
1 Parent(s): 972f109

Update app to set tensors to consistent size

Browse files
Files changed (1) hide show
  1. app.py +28 -11
app.py CHANGED
@@ -28,10 +28,11 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
28
 
29
  # Create PyTorch Dataset object
30
  class ClinicalDataset(Dataset):
31
- def __init__(self, texts, labels, tokenizer):
32
  self.texts = texts
33
  self.labels = labels
34
  self.tokenizer = tokenizer
 
35
 
36
  def __len__(self):
37
  return len(self.texts)
@@ -39,15 +40,28 @@ class ClinicalDataset(Dataset):
39
  def __getitem__(self, idx):
40
  text = self.texts[idx]
41
  label = self.labels[idx]
42
- encoding = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
43
- return {"input_ids": encoding["input_ids"].squeeze(), "attention_mask": encoding["attention_mask"].squeeze(), "labels": torch.tensor(label)}
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
 
46
 
47
  # Data Collator
48
  data_collator = default_data_collator
49
 
50
- dataset = ClinicalDataset(texts=train_texts, labels=train_labels, tokenizer=tokenizer)
 
51
 
52
  # Split dataset into training and validation sets
53
  train_size = int(0.8 * len(dataset))
@@ -66,13 +80,16 @@ training_args = TrainingArguments(
66
  logging_steps=10,)
67
 
68
  trainer = Trainer(
69
- model=model,
70
- args=training_args,
71
- train_dataset=train_dataset,
72
- eval_dataset=val_dataset,
73
- data_collator=lambda data: {'input_ids': torch.stack([f['input_ids'] for f in data]),
74
- 'attention_mask': torch.stack([f['attention_mask'] for f in data]),
75
- 'labels': torch.stack([f['labels'] for f in data])}, )
 
 
 
76
 
77
 
78
  st.write("Training started...")
 
28
 
29
  # Create PyTorch Dataset object
30
  class ClinicalDataset(Dataset):
31
+ def __init__(self, texts, labels, tokenizer, max_seq_length):
32
  self.texts = texts
33
  self.labels = labels
34
  self.tokenizer = tokenizer
35
+ self.max_seq_length = max_seq_length
36
 
37
  def __len__(self):
38
  return len(self.texts)
 
40
  def __getitem__(self, idx):
41
  text = self.texts[idx]
42
  label = self.labels[idx]
43
+
44
+ encoding = self.tokenizer(
45
+ text,
46
+ return_tensors="pt",
47
+ padding='max_length', # Pad sequences to the maximum sequence length
48
+ truncation=True,
49
+ max_length=self.max_seq_length
50
+ )
51
+
52
+ return {
53
+ "input_ids": encoding["input_ids"].squeeze(),
54
+ "attention_mask": encoding["attention_mask"].squeeze(),
55
+ "labels": torch.tensor(label)
56
+ }
57
 
58
 
59
 
60
  # Data Collator
61
  data_collator = default_data_collator
62
 
63
+ seq_length = 128
64
+ dataset = ClinicalDataset(texts=train_texts, labels=train_labels, tokenizer=tokenizer, max_seq_length=seq_length)
65
 
66
  # Split dataset into training and validation sets
67
  train_size = int(0.8 * len(dataset))
 
80
  logging_steps=10,)
81
 
82
  trainer = Trainer(
83
+ model=model,
84
+ args=training_args,
85
+ train_dataset=train_dataset,
86
+ eval_dataset=val_dataset,
87
+ data_collator=lambda data: {'input_ids': torch.stack([f['input_ids'] for f in data]),
88
+ 'attention_mask': torch.stack([f['attention_mask'] for f in data]),
89
+ 'labels': torch.stack([f['labels'] for f in data])},
90
+ pad_to_max_length=True
91
+
92
+ )
93
 
94
 
95
  st.write("Training started...")