xhiroga commited on
Commit
738801e
1 Parent(s): d4a1725

Upload folder using huggingface_hub

Browse files
models/model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ee6f0b0f6d957c868e1fb383c627d0de3095f5fac670c084e81cb81b29b43b73
3
  size 1074051192
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d5e8017d418b955cc16dd7a3f4ae86cc15f9be6df35bdbd2ab275583bf05b60
3
  size 1074051192
notebooks/nobg.ipynb CHANGED
@@ -2,9 +2,18 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 2,
6
  "metadata": {},
7
- "outputs": [],
 
 
 
 
 
 
 
 
 
8
  "source": [
9
  "import torch\n",
10
  "from carvekit.api.high import HiInterface\n",
@@ -24,8 +33,8 @@
24
  "\n",
25
  "# input_dir = \"../data/raw\"\n",
26
  "# output_dir = \"../data/nobg\"\n",
27
- "input_dir = \"../data/raw/ポケットモンスターシールド\"\n",
28
- "output_dir = \"../data/nobg/ポケットモンスターシールド\"\n",
29
  "\n",
30
  "# Create output directory if it doesn't exist\n",
31
  "os.makedirs(output_dir, exist_ok=True)\n",
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 11,
6
  "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "c:\\Users\\hiroga\\miniconda3\\envs\\pokemon-pal\\Lib\\site-packages\\torchvision\\transforms\\functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).\n",
13
+ " warnings.warn(\n"
14
+ ]
15
+ }
16
+ ],
17
  "source": [
18
  "import torch\n",
19
  "from carvekit.api.high import HiInterface\n",
 
33
  "\n",
34
  "# input_dir = \"../data/raw\"\n",
35
  "# output_dir = \"../data/nobg\"\n",
36
+ "input_dir = \"../data/raw/paldex.io\"\n",
37
+ "output_dir = \"../data/nobg/paldex.io\"\n",
38
  "\n",
39
  "# Create output directory if it doesn't exist\n",
40
  "os.makedirs(output_dir, exist_ok=True)\n",
notebooks/train.ipynb CHANGED
@@ -47,7 +47,8 @@
47
  "]\n",
48
  "pal_dir: list[str] = [\n",
49
  " f\"{data_dir}/#パルワールド/\", \n",
50
- " f\"{data_dir}/every-pal-in-palworld-a-complete-paldeck-list/\"\n",
 
51
  "]\n",
52
  "\n",
53
  "pokemon_images: list[str] = [os.path.join(dir, file) for dir in pokemon_dir for file in os.listdir(dir)]\n",
@@ -198,10 +199,11 @@
198
  " \n",
199
  " # Save the model\n",
200
  " model_dir = '../models/'\n",
201
- " if not os.path.exists(model_dir):\n",
202
- " os.makedirs(model_dir)\n",
 
203
  " tensors = {name: param for name, param in model.named_parameters()}\n",
204
- " save_file(tensors, f\"{model_dir}SimpleCNN_{epoch+1}_{data_dir.replace('.', '').replace('/', '_')}_{image_size}x{image_size}.safetensors\")\n",
205
  " save_file(tensors, f\"{model_dir}model.safetensors\")\n",
206
  "\n",
207
  " # Plotting the losses\n",
 
47
  "]\n",
48
  "pal_dir: list[str] = [\n",
49
  " f\"{data_dir}/#パルワールド/\", \n",
50
+ " f\"{data_dir}/every-pal-in-palworld-a-complete-paldeck-list/\",\n",
51
+ " f\"{data_dir}/paldex.io/\"\n",
52
  "]\n",
53
  "\n",
54
  "pokemon_images: list[str] = [os.path.join(dir, file) for dir in pokemon_dir for file in os.listdir(dir)]\n",
 
199
  " \n",
200
  " # Save the model\n",
201
  " model_dir = '../models/'\n",
202
+ " snapshots_dir = f\"{model_dir}snapshots/\"\n",
203
+ " if not os.path.exists(snapshots_dir):\n",
204
+ " os.makedirs(snapshots_dir)\n",
205
  " tensors = {name: param for name, param in model.named_parameters()}\n",
206
+ " save_file(tensors, f\"{snapshots_dir}SimpleCNN_{epoch+1}_{data_dir.replace('.', '').replace('/', '_')}_{image_size}x{image_size}.safetensors\")\n",
207
  " save_file(tensors, f\"{model_dir}model.safetensors\")\n",
208
  "\n",
209
  " # Plotting the losses\n",