Hypernova823 commited on
Commit
e39de11
·
verified ·
1 Parent(s): 550fd67

Delete train_ocr.py

Browse files
Files changed (1) hide show
  1. train_ocr.py +0 -85
train_ocr.py DELETED
@@ -1,85 +0,0 @@
1
- import os
2
- os.environ["HF_TOKEN"] = "lol nooo"
3
- import torch
4
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
5
- from datasets import load_dataset
6
- from torch.utils.data import Dataset
7
-
8
- # 1. Hardware Check
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- print(f"--- GPU STATUS: {device.upper()} ---")
11
-
12
- # 2. Download Dataset
13
- print("Downloading Handwriting Dataset (~260MB)...")
14
- dataset = load_dataset("Teklia/IAM-line", split="train").train_test_split(test_size=0.1)
15
-
16
- # 3. Download Model
17
- print("Downloading TrOCR Base Model (~1.5GB)...")
18
- model_id = "microsoft/trocr-base-handwritten"
19
- processor = TrOCRProcessor.from_pretrained(model_id)
20
- model = VisionEncoderDecoderModel.from_pretrained(model_id).to(device)
21
-
22
- # 4. Prepare Dataset
23
- class HandwritingDataset(Dataset):
24
- def __init__(self, hf_dataset, processor):
25
- self.dataset = hf_dataset
26
- self.processor = processor
27
- def __len__(self):
28
- return len(self.dataset)
29
- def __getitem__(self, idx):
30
- item = self.dataset[idx]
31
- image = item["image"].convert("RGB")
32
- text = item["text"]
33
- pixel_values = self.processor(image, return_tensors="pt").pixel_values.squeeze()
34
- labels = self.processor.tokenizer(text, padding="max_length", max_length=64, truncation=True).input_ids
35
- labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
36
- return {"pixel_values": pixel_values, "labels": torch.tensor(labels)}
37
-
38
- train_dataset = HandwritingDataset(dataset['train'], processor)
39
- eval_dataset = HandwritingDataset(dataset['test'], processor)
40
-
41
- # --- CRITICAL OPTIMIZATIONS FOR 8GB VRAM ---
42
- model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
43
- model.config.pad_token_id = processor.tokenizer.pad_token_id
44
- model.config.vocab_size = model.config.decoder.vocab_size
45
-
46
- # Enable Gradient Checkpointing (The "Magic" VRAM Saver)
47
- model.gradient_checkpointing_enable()
48
-
49
- # 5. Training Configuration
50
- training_args = Seq2SeqTrainingArguments(
51
- predict_with_generate=True,
52
- eval_strategy="steps",
53
- per_device_train_batch_size=2, # Small batches are safer
54
- per_device_eval_batch_size=1, # Extra safety during evaluation
55
- gradient_accumulation_steps=8, # Effectively a batch size of 16
56
- output_dir="./working_checkpoints",
57
- logging_steps=10,
58
- save_steps=400, # Save more frequently just in case
59
- eval_steps=400,
60
- fp16=True, # Required for RTX 3060 Ti
61
- max_steps=2000,
62
- learning_rate=4e-5,
63
- save_total_limit=3, # Keep the 3 most recent checkpoints
64
- dataloader_num_workers=0, # Prevents Windows multi-threading errors
65
- report_to="none" # Stops it from asking for a login (WandB/etc)
66
- )
67
-
68
- # 6. Start Training
69
- trainer = Seq2SeqTrainer(
70
- model=model,
71
- processing_class=processor.image_processor,
72
- args=training_args,
73
- train_dataset=train_dataset,
74
- eval_dataset=eval_dataset,
75
- data_collator=default_data_collator,
76
- )
77
-
78
- print("Starting training! Once the progress bar starts, you can safely walk away.")
79
- trainer.train()
80
-
81
- # 7. Save Final Model
82
- print("Saving final output...")
83
- trainer.save_model("./final_handwriting_model")
84
- processor.save_pretrained("./final_handwriting_model")
85
- print("DONE. See you tomorrow morning.")