{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "4beac401", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from models.yolo import Model" ] }, { "cell_type": "markdown", "id": "8680f822", "metadata": {}, "source": [ "## Convert YOLOv9-C" ] }, { "cell_type": "code", "execution_count": null, "id": "59f0198d", "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cpu\")\n", "cfg = \"./models/detect/gelan-c.yaml\"\n", "model = Model(cfg, ch=3, nc=80, anchors=3)\n", "#model = model.half()\n", "model = model.to(device)\n", "_ = model.eval()\n", "ckpt = torch.load('./yolov9-c.pt', map_location='cpu')\n", "model.names = ckpt['model'].names\n", "model.nc = ckpt['model'].nc" ] }, { "cell_type": "code", "execution_count": null, "id": "2de7e1be", "metadata": {}, "outputs": [], "source": [ "idx = 0\n", "for k, v in model.state_dict().items():\n", " if \"model.{}.\".format(idx) in k:\n", " if idx < 22:\n", " kr = k.replace(\"model.{}.\".format(idx), \"model.{}.\".format(idx+1))\n", " model.state_dict()[k] -= model.state_dict()[k]\n", " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", " elif \"model.{}.cv2.\".format(idx) in k:\n", " kr = k.replace(\"model.{}.cv2.\".format(idx), \"model.{}.cv4.\".format(idx+16))\n", " model.state_dict()[k] -= model.state_dict()[k]\n", " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", " elif \"model.{}.cv3.\".format(idx) in k:\n", " kr = k.replace(\"model.{}.cv3.\".format(idx), \"model.{}.cv5.\".format(idx+16))\n", " model.state_dict()[k] -= model.state_dict()[k]\n", " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", " elif \"model.{}.dfl.\".format(idx) in k:\n", " kr = k.replace(\"model.{}.dfl.\".format(idx), \"model.{}.dfl2.\".format(idx+16))\n", " model.state_dict()[k] -= model.state_dict()[k]\n", " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", " else:\n", " while True:\n", " idx += 1\n", " if \"model.{}.\".format(idx) in k:\n", " break\n", " if idx < 22:\n", " kr = k.replace(\"model.{}.\".format(idx), \"model.{}.\".format(idx+1))\n", " model.state_dict()[k] -= model.state_dict()[k]\n", " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", " elif \"model.{}.cv2.\".format(idx) in k:\n", " kr = k.replace(\"model.{}.cv2.\".format(idx), \"model.{}.cv4.\".format(idx+16))\n", " model.state_dict()[k] -= model.state_dict()[k]\n", " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", " elif \"model.{}.cv3.\".format(idx) in k:\n", " kr = k.replace(\"model.{}.cv3.\".format(idx), \"model.{}.cv5.\".format(idx+16))\n", " model.state_dict()[k] -= model.state_dict()[k]\n", " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", " elif \"model.{}.dfl.\".format(idx) in k:\n", " kr = k.replace(\"model.{}.dfl.\".format(idx), \"model.{}.dfl2.\".format(idx+16))\n", " model.state_dict()[k] -= model.state_dict()[k]\n", " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", "_ = model.eval()" ] }, { "cell_type": "code", "execution_count": null, "id": "960796e3", "metadata": {}, "outputs": [], "source": [ "m_ckpt = {'model': model.half(),\n", " 'optimizer': None,\n", " 'best_fitness': None,\n", " 'ema': None,\n", " 'updates': None,\n", " 'opt': None,\n", " 'git': None,\n", " 'date': None,\n", " 'epoch': -1}\n", "torch.save(m_ckpt, \"./yolov9-c-converted.pt\")" ] }, { "cell_type": "markdown", "id": "47c6e6ae", "metadata": {}, "source": [ "## Convert YOLOv9-E" ] }, { "cell_type": "code", "execution_count": null, "id": "801a1b7c", "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cpu\")\n", "cfg = \"./models/detect/gelan-e.yaml\"\n", "model = Model(cfg, ch=3, nc=80, anchors=3)\n", "#model = model.half()\n", "model = model.to(device)\n", "_ = model.eval()\n", "ckpt = torch.load('./yolov9-e.pt', map_location='cpu')\n", "model.names = ckpt['model'].names\n", "model.nc = ckpt['model'].nc" ] }, { "cell_type": "code", "execution_count": null, "id": "a2ef4fe6", "metadata": {}, "outputs": [], "source": [ "idx = 0\n", "for k, v in model.state_dict().items():\n", " if \"model.{}.\".format(idx) in k:\n", " if idx < 29:\n", " kr = k.replace(\"model.{}.\".format(idx), \"model.{}.\".format(idx))\n", " model.state_dict()[k] -= model.state_dict()[k]\n", " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", " print(k, \"perfectly matched!!\")\n", " elif idx < 42:\n", " kr = k.replace(\"model.{}.\".format(idx), \"model.{}.\".format(idx+7))\n", " model.state_dict()[k] -= model.state_dict()[k]\n", " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", " print(k, \"perfectly matched!!\")\n", " elif \"model.{}.cv2.\".format(idx) in k:\n", " kr = k.replace(\"model.{}.cv2.\".format(idx), \"model.{}.cv4.\".format(idx+7))\n", " model.state_dict()[k] -= model.state_dict()[k]\n", " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", " print(k, \"perfectly matched!!\")\n", " elif \"model.{}.cv3.\".format(idx) in k:\n", " kr = k.replace(\"model.{}.cv3.\".format(idx), \"model.{}.cv5.\".format(idx+7))\n", " model.state_dict()[k] -= model.state_dict()[k]\n", " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", " print(k, \"perfectly matched!!\")\n", " elif \"model.{}.dfl.\".format(idx) in k:\n", " kr = k.replace(\"model.{}.dfl.\".format(idx), \"model.{}.dfl2.\".format(idx+7))\n", " model.state_dict()[k] -= model.state_dict()[k]\n", " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", " print(k, \"perfectly matched!!\")\n", " else:\n", " while True:\n", " idx += 1\n", " if \"model.{}.\".format(idx) in k:\n", " break\n", " if idx < 29:\n", " kr = k.replace(\"model.{}.\".format(idx), \"model.{}.\".format(idx))\n", " model.state_dict()[k] -= model.state_dict()[k]\n", " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", " print(k, \"perfectly matched!!\")\n", " elif idx < 42:\n", " kr = k.replace(\"model.{}.\".format(idx), \"model.{}.\".format(idx+7))\n", " model.state_dict()[k] -= model.state_dict()[k]\n", " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", " print(k, \"perfectly matched!!\")\n", " elif \"model.{}.cv2.\".format(idx) in k:\n", " kr = k.replace(\"model.{}.cv2.\".format(idx), \"model.{}.cv4.\".format(idx+7))\n", " model.state_dict()[k] -= model.state_dict()[k]\n", " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", " print(k, \"perfectly matched!!\")\n", " elif \"model.{}.cv3.\".format(idx) in k:\n", " kr = k.replace(\"model.{}.cv3.\".format(idx), \"model.{}.cv5.\".format(idx+7))\n", " model.state_dict()[k] -= model.state_dict()[k]\n", " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", " print(k, \"perfectly matched!!\")\n", " elif \"model.{}.dfl.\".format(idx) in k:\n", " kr = k.replace(\"model.{}.dfl.\".format(idx), \"model.{}.dfl2.\".format(idx+7))\n", " model.state_dict()[k] -= model.state_dict()[k]\n", " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", " print(k, \"perfectly matched!!\")\n", "_ = model.eval()" ] }, { "cell_type": "code", "execution_count": null, "id": "27bc1869", "metadata": {}, "outputs": [], "source": [ "m_ckpt = {'model': model.half(),\n", " 'optimizer': None,\n", " 'best_fitness': None,\n", " 'ema': None,\n", " 'updates': None,\n", " 'opt': None,\n", " 'git': None,\n", " 'date': None,\n", " 'epoch': -1}\n", "torch.save(m_ckpt, \"./yolov9-e-converted.pt\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 5 }