{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Collect resources" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## From GitHub" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "!git clone -q https://github.com/mrok273/Qiita ../data/raw/mrok273/Qiita" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## From Kaggle" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "!kaggle datasets download -d mikoajkolman/pokemon-images-first-generation17000-files -p \"../data/raw/\" -q" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## From Web" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Install firefox\n", "- Open web page\n", "- Bulk save image (See [How to Save All the Images on a Web Page in Firefox Browser](https://www.journeybytes.com/bulk-save-images-using-firefox/))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## From YouTube" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Pal" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Official paldeck\n", "!yt-dlp --postprocessor-args \"-ss 00:00:00 -t 00:00:05\" -o \"../data/video/pocketpair/%(title)s-%(id)s-5s.%(ext)s\" -q https://www.youtube.com/playlist?list=PLptNv_Fxn9idzsTRulWNmLYKWgKhqKI5s" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "import os\n", "import re\n", "\n", "directory = \"../data/video/pocketpair\"\n", "for filename in os.listdir(directory):\n", " match = re.search(r'[Pp]aldeck.*[Nn]o.(\\d+).*.webm', filename)\n", " paldeck_no, = match.groups() if match else [None]\n", " if paldeck_no is None:\n", " continue\n", " new_filename = f\"paldeck_no{paldeck_no.zfill(3)}.webm\"\n", " os.rename(os.path.join(directory, filename), os.path.join(directory, new_filename))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1920x1080x60/1\n" ] } ], "source": [ "!ffprobe -v error -select_streams v:0 -show_entries stream=width,height,r_frame_rate -of csv=s=x:p=0 \"../data/video/pocketpair/paldeck_no001.webm\"" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append('..')\n", "\n", "import os\n", "import subprocess\n", "\n", "from datetime import datetime\n", "from src.pipeline import *\n", "\n", "input_dir = \"../data/video/pocketpair\"\n", "output_dir = \"../data/raw/pocketpair\"\n", "\n", "for root, dirs, files in os.walk(input_dir):\n", " for filename in files:\n", " full_input_path = os.path.join(root, filename)\n", " filename_without_ext, _ext = os.path.splitext(filename)\n", " output_subdir = os.path.join(output_dir, filename_without_ext)\n", " os.makedirs(output_subdir, exist_ok=True)\n", "\n", " output_pattern = os.path.join(output_subdir, \"frame_%05d.png\")\n", " command = ['ffmpeg', '-hwaccel', 'cuda', '-i', full_input_path, '-vf', 'fps=12', output_pattern]\n", " subprocess.run(command, check=True)\n", " \n", " for root_out, _, files_out in os.walk(output_subdir):\n", " for filename_out in files_out:\n", " full_output_path = os.path.join(root_out, filename_out)\n", " raw_dir = data_dir(Step.raw.value)\n", " metadata = Metadata(\n", " bucket=raw_dir,\n", " path=os.path.relpath(full_output_path, raw_dir),\n", " step=Step.raw,\n", " label=Label.pal,\n", " created_at=datetime.utcnow()\n", " )\n", " create_metadata(metadata)\n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Fan video\n", "!yt-dlp -o \"../data/video/palworld-fan/%(id)s.%(ext)s\" -q https://www.youtube.com/playlist?list=PLitsLuiXBQxtd0ThPaYMqsbxUMfmdxVHc" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append('..')\n", "\n", "import os\n", "import subprocess\n", "\n", "from datetime import datetime\n", "from src.pipeline import *\n", "\n", "input_dir = \"../data/video/palworld-fan\"\n", "output_dir = \"../data/raw/palworld-fan\"\n", "\n", "for root, dirs, files in os.walk(input_dir):\n", " for filename in files:\n", " full_input_path = os.path.join(root, filename)\n", " filename_without_ext, _ext = os.path.splitext(filename)\n", " output_subdir = os.path.join(output_dir, filename_without_ext)\n", " os.makedirs(output_subdir, exist_ok=True)\n", "\n", " output_pattern = os.path.join(output_subdir, \"frame_%05d.png\")\n", " command = ['ffmpeg', '-hwaccel', 'cuda', '-i', full_input_path, '-vf', 'fps=12', output_pattern]\n", " subprocess.run(command, check=True)\n", " \n", " for root_out, _, files_out in os.walk(output_subdir):\n", " for filename_out in files_out:\n", " full_output_path = os.path.join(root_out, filename_out)\n", " raw_dir = data_dir(Step.raw.value)\n", " metadata = Metadata(\n", " bucket=raw_dir,\n", " path=os.path.relpath(full_output_path, raw_dir),\n", " step=Step.raw,\n", " label=Label.pal,\n", " created_at=datetime.utcnow()\n", " )\n", " create_metadata(metadata)\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Pokemon" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!yt-dlp -o \"../data/video/pokemon-games/%(id)s.%(ext)s\" -q https://youtube.com/playlist?list=PLitsLuiXBQxvqH5Hv1R5ioFnCpIBMNvX3&si=nzehh3dDiU3k2Q7F" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [], "source": [ "import os\n", "import subprocess\n", "\n", "def video2img(video: str, output_dir: str, fps: int):\n", " filename_without_ext, _ext = os.path.splitext(os.path.basename(video))\n", " output_subdir = os.path.join(output_dir, filename_without_ext)\n", " os.makedirs(output_subdir, exist_ok=True)\n", " output_pattern = os.path.join(output_subdir, \"frame_%05d.png\")\n", " command = ['ffmpeg', '-hwaccel', 'cuda', '-i', video, '-vf', f\"fps={fps}\", output_pattern]\n", " subprocess.run(command, check=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for video in [\"0Loz61U6CuE.webm\", \"AObd6oPnlyg.webm\", \"cIi40yfs630.webm\", \"G9L0LK07lis.webm\", \"LG-LZKUUVZI.webm\", \"Q3-fCEL-JjE.webm\"]:\n", " video2img(f\"../data/video/pokemon-games/{video}\", \"../data/raw/pokemon-games\", 6)" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [], "source": [ "video2img(\"../data/video/pokemon-games/EEupjm0LwUQ.webm\", \"../data/raw/pokemon-games\", 1)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# torchvision.dataset はフォルダ構造が`split`/`label`でないと使えない。前処理にはHuggingFace。\n", "import os\n", "from datasets import load_dataset\n", "from torchvision import transforms\n", "from typing import Tuple\n", "\n", "\n", "def center_crop_and_save(input_dir:str, output_dir:str, crop_size: Tuple[int, int]):\n", " dataset = load_dataset(\"imagefolder\", data_dir=input_dir)\n", " cropper = transforms.CenterCrop(crop_size)\n", " os.makedirs(output_dir, exist_ok=True)\n", "\n", " def _center_crop_and_save(example):\n", " cropped = cropper(example[\"image\"])\n", " cropped.filename = os.path.abspath(example[\"image\"].filename).lower().replace(\n", " os.path.abspath(input_dir).lower(),\n", " os.path.abspath(output_dir).lower(),\n", " )\n", " cropped.save(cropped.filename)\n", " # No need to return example, just save it.\n", "\n", " dataset.map(_center_crop_and_save)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Map: 100%|██████████| 11112/11112 [11:55<00:00, 15.52 examples/s]\n", "Generating train split: 7405 examples [00:00, 10062.70 examples/s]\n", "Map: 100%|██████████| 7405/7405 [18:05<00:00, 6.82 examples/s]\n", "Generating train split: 9862 examples [00:00, 10031.40 examples/s]\n", "Map: 100%|██████████| 9862/9862 [30:25<00:00, 5.40 examples/s]\n", "Generating train split: 12420 examples [00:01, 10197.81 examples/s]\n", "Map: 100%|██████████| 12420/12420 [30:12<00:00, 6.85 examples/s]\n" ] } ], "source": [ "center_crop_and_save(\"../data/raw/pokemon-games/0Loz61U6CuE\", \"../data/raw/pokemon-games/0Loz61U6CuE_cropped\", (1028, 1028))\n", "center_crop_and_save(\"../data/raw/pokemon-games/AObd6oPnlyg\", \"../data/raw/pokemon-games/AObd6oPnlyg_cropped\", (1028, 1028))\n", "center_crop_and_save(\"../data/raw/pokemon-games/LG-LZKUUVZI\", \"../data/raw/pokemon-games/LG-LZKUUVZI_cropped\", (1028, 1028))\n", "center_crop_and_save(\"../data/raw/pokemon-games/Q3-fCEL-JjE\", \"../data/raw/pokemon-games/Q3-fCEL-JjE_cropped\", (1028, 1028))" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "import os\n", "from datasets import load_dataset\n", "from torchvision.transforms.functional import crop\n", "from typing import Tuple\n", "\n", "def left_crop_and_save(input_dir:str, output_dir:str, crop_size: Tuple[int, int]):\n", " dataset = load_dataset(\"imagefolder\", data_dir=input_dir)\n", " os.makedirs(output_dir, exist_ok=True)\n", " cropper = lambda image: crop(image, 0, 0, crop_size[0], crop_size[1])\n", "\n", " def _left_crop_and_save(example):\n", " try:\n", " if example[\"image\"].size == crop_size:\n", " return\n", " cropped = cropper(example[\"image\"])\n", " cropped.filename = os.path.abspath(example[\"image\"].filename).lower().replace(\n", " os.path.abspath(input_dir).lower(),\n", " os.path.abspath(output_dir).lower(),\n", " )\n", " cropped.save(cropped.filename)\n", " # No need to return example, just save it.\n", "\n", " except Exception as e:\n", " print(f\"Error occurred: {e}\")\n", "\n", " dataset.map(_left_crop_and_save)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Generating train split: 27813 examples [00:00, 47790.54 examples/s]\n", "Map: 57%|█████▋ | 15990/27813 [03:43<02:23, 82.47 examples/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Error occurred: cannot identify image file 'C:\\\\Users\\\\hiroga\\\\Documents\\\\GitHub\\\\til\\\\computer-science\\\\machine-learning\\\\_src\\\\pokemon-palworld\\\\data\\\\raw\\\\pokemon-games\\\\cIi40yfs630\\\\frame_15991.png'\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Map: 100%|██████████| 27813/27813 [15:52<00:00, 29.19 examples/s]\n" ] } ], "source": [ "left_crop_and_save(\"../data/raw/pokemon-games/cIi40yfs630\", \"../data/raw/pokemon-games/cIi40yfs630\", (1080, 1080))" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "sys.path.append('../RMBG-1.4')\n", "\n", "from typing import Optional\n", "\n", "import numpy as np\n", "import torch\n", "from PIL.Image import Image\n", "from briarmbg import BriaRMBG\n", "from utilities import postprocess_image, preprocess_image\n", "\n", "net = BriaRMBG()\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "net = BriaRMBG.from_pretrained(\"briaai/RMBG-1.4\")\n", "net.to(device)\n", "net.eval() \n", "\n", "def remove_background(image: Image) -> Optional[Image]:\n", " try:\n", " # prepare input\n", " orig_im = np.array(image)\n", " orig_im = orig_im[:,:,:3] # remove alpha channel\n", " orig_im_size = orig_im.shape[0:2]\n", " model_input_size = [1024,1024]\n", " preprocessed = preprocess_image(orig_im, model_input_size).to(device)\n", "\n", " # inference \n", " result = net(preprocessed)\n", "\n", " # post process\n", " result_image = postprocess_image(result[0][0], orig_im_size)\n", "\n", " # save result\n", " pil_im = Image.fromarray(result_image)\n", " no_bg_image = Image.new(\"RGBA\", pil_im.size, (0,0,0,0))\n", " no_bg_image.paste(image, mask=pil_im)\n", " return no_bg_image\n", "\n", " except Exception as e:\n", " print(f\"{e, image}\")\n", " return None" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [], "source": [ "import os\n", "from datasets import load_dataset\n", "\n", "def remove_bg_and_save(input_dir:str, output_dir:str):\n", " dataset = load_dataset(\"imagefolder\", data_dir=input_dir)\n", " os.makedirs(output_dir, exist_ok=True)\n", "\n", " def _remove_bg_and_save(example):\n", " try:\n", " nobg = remove_background(example[\"image\"])\n", " nobg.filename = os.path.abspath(example[\"image\"].filename).lower().replace(\n", " os.path.abspath(input_dir).lower(),\n", " os.path.abspath(output_dir).lower(),\n", " )\n", " nobg.save(nobg.filename)\n", "\n", " except Exception as e:\n", " print(f\"Error occurred: {e}\")\n", "\n", " dataset.map(_remove_bg_and_save)" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Generating train split: 11112 examples [00:00, 43038.25 examples/s]\n", "Map: 100%|██████████| 11112/11112 [28:06<00:00, 6.59 examples/s]\n", "Generating train split: 7405 examples [00:00, 46350.91 examples/s]\n", "Map: 100%|██████████| 7405/7405 [15:08<00:00, 8.15 examples/s]\n", "Generating train split: 27813 examples [00:00, 41096.54 examples/s]\n", "Map: 57%|█████▋ | 15992/27813 [42:15<26:43, 7.37 examples/s] " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Error occurred: cannot identify image file 'C:\\\\Users\\\\hiroga\\\\Documents\\\\GitHub\\\\til\\\\computer-science\\\\machine-learning\\\\_src\\\\pokemon-palworld\\\\data\\\\raw\\\\pokemon-games\\\\cIi40yfs630\\\\frame_15991.png'\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Map: 100%|██████████| 27813/27813 [1:13:12<00:00, 6.33 examples/s]\n", "Generating train split: 586 examples [00:00, 45085.15 examples/s]\n", "Map: 100%|██████████| 586/586 [02:00<00:00, 4.86 examples/s]\n", "Generating train split: 3069 examples [00:00, 45460.81 examples/s]\n", "Map: 100%|██████████| 3069/3069 [14:44<00:00, 3.47 examples/s]\n", "Generating train split: 9862 examples [00:00, 45277.86 examples/s]\n", "Map: 100%|██████████| 9862/9862 [24:11<00:00, 6.79 examples/s] \n", "Generating train split: 12420 examples [00:00, 48989.20 examples/s]\n", "Map: 100%|██████████| 12420/12420 [24:59<00:00, 8.28 examples/s]\n" ] } ], "source": [ "remove_bg_and_save(\"../data/raw/pokemon-games/0Loz61U6CuE_cropped\", \"../data/nobg/pokemon-games/0Loz61U6CuE\")\n", "remove_bg_and_save(\"../data/raw/pokemon-games/AObd6oPnlyg_cropped\", \"../data/nobg/pokemon-games/AObd6oPnlyg\")\n", "remove_bg_and_save(\"../data/raw/pokemon-games/cIi40yfs630\", \"../data/nobg/pokemon-games/cIi40yfs630\")\n", "remove_bg_and_save(\"../data/raw/pokemon-games/EEupjm0LwUQ\", \"../data/nobg/pokemon-games/EEupjm0LwUQ\")\n", "remove_bg_and_save(\"../data/raw/pokemon-games/G9L0LK07lis\", \"../data/nobg/pokemon-games/G9L0LK07lis\")\n", "remove_bg_and_save(\"../data/raw/pokemon-games/LG-LZKUUVZI_cropped\", \"../data/nobg/pokemon-games/LG-LZKUUVZI\")\n", "remove_bg_and_save(\"../data/raw/pokemon-games/Q3-fCEL-JjE_cropped\", \"../data/nobg/pokemon-games/Q3-fCEL-JjE\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "前処理後の画像のチェック結果\n", "- 0Loz61U6CuE: タイトルのロゴが不要、ポケモン選択画面が不要、画面下部のウィンドウが邪魔\n", "- AObd6oPnlyg: ポケモン選択画面が不要\n", "- cIi40yfs630: ポケモン切り替え中の画面が不要\n", "- EEupjm0LwUQ: ほぼ変わらない絵が1ポケモンあたり5枚あるのは多いかも。その割に、パルデア原産のポケモンはパルワールドとの区別にそこまで貢献しない気がする...\n", "- G9L0LK07lis: 特になし\n", "- LG-LZKUUVZI: 特になし。図鑑ではなくプレイ中画面から切り取ったものはノイズが少ない\n", "- Q3-fCEL-JjE: 画面下部のウインドウが邪魔" ] }, { "cell_type": "code", "execution_count": 87, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Map: 100%|██████████| 11112/11112 [14:14<00:00, 13.00 examples/s] \n", "Map: 100%|██████████| 11112/11112 [07:02<00:00, 26.32 examples/s]\n", "Generating train split: 12420 examples [00:01, 9464.56 examples/s]\n", "Map: 100%|██████████| 12420/12420 [19:25<00:00, 10.66 examples/s] \n", "Map: 100%|██████████| 12420/12420 [10:10<00:00, 20.34 examples/s]\n" ] } ], "source": [ "# 再度クロップ。初回クロップ時に動画ごとに調整すればよかったことを反省...\n", "# 今回はフォルダを分けず、既存の画像を上書きする\n", "from datasets import load_dataset\n", "from PIL.Image import Image\n", "from torchvision.transforms import CenterCrop\n", "from torchvision.transforms.functional import crop\n", "from typing import Callable\n", "\n", "def crop_and_save(data_dir: str, cropper: Callable[[Image], Image]):\n", " dataset = load_dataset(\"imagefolder\", data_dir=data_dir)\n", " dataset = dataset.map(lambda data: {\"image\": cropper(data[\"image\"]), \"original_filename\": data[\"image\"].filename})\n", " dataset.map(lambda data: data[\"image\"].save(data[\"original_filename\"]))\n", "\n", "crop_and_save(\"../data/nobg/pokemon-games/0Loz61U6CuE\", CenterCrop((540, 540)))\n", "crop_and_save(\"../data/nobg/pokemon-games/Q3-fCEL-JjE\", lambda image: crop(image, 0, 0, 750, 1080)) # type: ignore" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "オブジェクト検出を行い、一定サイズ以上のオブジェクトだけを保存" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Map: 100%|██████████| 11112/11112 [07:13<00:00, 25.63 examples/s]\n", "Generating train split: 7405 examples [00:00, 40654.13 examples/s]\n", "Map: 100%|██████████| 7405/7405 [06:03<00:00, 20.39 examples/s]\n", "Map: 100%|██████████| 27812/27812 [42:15<00:00, 10.97 examples/s] \n", "Generating train split: 586 examples [00:00, 46861.05 examples/s]\n", "Map: 100%|██████████| 586/586 [01:09<00:00, 8.42 examples/s]\n", "Generating train split: 3069 examples [00:00, 51572.00 examples/s]\n", "Map: 100%|██████████| 3069/3069 [09:30<00:00, 5.38 examples/s]\n", "Generating train split: 9862 examples [00:00, 43790.95 examples/s]\n", "Map: 100%|██████████| 9862/9862 [12:20<00:00, 13.32 examples/s]\n", "Generating train split: 12420 examples [00:00, 44092.09 examples/s]\n", "Map: 100%|██████████| 12420/12420 [07:52<00:00, 26.28 examples/s]\n" ] } ], "source": [ "from datasets import load_dataset\n", "from PIL.Image import Image\n", "import cv2\n", "import os\n", "import numpy as np\n", "\n", "def get_object_bounding_boxes(image: Image):\n", " individual_channels = image.split()\n", "\n", " alpha_channel: np.array\n", " if len(individual_channels) == 4:\n", " alpha_channel = np.array(individual_channels[3])\n", " else:\n", " raise ValueError(\"Image does not have an alpha channel.\")\n", "\n", " # cv2.threshold関数を使用して、アルファチャンネルの値が1以上のピクセルを255(白)に、それ以外を0(黒)に変換します。\n", " # これにより、画像のオブジェクト部分を白、背景部分を黒としたバイナリマスクが作成されます。\n", " _, binary_mask = cv2.threshold(alpha_channel, 1, 255, cv2.THRESH_BINARY)\n", "\n", " contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n", "\n", " return contours or []\n", "\n", "def image_to_objects(image: Image, to_dir: str, min_height: int, min_width: int) -> list[Image]:\n", " contours = get_object_bounding_boxes(image)\n", " filtered_contours = [contour for contour in contours if cv2.contourArea(contour) > min_height * min_width]\n", " for index, contour in enumerate(filtered_contours):\n", " x, y, w, h = cv2.boundingRect(contour)\n", " cropped_image = image.crop((x, y, x + w, y + h))\n", " filename = os.path.basename(image.filename)\n", " filename_without_extension, _ = os.path.splitext(filename)\n", " cropped_image.save(os.path.join(to_dir, f\"{filename_without_extension}_{index:03}.png\"))\n", "\n", "def detect_main_objects_and_save(data_dir: str, to_dir: str, min_height: int, min_width: int):\n", " dataset = load_dataset(\"imagefolder\", data_dir=data_dir, split=\"train\")\n", " os.makedirs(to_dir, exist_ok=True)\n", " dataset.map(lambda example: {\"image\": image_to_objects(example[\"image\"], to_dir, min_height, min_width)}, batched=False)\n", "\n", "\n", "min_height, min_width = 256, 256 # YouTubeから保存した画像が1920x1080という前提。キャラクターが普通に写っている場合は高さか幅が256pxを超えているように見える。\n", "detect_main_objects_and_save(\"../data/nobg/pokemon-games/0Loz61U6CuE/\", \"../data/cropped/pokemon-games/0Loz61U6CuE/\", min_height, min_width)\n", "detect_main_objects_and_save(\"../data/nobg/pokemon-games/AObd6oPnlyg/\", \"../data/cropped/pokemon-games/AObd6oPnlyg/\", min_height, min_width)\n", "detect_main_objects_and_save(\"../data/nobg/pokemon-games/cIi40yfs630/\", \"../data/cropped/pokemon-games/cIi40yfs630/\", min_height, min_width)\n", "detect_main_objects_and_save(\"../data/nobg/pokemon-games/EEupjm0LwUQ/\", \"../data/cropped/pokemon-games/EEupjm0LwUQ/\", min_height, min_width)\n", "detect_main_objects_and_save(\"../data/nobg/pokemon-games/G9L0LK07lis/\", \"../data/cropped/pokemon-games/G9L0LK07lis/\", min_height, min_width)\n", "detect_main_objects_and_save(\"../data/nobg/pokemon-games/LG-LZKUUVZI/\", \"../data/cropped/pokemon-games/LG-LZKUUVZI/\", min_height, min_width)\n", "detect_main_objects_and_save(\"../data/nobg/pokemon-games/Q3-fCEL-JjE/\", \"../data/cropped/pokemon-games/Q3-fCEL-JjE/\", min_height, min_width)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Filter images" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "ResNet(\n", " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", " (layer1): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (layer2): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (layer3): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (layer4): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n", " (fc): Linear(in_features=512, out_features=3, bias=True)\n", ")" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from safetensors import safe_open\n", "from torchvision import models\n", "import torch\n", "\n", "labels = [\"etc\", \"pal\", \"pokemon\"]\n", "\n", "model = models.resnet18()\n", "model.fc = torch.nn.Linear(model.fc.in_features, len(labels))\n", "\n", "model_save_path = \"../models/snapshots/filter.safetensors\"\n", "tensors = {}\n", "with safe_open(model_save_path, framework=\"pt\", device=\"cpu\") as f:\n", " for key in f.keys():\n", " tensors[key] = f.get_tensor(key)\n", "\n", "model.load_state_dict(tensors, strict=False)\n", "model.eval()" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "from torchvision import transforms\n", "from PIL.Image import Image\n", "\n", "preprocess = transforms.Compose([\n", " transforms.Resize(224),\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", "])\n", "\n", "def classify_image(input_image: Image):\n", " img_t = preprocess(input_image)\n", " batch_t = torch.unsqueeze(img_t, 0)\n", " \n", " with torch.no_grad():\n", " output = model(batch_t)\n", " _, max_index = torch.max(output, dim=1)\n", " return max_index.item()\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Map: 100%|██████████| 362/362 [00:29<00:00, 12.15 examples/s]\n", "Generating train split: 1481 examples [00:00, 43102.23 examples/s]\n", "Map: 100%|██████████| 1481/1481 [04:01<00:00, 6.13 examples/s]\n", "Generating train split: 512 examples [00:00, 42667.22 examples/s]\n", "Map: 100%|██████████| 512/512 [00:40<00:00, 12.60 examples/s]\n", "Generating train split: 434 examples [00:00, 39452.27 examples/s]\n", "Map: 100%|██████████| 434/434 [00:37<00:00, 11.62 examples/s]\n", "Generating train split: 718 examples [00:00, 39889.14 examples/s]\n", "Map: 100%|██████████| 718/718 [00:48<00:00, 14.83 examples/s]\n", "Generating train split: 774 examples [00:00, 42979.21 examples/s]\n", "Map: 100%|██████████| 774/774 [00:54<00:00, 14.25 examples/s]\n", "Generating train split: 958 examples [00:00, 39906.87 examples/s]\n", "Map: 100%|██████████| 958/958 [02:40<00:00, 5.95 examples/s]\n", "Generating train split: 534 examples [00:00, 42679.14 examples/s]\n", "Map: 100%|██████████| 534/534 [01:42<00:00, 5.20 examples/s]\n", "Generating train split: 1938 examples [00:00, 43060.21 examples/s]\n", "Map: 100%|██████████| 1938/1938 [02:18<00:00, 13.95 examples/s]\n", "Generating train split: 1377 examples [00:00, 44413.01 examples/s]\n", "Map: 100%|██████████| 1377/1377 [01:39<00:00, 13.87 examples/s]\n", "Generating train split: 3720 examples [00:00, 49043.23 examples/s]\n", "Map: 100%|██████████| 3720/3720 [08:16<00:00, 7.50 examples/s]\n", "Generating train split: 2503 examples [00:00, 45440.29 examples/s]\n", "Map: 100%|██████████| 2503/2503 [02:45<00:00, 15.16 examples/s]\n", "Generating train split: 568 examples [00:00, 40568.15 examples/s]\n", "Map: 100%|██████████| 568/568 [02:57<00:00, 3.20 examples/s]\n", "Generating train split: 3045 examples [00:00, 46843.36 examples/s]\n", "Map: 100%|██████████| 3045/3045 [03:42<00:00, 13.68 examples/s]\n", "Generating train split: 351 examples [00:00, 38999.73 examples/s]\n", "Map: 100%|██████████| 351/351 [00:30<00:00, 11.33 examples/s]\n", "Generating train split: 3923 examples [00:00, 47214.91 examples/s]\n", "Map: 100%|██████████| 3923/3923 [04:49<00:00, 13.57 examples/s]\n", "Generating train split: 994 examples [00:00, 42885.75 examples/s]\n", "Map: 100%|██████████| 994/994 [01:15<00:00, 13.12 examples/s]\n", "Generating train split: 6981 examples [00:00, 47512.35 examples/s]\n", "Map: 100%|██████████| 6981/6981 [12:53<00:00, 9.03 examples/s]\n", "Generating train split: 3875 examples [00:00, 45851.61 examples/s]\n", "Map: 100%|██████████| 3875/3875 [05:06<00:00, 12.63 examples/s]\n", "Map: 100%|██████████| 10359/10359 [08:12<00:00, 21.05 examples/s]\n", "Generating train split: 5165 examples [00:00, 47165.39 examples/s]\n", "Map: 100%|██████████| 5165/5165 [04:44<00:00, 18.17 examples/s]\n", "Generating train split: 30029 examples [00:00, 40537.67 examples/s]\n", "Map: 100%|██████████| 30029/30029 [30:26<00:00, 16.44 examples/s] \n", "Generating train split: 7485 examples [00:00, 47421.65 examples/s]\n", "Map: 100%|██████████| 7485/7485 [07:50<00:00, 15.92 examples/s]\n" ] } ], "source": [ "from datasets import load_dataset\n", "from PIL.Image import Image\n", "import os\n", "\n", "def save_if_pokemon_or_pal(example: dict[str, any], to_dir: str):\n", " etc = 0\n", " image = example[\"image\"]\n", " # convert('RGB') って不要だった気が...なぜ急に必要に?\n", " classified_label = classify_image(image.convert('RGB'))\n", " if classified_label != etc:\n", " filename = os.path.basename(image.filename)\n", " image.save(os.path.join(to_dir, filename))\n", " \n", "\n", "def filter_images_and_save(data_dir: str, to_dir: str):\n", " dataset = load_dataset(\"imagefolder\", data_dir=data_dir, split=\"train\")\n", " os.makedirs(to_dir, exist_ok=True)\n", " dataset.map(lambda example: {\"image\": save_if_pokemon_or_pal(example, to_dir)}, batched=False)\n", "\n", "filter_images_and_save(\"../data/cropped/palworld-fan/1JN5-jr5D_k/\", \"../data/filtered/palworld-fan_1JN5-jr5D_k/\")\n", "filter_images_and_save(\"../data/cropped/palworld-fan/3qpPt0YLp0g/\", \"../data/filtered/palworld-fan_3qpPt0YLp0g/\")\n", "filter_images_and_save(\"../data/cropped/palworld-fan/5FdIrKB1SUI/\", \"../data/filtered/palworld-fan_5FdIrKB1SUI/\")\n", "filter_images_and_save(\"../data/cropped/palworld-fan/7dmUoLu14qs/\", \"../data/filtered/palworld-fan_7dmUoLu14qs/\")\n", "filter_images_and_save(\"../data/cropped/palworld-fan/AAtemrMzo3s/\", \"../data/filtered/palworld-fan_AAtemrMzo3s/\")\n", "filter_images_and_save(\"../data/cropped/palworld-fan/bckQkt8aUlo/\", \"../data/filtered/palworld-fan_bckQkt8aUlo/\")\n", "filter_images_and_save(\"../data/cropped/palworld-fan/GB7rGn3IDpI/\", \"../data/filtered/palworld-fan_GB7rGn3IDpI/\")\n", "filter_images_and_save(\"../data/cropped/palworld-fan/gM-jmf28GEY/\", \"../data/filtered/palworld-fan_gM-jmf28GEY/\")\n", "filter_images_and_save(\"../data/cropped/palworld-fan/HBJwXcKymOk/\", \"../data/filtered/palworld-fan_HBJwXcKymOk/\")\n", "filter_images_and_save(\"../data/cropped/palworld-fan/iiGcw_gq53c/\", \"../data/filtered/palworld-fan_iiGcw_gq53c/\")\n", "filter_images_and_save(\"../data/cropped/palworld-fan/rNmZXw4zCys/\", \"../data/filtered/palworld-fan_rNmZXw4zCys/\")\n", "filter_images_and_save(\"../data/cropped/palworld-fan/S8-_o6CEI8M/\", \"../data/filtered/palworld-fan_S8-_o6CEI8M/\")\n", "filter_images_and_save(\"../data/cropped/palworld-fan/SNCOkUE3A0A/\", \"../data/filtered/palworld-fan_SNCOkUE3A0A/\")\n", "filter_images_and_save(\"../data/cropped/palworld-fan/utAT6L3Ea00/\", \"../data/filtered/palworld-fan_utAT6L3Ea00/\")\n", "filter_images_and_save(\"../data/cropped/palworld-fan/v878zGYOGq8/\", \"../data/filtered/palworld-fan_v878zGYOGq8/\")\n", "filter_images_and_save(\"../data/cropped/palworld-fan/XETrVLff13M/\", \"../data/filtered/palworld-fan_XETrVLff13M/\")\n", "filter_images_and_save(\"../data/cropped/palworld-fan/YSpO6l5TglA/\", \"../data/filtered/palworld-fan_YSpO6l5TglA/\")\n", "filter_images_and_save(\"../data/cropped/palworld-fan/zms3ORqAXiQ/\", \"../data/filtered/palworld-fan_zms3ORqAXiQ/\")\n", "filter_images_and_save(\"../data/cropped/palworld-fan/Zyqgp460xRo/\", \"../data/filtered/palworld-fan_Zyqgp460xRo/\")\n", "\n", "filter_images_and_save(\"../data/cropped/pokemon-games/0Loz61U6CuE/\", \"../data/filtered/pokemon-games_0Loz61U6CuE/\")\n", "filter_images_and_save(\"../data/cropped/pokemon-games/AObd6oPnlyg/\", \"../data/filtered/pokemon-games_AObd6oPnlyg/\")\n", "filter_images_and_save(\"../data/cropped/pokemon-games/cIi40yfs630/\", \"../data/filtered/pokemon-games_cIi40yfs630/\")\n", "filter_images_and_save(\"../data/cropped/pokemon-games/EEupjm0LwUQ/\", \"../data/filtered/pokemon-games_EEupjm0LwUQ/\")\n", "filter_images_and_save(\"../data/cropped/pokemon-games/G9L0LK07lis/\", \"../data/filtered/pokemon-games_G9L0LK07lis/\")\n", "filter_images_and_save(\"../data/cropped/pokemon-games/LG-LZKUUVZI/\", \"../data/filtered/pokemon-games_LG-LZKUUVZI/\")\n", "filter_images_and_save(\"../data/cropped/pokemon-games/Q3-fCEL-JjE/\", \"../data/filtered/pokemon-games_Q3-fCEL-JjE/\")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "til-machine-learning", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.1" } }, "nbformat": 4, "nbformat_minor": 2 }