File size: 2,379 Bytes
196ee88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
{
"cells": [
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"# Load the checkpoint\n",
"input_state_dict = torch.load(\"vodka_v5_4_768-ep60-gs146640.ckpt\")\n",
"\n",
"root_is_state_dict = False\n",
"if \"state_dict\" in input_state_dict:\n",
" root_is_state_dict = True\n",
" state_dict = input_state_dict[\"state_dict\"]\n",
"else:\n",
" state_dict = input_state_dict\n",
"\n",
"mappings = {\n",
" \"attn_1.to_out.0\": \"attn_1.proj_out\",\n",
" \"attn_1.to_k\": \"attn_1.k\",\n",
" \"attn_1.to_q\": \"attn_1.q\",\n",
" \"attn_1.to_v\": \"attn_1.v\"\n",
"}\n",
"\n",
"def replace_all(key):\n",
" for mapping in mappings:\n",
" key = key.replace(mapping, mappings[mapping])\n",
" return key\n",
"\n",
"# First, create a new state_dict with renamed keys\n",
"renamed_state_dict = dict()\n",
"for key in state_dict:\n",
" renamed_state_dict[replace_all(key)] = state_dict[key]\n",
"\n",
"# Then, reshape the tensors in the renamed state_dict\n",
"def reshape_tensors(sd):\n",
" for key in sd:\n",
" if \"attn_1\" in key and sd[key].shape == torch.Size([512, 512]):\n",
" sd[key] = sd[key].unsqueeze(2).unsqueeze(3)\n",
" return sd\n",
"\n",
"output_state_dict = reshape_tensors(renamed_state_dict)\n",
"\n",
"# Finally, save the updated state_dict\n",
"if root_is_state_dict:\n",
" input_state_dict[\"state_dict\"] = output_state_dict\n",
" torch.save(input_state_dict, \"6_vodka_v5_768_adamw8bit_ep60.ckpt\")\n",
"else:\n",
" torch.save(output_state_dict, \"6_vodka_v5_768_adamw8bit_ep60.ckpt\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "fastai",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|