rararara9999 commited on
Commit
96581e3
·
verified ·
1 Parent(s): aec5168

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -0
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+
3
+ # Install the required packages
4
+ subprocess.check_call(["pip", "install", "--upgrade", "pip"])
5
+ subprocess.check_call(["pip", "install", "-U", "transformers"])
6
+ subprocess.check_call(["pip", "install", "-U", "accelerate"])
7
+ subprocess.check_call(["pip", "install", "datasets"])
8
+ subprocess.check_call(["pip", "install", "evaluate"])
9
+ subprocess.check_call(["pip", "install", "scikit-learn"])
10
+ subprocess.check_call(["pip", "install", "torchvision"])
11
+
12
+ model_checkpoint = "google/vit-base-patch16-224-in21k"
13
+ batch_size = 128
14
+
15
+ from datasets import load_dataset
16
+ from evaluate import load
17
+
18
+ metric = load("accuracy")
19
+
20
+ # Load the dataset directly from Hugging Face
21
+ dataset = load_dataset("DamarJati/Face-Mask-Detection")
22
+ labels = dataset["train"].features["label"].names
23
+ label2id, id2label = dict(), dict()
24
+ for i, label in enumerate(labels):
25
+ label2id[label] = i
26
+ id2label[i] = label
27
+
28
+ from transformers import AutoImageProcessor
29
+ image_processor = AutoImageProcessor.from_pretrained(model_checkpoint)
30
+ image_processor
31
+
32
+ from torchvision.transforms import (
33
+ CenterCrop,
34
+ Compose,
35
+ Normalize,
36
+ RandomHorizontalFlip,
37
+ RandomResizedCrop,
38
+ Resize,
39
+ ToTensor,
40
+ ColorJitter,
41
+ RandomRotation
42
+ )
43
+
44
+ normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
45
+ size = (image_processor.size["height"], image_processor.size["width"])
46
+
47
+ train_transforms = Compose(
48
+ [
49
+ RandomResizedCrop(size),
50
+ RandomHorizontalFlip(),
51
+ RandomRotation(degrees=15),
52
+ ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
53
+ ToTensor(),
54
+ normalize,
55
+ ]
56
+ )
57
+
58
+ val_transforms = Compose(
59
+ [
60
+ Resize(size),
61
+ CenterCrop(size),
62
+ RandomRotation(degrees=15),
63
+ ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
64
+ ToTensor(),
65
+ normalize,
66
+ ]
67
+ )
68
+
69
+ def preprocess_train(example_batch):
70
+ example_batch["pixel_values"] = [
71
+ train_transforms(image.convert("RGB")) for image in example_batch["image"]
72
+ ]
73
+ return example_batch
74
+
75
+ def preprocess_val(example_batch):
76
+ example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
77
+ return example_batch
78
+
79
+ splits = dataset["train"].train_test_split(test_size=0.3)
80
+ train_ds = splits['train']
81
+ val_ds = splits['test']
82
+
83
+ train_ds.set_transform(preprocess_train)
84
+ val_ds.set_transform(preprocess_val)
85
+
86
+ from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
87
+
88
+ model = AutoModelForImageClassification.from_pretrained(model_checkpoint,
89
+ label2id=label2id,
90
+ id2label=id2label,
91
+ ignore_mismatched_sizes=True)
92
+
93
+ model_name = model_checkpoint.split("/")[-1]
94
+
95
+ args = TrainingArguments(
96
+ f"{model_name}-finetuned",
97
+ remove_unused_columns=False,
98
+ eval_strategy="epoch", # Updated parameter
99
+ save_strategy="epoch",
100
+ save_total_limit=5,
101
+ learning_rate=1e-3,
102
+ per_device_train_batch_size=batch_size,
103
+ gradient_accumulation_steps=2,
104
+ per_device_eval_batch_size=batch_size,
105
+ num_train_epochs=2,
106
+ warmup_ratio=0.1,
107
+ weight_decay=0.01,
108
+ lr_scheduler_type="cosine",
109
+ logging_steps=10,
110
+ load_best_model_at_end=True,
111
+ metric_for_best_model="accuracy",
112
+ )
113
+
114
+ import numpy as np
115
+
116
+ def compute_metrics(eval_pred):
117
+ """Computes accuracy on a batch of predictions"""
118
+ predictions = np.argmax(eval_pred.predictions, axis=1)
119
+ return metric.compute(predictions=predictions, references=eval_pred.label_ids)
120
+
121
+ import torch
122
+
123
+ def collate_fn(examples):
124
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
125
+ labels = torch.tensor([example["label"] for example in examples])
126
+ return {"pixel_values": pixel_values, "labels": labels}
127
+
128
+ trainer = Trainer(
129
+ model=model,
130
+ args=args,
131
+ train_dataset=train_ds,
132
+ eval_dataset=val_ds,
133
+ tokenizer=image_processor,
134
+ compute_metrics=compute_metrics,
135
+ data_collator=collate_fn,
136
+ )
137
+
138
+ print("Starting training...")
139
+ train_results = trainer.train()
140
+ print("Training completed.")
141
+
142
+ # Save model
143
+ trainer.save_model()
144
+ trainer.log_metrics("train", train_results.metrics)
145
+ trainer.save_metrics("train", train_results.metrics)
146
+ trainer.save_state()
147
+
148
+ print("Starting evaluation...")
149
+ metrics = trainer.evaluate()
150
+ print("Evaluation completed.")
151
+
152
+ # Log and save metrics
153
+ trainer.log_metrics("eval", metrics)
154
+ trainer.save_metrics("eval", metrics)
155
+
156
+ # Print evaluation metrics
157
+ print("Evaluation Metrics:")
158
+ for key, value in metrics.items():
159
+ print(f"{key}: {value}")