Spaces:
Paused
Paused
Update train.py
Browse files
train.py
CHANGED
|
@@ -57,6 +57,9 @@ def load_data():
|
|
| 57 |
return dataset
|
| 58 |
|
| 59 |
def encode_decode(texts, tok):
|
|
|
|
|
|
|
|
|
|
| 60 |
tokenized_texts = tok(
|
| 61 |
texts,
|
| 62 |
padding="max_length",
|
|
@@ -69,7 +72,7 @@ def encode_decode(texts, tok):
|
|
| 69 |
decoded_texts = tok.batch_decode(tokenized_texts)
|
| 70 |
else:
|
| 71 |
print('Found invalid entry in examples. Returning dummy..')
|
| 72 |
-
decoded_texts = [
|
| 73 |
|
| 74 |
islist = not len(decoded_texts) == 1
|
| 75 |
|
|
@@ -97,7 +100,7 @@ def get_training_corpus(dataset):
|
|
| 97 |
def format_prompts(examples, tokenizer, isinst):
|
| 98 |
texts = []
|
| 99 |
for text in examples['text']:
|
| 100 |
-
if text:
|
| 101 |
if isinst:
|
| 102 |
conversation = []
|
| 103 |
parts = text.split('<|end|>')
|
|
@@ -115,6 +118,9 @@ def format_prompts(examples, tokenizer, isinst):
|
|
| 115 |
print('Found empty entry in examples. Moving on..')
|
| 116 |
continue
|
| 117 |
|
|
|
|
|
|
|
|
|
|
| 118 |
coded_texts = tokenizer.code(texts)
|
| 119 |
return {'text': coded_texts}
|
| 120 |
|
|
@@ -208,7 +214,24 @@ def train_model(model, tokenizer, dataset, push, isinst):
|
|
| 208 |
)
|
| 209 |
|
| 210 |
dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer, isinst), batched=True, remove_columns=dataset.column_names)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
print("Mapped dataset sample length:", len(dataset[0]['text']))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
trainer = trl.SFTTrainer(
|
| 214 |
model=model,
|
|
@@ -270,8 +293,14 @@ def main(push_to_hub=True, is_inst_finetune=False):
|
|
| 270 |
model = create_model(tokenizer)
|
| 271 |
print("Created Model.")
|
| 272 |
|
|
|
|
|
|
|
|
|
|
| 273 |
print("Resizing Token Embeddings..")
|
| 274 |
-
|
|
|
|
|
|
|
|
|
|
| 275 |
print("Resized Embeddings.")
|
| 276 |
|
| 277 |
print("Training Model..")
|
|
|
|
| 57 |
return dataset
|
| 58 |
|
| 59 |
def encode_decode(texts, tok):
|
| 60 |
+
if tok.pad_token is None:
|
| 61 |
+
tok.pad_token = tok.eos_token
|
| 62 |
+
|
| 63 |
tokenized_texts = tok(
|
| 64 |
texts,
|
| 65 |
padding="max_length",
|
|
|
|
| 72 |
decoded_texts = tok.batch_decode(tokenized_texts)
|
| 73 |
else:
|
| 74 |
print('Found invalid entry in examples. Returning dummy..')
|
| 75 |
+
decoded_texts = [tokenizer.pad_token * MAX_SEQ_LENGTH]
|
| 76 |
|
| 77 |
islist = not len(decoded_texts) == 1
|
| 78 |
|
|
|
|
| 100 |
def format_prompts(examples, tokenizer, isinst):
|
| 101 |
texts = []
|
| 102 |
for text in examples['text']:
|
| 103 |
+
if text and len(text.strip()) > 0:
|
| 104 |
if isinst:
|
| 105 |
conversation = []
|
| 106 |
parts = text.split('<|end|>')
|
|
|
|
| 118 |
print('Found empty entry in examples. Moving on..')
|
| 119 |
continue
|
| 120 |
|
| 121 |
+
if len(texts) == 0:
|
| 122 |
+
raise ValueError("No valid texts found in examples for formatting.")
|
| 123 |
+
|
| 124 |
coded_texts = tokenizer.code(texts)
|
| 125 |
return {'text': coded_texts}
|
| 126 |
|
|
|
|
| 214 |
)
|
| 215 |
|
| 216 |
dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer, isinst), batched=True, remove_columns=dataset.column_names)
|
| 217 |
+
|
| 218 |
+
if 'text' not in dataset.column_names:
|
| 219 |
+
raise ValueError("Dataset transformation failed: 'text' column missing after mapping.")
|
| 220 |
+
|
| 221 |
print("Mapped dataset sample length:", len(dataset[0]['text']))
|
| 222 |
+
|
| 223 |
+
try:
|
| 224 |
+
test_input = tokenizer(
|
| 225 |
+
["This is a test input."],
|
| 226 |
+
return_tensors="pt",
|
| 227 |
+
padding="max_length",
|
| 228 |
+
truncation=True,
|
| 229 |
+
max_length=MAX_SEQ_LENGTH
|
| 230 |
+
)
|
| 231 |
+
test_output = model(**test_input)
|
| 232 |
+
print("Model test output shape:", test_output.logits.shape)
|
| 233 |
+
except RuntimeError as e:
|
| 234 |
+
print(f"Error processing test batch: {e}")
|
| 235 |
|
| 236 |
trainer = trl.SFTTrainer(
|
| 237 |
model=model,
|
|
|
|
| 293 |
model = create_model(tokenizer)
|
| 294 |
print("Created Model.")
|
| 295 |
|
| 296 |
+
print(f"Tokenizer vocabulary size: {len(tokenizer)}")
|
| 297 |
+
print(f"Special tokens: {tokenizer.special_tokens_map}")
|
| 298 |
+
|
| 299 |
print("Resizing Token Embeddings..")
|
| 300 |
+
try:
|
| 301 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 302 |
+
except RuntimeError as e:
|
| 303 |
+
raise RuntimeError(f"Error resizing token embeddings: {e}")
|
| 304 |
print("Resized Embeddings.")
|
| 305 |
|
| 306 |
print("Training Model..")
|