Spaces:
Sleeping
Sleeping
Update app.py (#14)
Browse files- Update app.py (ad1bd14c666771191de9e043dc62c8cd10c6c77f)
app.py
CHANGED
@@ -50,19 +50,51 @@ def log_event(event):
|
|
50 |
# =========================
|
51 |
# Training Pipeline
|
52 |
# =========================
|
53 |
-
def train_model(model_name, dataset_name, epochs
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
66 |
|
67 |
# Load model
|
68 |
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
|
|
|
50 |
# =========================
|
51 |
# Training Pipeline
|
52 |
# =========================
|
53 |
+
def train_model(model_name, dataset_name, epochs):
|
54 |
+
try:
|
55 |
+
log.info(f"Loading dataset: {dataset_name}")
|
56 |
+
parts = dataset_name.split(" ")
|
57 |
+
|
58 |
+
if len(parts) == 2:
|
59 |
+
dataset_repo, dataset_config = parts
|
60 |
+
dataset = load_dataset(dataset_repo, dataset_config, split="train[:200]") # CPU-friendly subset
|
61 |
+
else:
|
62 |
+
dataset = load_dataset(dataset_name, split="train[:200]")
|
63 |
+
|
64 |
+
log.info("Dataset loaded successfully")
|
65 |
+
|
66 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
67 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
|
68 |
+
|
69 |
+
def tokenize_fn(examples):
|
70 |
+
return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)
|
71 |
+
|
72 |
+
dataset = dataset.map(tokenize_fn, batched=True)
|
73 |
+
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
|
74 |
+
|
75 |
+
training_args = TrainingArguments(
|
76 |
+
output_dir="./results",
|
77 |
+
eval_strategy="epoch",
|
78 |
+
save_strategy="epoch",
|
79 |
+
learning_rate=2e-5,
|
80 |
+
per_device_train_batch_size=4,
|
81 |
+
num_train_epochs=int(epochs),
|
82 |
+
logging_dir="./logs",
|
83 |
+
logging_steps=10,
|
84 |
+
)
|
85 |
|
86 |
+
trainer = Trainer(
|
87 |
+
model=model,
|
88 |
+
args=training_args,
|
89 |
+
train_dataset=dataset,
|
90 |
+
tokenizer=tokenizer,
|
91 |
+
)
|
92 |
|
93 |
+
trainer.train()
|
94 |
+
return "Training complete ✅"
|
95 |
+
except Exception as e:
|
96 |
+
log.error(f"Training failed: {e}")
|
97 |
+
return f"Error during training: {e}"
|
98 |
|
99 |
# Load model
|
100 |
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
|