xhiroga commited on
Commit
c0fe349
1 Parent(s): f8bd42c

Upload folder using huggingface_hub

Browse files
app.py CHANGED
@@ -32,6 +32,7 @@ def classify_image(input_image: Image):
32
 
33
  # Forward pass the input through the model
34
  output = model(input_tensor)
 
35
 
36
  probabilities = torch.nn.functional.softmax(output, dim=1)
37
 
 
32
 
33
  # Forward pass the input through the model
34
  output = model(input_tensor)
35
+ print(output)
36
 
37
  probabilities = torch.nn.functional.softmax(output, dim=1)
38
 
models/model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3342f4ace31995d4255d3b4a7017955fa83014ed9c9304b1aaa4f333a3c54271
3
  size 1074051192
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee6f0b0f6d957c868e1fb383c627d0de3095f5fac670c084e81cb81b29b43b73
3
  size 1074051192
notebooks/crop.ipynb ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 4,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "Red channel:\n",
13
+ "tensor([[0.0000],\n",
14
+ " [0.9294]])\n",
15
+ "Green channel:\n",
16
+ "tensor([[0.0000],\n",
17
+ " [0.1098]])\n",
18
+ "Blue channel:\n",
19
+ "tensor([[0.0000],\n",
20
+ " [0.1412]])\n",
21
+ "Alpha channel:\n",
22
+ "tensor([[0.],\n",
23
+ " [1.]])\n"
24
+ ]
25
+ },
26
+ {
27
+ "data": {
28
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMwAAAGFCAYAAACxAhziAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAEO0lEQVR4nO3VsQ3CUBAFQYyowDkh/RdESE4LRwl4A+vL0kx8wUtWt83M3IBD7qsHwJUIBgLBQCAYCAQDgWAgEAwEgoFAMBAIBgLBQCAYCAQDgWAgEAwEgoFAMBAIBgLBQCAYCAQDgWAgEAwEgoFAMBAIBgLBQCAYCAQDgWAgEAwEgoFAMBAIBgLBQCAYCAQDgWAgEAwEgoFAMBAIBgLBQCAYCAQDgWAgEAwEgoFAMBAIBgLBQCAYCAQDgWAgEAwEgoFAMBAIBgLBQCAYCAQDgWAgEAwEgoFAMBAIBgLBQCAYCAQDgWAgEAwEgoFAMBAIBgLBQCAYCAQDgWAgEAwEgoFAMBAIBgLBQCAYCAQDgWAgEAwEgoFAMBAIBgLBQCAYCAQDgWAgEAwEgoFAMBAIBgLBQCAYCAQDgWAgEAwEgoFAMBAIBgLBQCAYCAQDgWAgEAwEgoFAMBAIBgLBQCAYCAQDgWAgEAwEgoFAMBAIBgLBQCAYCAQDgWAgEAwEgoFAMBAIBgLBQCAYCAQDgWAgEAwEgoFAMBAIBgLBQCAYCAQDgWAgEAwEgoFAMBAIBgLBQCAYCAQDgWAgEAwEgoFAMBAIBgLBQCAYCAQDgWAgEAwEgoFAMBAIBgLBQCAYCAQDgWAgEAwEgoFAMBAIBgLBQCAYCAQDgWAgEAwEgoFAMBAIBgLBQCAYCAQDgWAgEAwEgoFAMBAIBgLBQCAYCAQDgWAgEAwEgoFAMBA8jh5+n68zd8By++f998aHgUAwEAgGAsFAIBgIBAOBYCAQDASCgUAwEAgGAsFAIBgIBAOBYCAQDASCgUAwEAgGAsFAIBgIBAOBYCAQDASCgUAwEAgGAsFAIBgIBAOBYCAQDASCgUAwEAgGAsFAIBgIBAOBYCAQDASCgUAwEAgGAsFAIBgIBAOBYCAQDASCgUAwEAgGAsFAIBgIBAOBYCAQDASCgUAwEAgGAsFAIBgIBAOBYCAQDASCgUAwEAgGAsFAIBgIBAOBYCAQDASCgUAwEAgGAsFAIBgIBAOBYCAQDASCgUAwEAgGAsFAIBgIBAOBYCAQDASCgUAwEAgGAsFAIBgIBAOBYCAQDASCgUAwEAgGAsFAIBgIBAOBYCAQDASCgUAwEAgGAsFAIBgIBAOBYCAQDASCgUAwEAgGAsFAIBgIBAOBYCAQDASCgUAwEAgGAsFAIBgIBAOBYCAQDASCgUAwEAgGAsFAIBgIBAOBYCAQDASCgUAwEAgGAsFAIBgIBAOBYCAQDASCgUAwEAgGAsFAIBgIBAOBYCAQDASCgUAwEAgGAsFAIBgIBAOBYCAQDASCgUAwEAgGAsFAIBgIBAOBYCAQDASCgUAwEAgGAsFAIBgIBAOBYCAQDASCgUAwEAgGAsFAIBgIBAOBYCAQDATbzMzqEXAVPgwEgoFAMBAIBgLBQCAYCAQDgWAgEAwEP8rSDgOaVu6AAAAAAElFTkSuQmCC",
29
+ "text/plain": [
30
+ "<Figure size 640x480 with 1 Axes>"
31
+ ]
32
+ },
33
+ "metadata": {},
34
+ "output_type": "display_data"
35
+ }
36
+ ],
37
+ "source": [
38
+ "import matplotlib.pyplot as plt\n",
39
+ "\n",
40
+ "from PIL import Image\n",
41
+ "from torchvision import transforms\n",
42
+ "\n",
43
+ "\n",
44
+ "def show_rgba(image_path):\n",
45
+ " # See RGBA data\n",
46
+ " image = Image.open(image_path)\n",
47
+ " to_tensor = transforms.ToTensor()\n",
48
+ " tensor = to_tensor(image)\n",
49
+ "\n",
50
+ " for i, color in enumerate(['Red', 'Green', 'Blue', 'Alpha']):\n",
51
+ " print(f\"{color} channel:\")\n",
52
+ " print(tensor[i])\n",
53
+ " plt.imshow(tensor.permute(1, 2, 0))\n",
54
+ " plt.axis('off')\n",
55
+ " plt.show()\n",
56
+ "\n",
57
+ "\n",
58
+ "show_rgba('../data/samples/transparent_indonesia_flag.png')\n",
59
+ "# Alpha channel: tensor([[0.], [1.]]). It means alpha 0 is transparent.\n"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": 5,
65
+ "metadata": {},
66
+ "outputs": [],
67
+ "source": [
68
+ "import cv2\n",
69
+ "import numpy as np\n",
70
+ "import os\n",
71
+ "\n",
72
+ "from PIL import Image\n",
73
+ "\n",
74
+ "\n",
75
+ "def get_object_bounding_boxes(image):\n",
76
+ " # アルファチャンネルを取得し、バイナリマスクを作成\n",
77
+ " alpha_channel = image[:, :, 3]\n",
78
+ "\n",
79
+ " # cv2.threshold関数を使用して、アルファチャンネルの値が1以上のピクセルを255(白)に、それ以外を0(黒)に変換します。\n",
80
+ " # これにより、画像のオブジェクト部分を白、背景部分を黒としたバイナリマスクが作成されます。\n",
81
+ " _, binary_mask = cv2.threshold(alpha_channel, 1, 255, cv2.THRESH_BINARY)\n",
82
+ "\n",
83
+ " # 輪郭を検出\n",
84
+ " contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n",
85
+ "\n",
86
+ " return contours or []\n",
87
+ "\n",
88
+ "\n",
89
+ "def show_bounding_boxes(image_path):\n",
90
+ " # 透明背景の画像をRGBA形式で読み込む\n",
91
+ " image_pil = Image.open(image_path)\n",
92
+ "\n",
93
+ " # PIL画像をOpenCV形式に変換\n",
94
+ " image = np.array(image_pil)\n",
95
+ "\n",
96
+ " # バウンディングボックスの取得\n",
97
+ " contours = get_object_bounding_boxes(image)\n",
98
+ " rects = [cv2.boundingRect(c) for c in contours]\n",
99
+ "\n",
100
+ " # バウンディングボックスの描画\n",
101
+ " image_bgr = cv2.cvtColor(image[:, :, :3], cv2.COLOR_RGB2BGR)\n",
102
+ " [\n",
103
+ " cv2.rectangle(image_bgr, (x, y), (x + w, y + h), (0, 255, 0), 2)\n",
104
+ " for [x,y,w,h] in rects\n",
105
+ " ]\n",
106
+ "\n",
107
+ " # バウンディングボックスを適��した画像の表示\n",
108
+ " cv2.imshow('Bounding Box', image_bgr)\n",
109
+ " cv2.waitKey(0)\n",
110
+ " cv2.destroyAllWindows()\n",
111
+ "\n",
112
+ "\n",
113
+ "show_bounding_boxes('../data/nobg/ポケットモンスターシールド/2020022922273500_s.png')\n"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": 8,
119
+ "metadata": {},
120
+ "outputs": [
121
+ {
122
+ "name": "stdout",
123
+ "output_type": "stream",
124
+ "text": [
125
+ "[]\n"
126
+ ]
127
+ }
128
+ ],
129
+ "source": [
130
+ "file_path = '../data/nobg/every-pal-in-palworld-a-complete-paldeck-list/016 Palworld Teafant.png.png'\n",
131
+ "\n",
132
+ "image_pil = Image.open(file_path)\n",
133
+ "image = np.array(image_pil)\n",
134
+ "\n",
135
+ "# Get the bounding boxes of the objects in the image\n",
136
+ "contours = get_object_bounding_boxes(image)\n",
137
+ "print(contours)"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": 9,
143
+ "metadata": {},
144
+ "outputs": [],
145
+ "source": [
146
+ "import os\n",
147
+ "\n",
148
+ "input_dir = \"../data/nobg\"\n",
149
+ "output_dir = \"../data/cropped\"\n",
150
+ "\n",
151
+ "def get_max_bounding_rect(contours):\n",
152
+ " if len(contours) == 0:\n",
153
+ " return 0,0,0,0\n",
154
+ "\n",
155
+ " c = max(contours, key=cv2.contourArea)\n",
156
+ " x, y, w, h = cv2.boundingRect(c)\n",
157
+ " return x, y, w, h\n",
158
+ "\n",
159
+ "# Loop over all files and subdirectories in the input directory\n",
160
+ "for root, dirs, files in os.walk(input_dir):\n",
161
+ " for filename in files:\n",
162
+ " # Construct full file path\n",
163
+ " file_path = os.path.join(root, filename)\n",
164
+ " \n",
165
+ " # Open the image and convert it to numpy array\n",
166
+ " image_pil = Image.open(file_path)\n",
167
+ " image = np.array(image_pil)\n",
168
+ " \n",
169
+ " # Get the bounding boxes of the objects in the image\n",
170
+ " contours = get_object_bounding_boxes(image)\n",
171
+ " \n",
172
+ " if len(contours) == 0:\n",
173
+ " continue\n",
174
+ " \n",
175
+ " # Get the maximum bounding rectangle\n",
176
+ " x, y, w, h = get_max_bounding_rect(contours)\n",
177
+ " \n",
178
+ " # Crop the image\n",
179
+ " cropped_image = image[y:y+h, x:x+w]\n",
180
+ " \n",
181
+ " cropped_image_pil = Image.fromarray(cropped_image)\n",
182
+ " \n",
183
+ " # Create output subdirectory if it doesn't exist\n",
184
+ " output_subdir = os.path.join(output_dir, os.path.relpath(root, input_dir))\n",
185
+ " os.makedirs(output_subdir, exist_ok=True)\n",
186
+ " \n",
187
+ " # Save the cropped image to the output directory\n",
188
+ " output_file_path = os.path.join(output_subdir, filename)\n",
189
+ " cropped_image_pil.save(output_file_path)\n"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": null,
195
+ "metadata": {},
196
+ "outputs": [],
197
+ "source": []
198
+ }
199
+ ],
200
+ "metadata": {
201
+ "kernelspec": {
202
+ "display_name": "pokemon-pal",
203
+ "language": "python",
204
+ "name": "python3"
205
+ },
206
+ "language_info": {
207
+ "codemirror_mode": {
208
+ "name": "ipython",
209
+ "version": 3
210
+ },
211
+ "file_extension": ".py",
212
+ "mimetype": "text/x-python",
213
+ "name": "python",
214
+ "nbconvert_exporter": "python",
215
+ "pygments_lexer": "ipython3",
216
+ "version": "3.11.7"
217
+ }
218
+ },
219
+ "nbformat": 4,
220
+ "nbformat_minor": 2
221
+ }
notebooks/nobg.ipynb ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "c:\\Users\\hiroga\\miniconda3\\envs\\pokemon-pal\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n",
14
+ "c:\\Users\\hiroga\\miniconda3\\envs\\pokemon-pal\\Lib\\site-packages\\torchvision\\transforms\\functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).\n",
15
+ " warnings.warn(\n"
16
+ ]
17
+ }
18
+ ],
19
+ "source": [
20
+ "import torch\n",
21
+ "from carvekit.api.high import HiInterface\n",
22
+ "\n",
23
+ "# Check doc strings for more information\n",
24
+ "interface = HiInterface(object_type=\"object\", # Can be \"object\" or \"hairs-like\".\n",
25
+ " batch_size_seg=5,\n",
26
+ " batch_size_matting=1,\n",
27
+ " device='cuda' if torch.cuda.is_available() else 'cpu',\n",
28
+ " seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net\n",
29
+ " matting_mask_size=2048,\n",
30
+ " trimap_prob_threshold=231,\n",
31
+ " trimap_dilation=30,\n",
32
+ " trimap_erosion_iters=5,\n",
33
+ " fp16=False)\n",
34
+ "import os\n",
35
+ "\n",
36
+ "# input_dir = \"../data/raw\"\n",
37
+ "# output_dir = \"../data/nobg\"\n",
38
+ "input_dir = \"../data/raw/ポケットモンスターシールド\"\n",
39
+ "output_dir = \"../data/nobg/ポケットモンスターシールド\"\n",
40
+ "\n",
41
+ "# Create output directory if it doesn't exist\n",
42
+ "os.makedirs(output_dir, exist_ok=True)\n",
43
+ "\n",
44
+ "# Loop over all files and subdirectories in the input directory\n",
45
+ "for root, dirs, files in os.walk(input_dir):\n",
46
+ " for filename in files:\n",
47
+ " # Construct full file path\n",
48
+ " file_path = os.path.join(root, filename)\n",
49
+ " \n",
50
+ " # Process the image and remove the background\n",
51
+ " images_without_background = interface([file_path])\n",
52
+ " image_wo_bg = images_without_background[0]\n",
53
+ " \n",
54
+ " # Create output subdirectory if it doesn't exist\n",
55
+ " output_subdir = os.path.join(output_dir, os.path.relpath(root, input_dir))\n",
56
+ " os.makedirs(output_subdir, exist_ok=True)\n",
57
+ " \n",
58
+ " # Save the processed image to the output directory\n",
59
+ " # Since the image format is RGBA, we save it as PNG\n",
60
+ " filename = os.path.splitext(filename)[0] + \".png\"\n",
61
+ " output_file_path = os.path.join(output_subdir, filename)\n",
62
+ " image_wo_bg.save(output_file_path)\n"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": []
71
+ }
72
+ ],
73
+ "metadata": {
74
+ "kernelspec": {
75
+ "display_name": "pokemon-pal",
76
+ "language": "python",
77
+ "name": "python3"
78
+ },
79
+ "language_info": {
80
+ "codemirror_mode": {
81
+ "name": "ipython",
82
+ "version": 3
83
+ },
84
+ "file_extension": ".py",
85
+ "mimetype": "text/x-python",
86
+ "name": "python",
87
+ "nbconvert_exporter": "python",
88
+ "pygments_lexer": "ipython3",
89
+ "version": "3.11.7"
90
+ }
91
+ },
92
+ "nbformat": 4,
93
+ "nbformat_minor": 2
94
+ }
notebooks/train.ipynb CHANGED
The diff for this file is too large to render. See raw diff