npv2k1 commited on
Commit
c9e0c1d
·
verified ·
1 Parent(s): 180bd6a

feat: ShapeClassifier

Browse files
.gitignore ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+
163
+ # ignore dataset but not the folder
164
+ data/raw/*
165
+ data/processed/*
Makefile ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ package:
2
+ pip freeze > requirements.txt
main.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from src.train import train
2
+
3
+ if __name__ == "__main__":
4
+ train()
5
+
requirements.txt ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ altair==5.1.1
3
+ annotated-types==0.5.0
4
+ anyio==3.7.1
5
+ attrs==23.1.0
6
+ certifi==2022.12.7
7
+ charset-normalizer==2.1.1
8
+ click==8.1.7
9
+ colorama==0.4.6
10
+ contourpy==1.1.1
11
+ cycler==0.12.0
12
+ exceptiongroup==1.1.3
13
+ fastapi==0.103.2
14
+ ffmpy==0.3.1
15
+ filelock==3.9.0
16
+ fonttools==4.43.0
17
+ fsspec==2023.9.2
18
+ gradio==3.45.2
19
+ gradio_client==0.5.3
20
+ h11==0.14.0
21
+ httpcore==0.18.0
22
+ httpx==0.25.0
23
+ huggingface-hub==0.17.3
24
+ idna==3.4
25
+ importlib-resources==6.1.0
26
+ Jinja2==3.1.2
27
+ jsonschema==4.19.1
28
+ jsonschema-specifications==2023.7.1
29
+ kiwisolver==1.4.5
30
+ MarkupSafe==2.1.2
31
+ matplotlib==3.8.0
32
+ mpmath==1.3.0
33
+ networkx==3.0
34
+ numpy==1.24.1
35
+ orjson==3.9.7
36
+ packaging==23.1
37
+ pandas==2.1.1
38
+ Pillow==9.3.0
39
+ pydantic==2.4.2
40
+ pydantic_core==2.10.1
41
+ pydub==0.25.1
42
+ pyparsing==3.1.1
43
+ python-dateutil==2.8.2
44
+ python-multipart==0.0.6
45
+ pytz==2023.3.post1
46
+ PyYAML==6.0.1
47
+ referencing==0.30.2
48
+ requests==2.28.1
49
+ rpds-py==0.10.3
50
+ semantic-version==2.10.0
51
+ six==1.16.0
52
+ sniffio==1.3.0
53
+ starlette==0.27.0
54
+ sympy==1.12
55
+ toolz==0.12.0
56
+ torch==2.0.1+cu117
57
+ torchaudio==2.0.2+cu117
58
+ torchvision==0.15.2+cu117
59
+ tqdm==4.66.1
60
+ typing_extensions==4.8.0
61
+ tzdata==2023.3
62
+ urllib3==1.26.13
63
+ uvicorn==0.23.2
64
+ websockets==11.0.3
src/__init__.py ADDED
File without changes
src/configs/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+ from inspect import isclass
4
+
5
+ # import all files under configs/
6
+ configs_dir = os.path.dirname(__file__)
7
+ for file in os.listdir(configs_dir):
8
+ path = os.path.join(configs_dir, file)
9
+ if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)):
10
+ config_name = file[: file.find(".py")] if file.endswith(".py") else file
11
+ module = importlib.import_module("src.configs." + config_name)
12
+ for attribute_name in dir(module):
13
+ attribute = getattr(module, attribute_name)
14
+
15
+ if isclass(attribute):
16
+ # Add the class to this package's variables
17
+ globals()[attribute_name] = attribute
src/configs/model_config.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class ModelConfig:
4
+ def __init__(self):
5
+ self.learning_rate = 0.001
6
+ self.batch_size = 32
7
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
8
+ self.epochs = 20
9
+ def get_config(self):
10
+ return self
src/data/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+ from inspect import isclass
4
+
5
+ # import all files under configs/
6
+ configs_dir = os.path.dirname(__file__)
7
+ for file in os.listdir(configs_dir):
8
+ path = os.path.join(configs_dir, file)
9
+ if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)):
10
+ config_name = file[: file.find(".py")] if file.endswith(".py") else file
11
+ module = importlib.import_module("src.data." + config_name)
12
+ for attribute_name in dir(module):
13
+ attribute = getattr(module, attribute_name)
14
+
15
+ if isclass(attribute):
16
+ # Add the class to this package's variables
17
+ globals()[attribute_name] = attribute
src/data/data_loader.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .dataset import CustomDataset
2
+ from torch.utils.data import DataLoader
3
+ from src.configs.model_config import ModelConfig
4
+ from .transform import data_transform
5
+ import os
6
+
7
+ num_classes = 3
8
+ config = ModelConfig().get_config()
9
+
10
+ train_dataset = CustomDataset(data_folder=os.path.join("data", 'raw'), transform=data_transform)
11
+
12
+ # # Calculate the split point
13
+ # split_index = int(0.8 * len(dataset))
14
+
15
+ # # Split the dataset into training and testing
16
+ # train_dataset = dataset[:split_index]
17
+ # test_dataset = dataset[split_index:]
18
+
19
+
20
+ train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
21
+
src/data/dataset.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .transform import data_transform
2
+ from torch.utils.data import Dataset
3
+ import os
4
+ from PIL import Image
5
+
6
+
7
+ class CustomDataset(Dataset):
8
+ def __init__(self, data_folder, transform=None):
9
+ self.data_folder = data_folder
10
+ self.image_files = os.listdir(data_folder)
11
+ self.transform = transform
12
+
13
+ def __len__(self):
14
+ return len(self.image_files)
15
+
16
+ def __getitem__(self, idx):
17
+ image_name = self.image_files[idx]
18
+ label =image_name[:len(image_name)-8] # Extract the label from the filename
19
+
20
+
21
+ image_path = os.path.join(self.data_folder, image_name)
22
+ image = Image.open(image_path).convert("RGB") # Ensure images are RGB
23
+
24
+ if self.transform:
25
+ image = self.transform(image)
26
+ # print("label: ", label, image)
27
+ if label == "circle":
28
+ label = 0
29
+ elif label == "square":
30
+ label = 1
31
+ elif label == "triangle":
32
+ label = 2
33
+
34
+ return image, label
35
+
src/data/transform.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import torchvision.transforms as transforms
2
+
3
+ data_transform = transforms.Compose([
4
+ transforms.Resize((128, 128)),
5
+ transforms.ToTensor(),
6
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Use appropriate values
7
+ ])
src/models/model.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch.nn.functional as F
3
+
4
+ class ShapeClassifier(nn.Module):
5
+ def __init__(self, num_classes):
6
+ super(ShapeClassifier, self).__init__()
7
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
8
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
9
+ self.fc1 = nn.Linear(16 * 64 * 64, 128)
10
+ self.fc2 = nn.Linear(128, num_classes)
11
+
12
+ def forward(self, x):
13
+ x = self.pool(F.relu(self.conv1(x)))
14
+ x = x.view(-1, 16 * 64 * 64) # Adjust the dimensions based on your input image size
15
+ x = F.relu(self.fc1(x))
16
+ x = self.fc2(x)
17
+ return x
src/train.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.optim as optim
3
+ import torch.nn.functional as F
4
+ from .models.model import ShapeClassifier
5
+
6
+ from src.configs.model_config import ModelConfig
7
+ from src.data.data_loader import train_loader, num_classes
8
+
9
+
10
+ def train():
11
+ config = ModelConfig().get_config()
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ model = ShapeClassifier(num_classes=num_classes).to(device)
14
+ optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
15
+ log_interval = 20
16
+ for epoch in range(config.epochs):
17
+ model.train()
18
+ running_loss = 0.0
19
+
20
+ for batch_idx, (inputs, labels) in enumerate(train_loader):
21
+ inputs, labels = inputs.to(device), labels.to(device)
22
+ optimizer.zero_grad()
23
+
24
+ outputs = model(inputs)
25
+ loss = F.cross_entropy(outputs, labels)
26
+ loss.backward()
27
+ optimizer.step()
28
+
29
+ running_loss += loss.item()
30
+
31
+ if batch_idx % log_interval == 0:
32
+ current_loss = running_loss / log_interval
33
+ print(
34
+ f"Epoch [{epoch + 1}/{config.epochs}], Batch [{batch_idx + 1}/{len(train_loader)}], Loss: {current_loss:.4f}")
35
+ running_loss = 0.0
36
+
37
+ # calculate the accuracy on the test set
38
+
39
+ with torch.no_grad():
40
+ model.eval()
41
+ correct = 0
42
+ total = 0
43
+ for inputs, labels in train_loader:
44
+ inputs, labels = inputs.to(device), labels.to(device)
45
+ outputs = model(inputs)
46
+ predicted = torch.argmax(outputs.data, 1)
47
+ total += labels.size(0)
48
+ correct += (predicted == labels).sum().item()
49
+ print(f"Accuracy of the model on the test images: {100 * correct / total} %")
50
+ # save the model
51
+ torch.save(model.state_dict(), "model.pth")
web.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ from torchvision.transforms import functional as F
6
+ from src.models.model import ShapeClassifier # Import your model class
7
+ from torchvision import transforms
8
+ import os
9
+ from src.data.transform import data_transform
10
+
11
+
12
+ def classify_drawing(drawing_image):
13
+ # return null if no drawing is provided
14
+ if drawing_image is None:
15
+ return None
16
+
17
+ # Load the trained model
18
+ num_classes = 3 # Set the number of classes
19
+ # Initialize your model class
20
+ model = ShapeClassifier(num_classes=num_classes)
21
+ model.load_state_dict(torch.load('model.pth'))
22
+ model.eval() # Set the model to evaluation mode
23
+
24
+ # Convert the drawing to a grayscale image
25
+ drawing = np.array(drawing_image)
26
+
27
+ drawing_tensor = data_transform(Image.fromarray(drawing))
28
+
29
+ # save all the drawing to a folder draw with index
30
+ Image.fromarray(drawing).save(f'draw/{len(os.listdir("draw"))}.png')
31
+
32
+ # Perform inference
33
+ with torch.no_grad():
34
+ output = model(drawing_tensor)
35
+
36
+ shape_classes = ["Circle", "Square", "Triangle"]
37
+ predicted_class = torch.argmax(output, dim=1).item()
38
+ predicted_label = shape_classes[predicted_class]
39
+
40
+ return predicted_label
41
+
42
+
43
+ iface = gr.Interface(
44
+ fn=classify_drawing,
45
+ inputs=gr.Image(type="pil"), # Use Sketchpad as input
46
+ outputs="text",
47
+ live=True,
48
+ capture_session=True,
49
+ )
50
+ iface.launch(server_port=8111)