Asseh commited on
Commit
e9e3006
1 Parent(s): 51b3d5b

Create App.py

Browse files
Files changed (1) hide show
  1. App.py +116 -0
App.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ import numpy as np
4
+
5
+ from datasets import load_dataset
6
+ dataset = load_dataset("Asseh/Ball_Classification")
7
+
8
+ from sklearn.model_selection import train_test_split
9
+
10
+ dataset = dataset["train"].train_test_split(test_size=0.2)
11
+
12
+ from transformers import AutoImageProcessor
13
+
14
+ checkpoint = "google/vit-base-patch16-224-in21k"
15
+ image_processor = AutoImageProcessor.from_pretrained(checkpoint)
16
+
17
+ from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
18
+
19
+ normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
20
+ size = (
21
+ image_processor.size["shortest_edge"]
22
+ if "shortest_edge" in image_processor.size
23
+ else (image_processor.size["height"], image_processor.size["width"])
24
+ )
25
+ _transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])
26
+
27
+ def transforms(examples):
28
+ examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
29
+ del examples["image"]
30
+ return examples
31
+
32
+ dataset = dataset.with_transform(transforms)
33
+
34
+ from transformers import DefaultDataCollator
35
+
36
+ data_collator = DefaultDataCollator()
37
+
38
+ def compute_metrics(eval_pred):
39
+ predictions, labels = eval_pred
40
+ predictions = np.argmax(predictions, axis=1)
41
+ return accuracy.compute(predictions=predictions, references=labels)
42
+
43
+ from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
44
+
45
+ model = AutoModelForImageClassification.from_pretrained(
46
+ checkpoint,
47
+ num_labels=len(labels),
48
+ id2label=id2label,
49
+ label2id=label2id,
50
+ )
51
+
52
+ training_args = TrainingArguments(
53
+ output_dir="Ball_Classification",
54
+ remove_unused_columns=False,
55
+ evaluation_strategy="epoch",
56
+ save_strategy="epoch",
57
+ learning_rate=5e-5,
58
+ per_device_train_batch_size=16,
59
+ gradient_accumulation_steps=4,
60
+ per_device_eval_batch_size=16,
61
+ num_train_epochs=3,
62
+ warmup_ratio=0.1,
63
+ logging_steps=10,
64
+ load_best_model_at_end=True,
65
+ metric_for_best_model="accuracy",
66
+ push_to_hub=False,
67
+ )
68
+
69
+ trainer = Trainer(
70
+ model=model,
71
+ args=training_args,
72
+ data_collator=data_collator,
73
+ train_dataset=dataset["train"],
74
+ eval_dataset=dataset["test"],
75
+ tokenizer=image_processor,
76
+ compute_metrics=compute_metrics,
77
+ )
78
+
79
+ trainer.train()
80
+
81
+ from huggingface_hub import notebook_login
82
+ notebook_login()
83
+
84
+ trainer.push_to_hub()
85
+
86
+ from transformers import pipeline
87
+
88
+ classifier = pipeline("image-classification", model="Ball_Classification")
89
+ classifier(image)
90
+
91
+ from transformers import pipeline
92
+
93
+ classifier = pipeline("image-classification", model="Asseh/Ball_Classification")
94
+
95
+ from transformers import pipeline
96
+
97
+ classifier = pipeline("image-classification", model="Asseh/Ball_Classification")
98
+
99
+ # Function to classify images into 7 classes
100
+ def image_classifier(inp):
101
+ # Dummy classification logic
102
+ # Generating random confidence scores for each class
103
+ confidence_scores = np.random.rand(7)
104
+ # Normalizing confidence scores to sum up to 1
105
+ confidence_scores /= np.sum(confidence_scores)
106
+ # Creating a dictionary with class labels and corresponding confidence scores
107
+ classes = ['american_football', 'baseball', 'basketball', 'billiard_ball', 'bowling_ball','cricket_ball', 'football', 'golf_ball', 'hockey_ball', 'hockey_puck', 'rugby_ball', 'shuttlecock', 'table_tennis_ball', 'tennis_ball', 'volleyball']
108
+ result = {classes[i]: confidence_scores[i] for i in range(15)}
109
+ return result
110
+
111
+ # Creating Gradio interface
112
+ demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label")
113
+
114
+ if __name__ == "__main__":
115
+ demo.launch()
116
+