codeShare commited on
Commit
ecd5d22
1 Parent(s): b8bbedf

Upload pt_to_safetensors_converter.ipynb

Browse files
Text/pt_to_safetensors_converter.ipynb ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ },
15
+ "widgets": {
16
+ "application/vnd.jupyter.widget-state+json": {
17
+ "a44dd6024769456a8262a17b0ce6a2ed": {
18
+ "model_module": "@jupyter-widgets/controls",
19
+ "model_name": "ButtonModel",
20
+ "model_module_version": "1.5.0",
21
+ "state": {
22
+ "_dom_classes": [],
23
+ "_model_module": "@jupyter-widgets/controls",
24
+ "_model_module_version": "1.5.0",
25
+ "_model_name": "ButtonModel",
26
+ "_view_count": null,
27
+ "_view_module": "@jupyter-widgets/controls",
28
+ "_view_module_version": "1.5.0",
29
+ "_view_name": "ButtonView",
30
+ "button_style": "success",
31
+ "description": "✔ Done",
32
+ "disabled": true,
33
+ "icon": "",
34
+ "layout": "IPY_MODEL_49441085d85a4f219a6ccbf2a197f527",
35
+ "style": "IPY_MODEL_f084b7dfcae445a58d36a9c21971793c",
36
+ "tooltip": ""
37
+ }
38
+ },
39
+ "49441085d85a4f219a6ccbf2a197f527": {
40
+ "model_module": "@jupyter-widgets/base",
41
+ "model_name": "LayoutModel",
42
+ "model_module_version": "1.2.0",
43
+ "state": {
44
+ "_model_module": "@jupyter-widgets/base",
45
+ "_model_module_version": "1.2.0",
46
+ "_model_name": "LayoutModel",
47
+ "_view_count": null,
48
+ "_view_module": "@jupyter-widgets/base",
49
+ "_view_module_version": "1.2.0",
50
+ "_view_name": "LayoutView",
51
+ "align_content": null,
52
+ "align_items": null,
53
+ "align_self": null,
54
+ "border": null,
55
+ "bottom": null,
56
+ "display": null,
57
+ "flex": null,
58
+ "flex_flow": null,
59
+ "grid_area": null,
60
+ "grid_auto_columns": null,
61
+ "grid_auto_flow": null,
62
+ "grid_auto_rows": null,
63
+ "grid_column": null,
64
+ "grid_gap": null,
65
+ "grid_row": null,
66
+ "grid_template_areas": null,
67
+ "grid_template_columns": null,
68
+ "grid_template_rows": null,
69
+ "height": null,
70
+ "justify_content": null,
71
+ "justify_items": null,
72
+ "left": null,
73
+ "margin": null,
74
+ "max_height": null,
75
+ "max_width": null,
76
+ "min_height": null,
77
+ "min_width": "50px",
78
+ "object_fit": null,
79
+ "object_position": null,
80
+ "order": null,
81
+ "overflow": null,
82
+ "overflow_x": null,
83
+ "overflow_y": null,
84
+ "padding": null,
85
+ "right": null,
86
+ "top": null,
87
+ "visibility": null,
88
+ "width": null
89
+ }
90
+ },
91
+ "f084b7dfcae445a58d36a9c21971793c": {
92
+ "model_module": "@jupyter-widgets/controls",
93
+ "model_name": "ButtonStyleModel",
94
+ "model_module_version": "1.5.0",
95
+ "state": {
96
+ "_model_module": "@jupyter-widgets/controls",
97
+ "_model_module_version": "1.5.0",
98
+ "_model_name": "ButtonStyleModel",
99
+ "_view_count": null,
100
+ "_view_module": "@jupyter-widgets/base",
101
+ "_view_module_version": "1.2.0",
102
+ "_view_name": "StyleView",
103
+ "button_color": null,
104
+ "font_weight": ""
105
+ }
106
+ }
107
+ }
108
+ }
109
+ },
110
+ "cells": [
111
+ {
112
+ "cell_type": "code",
113
+ "source": [
114
+ "#@title Mount Google Drive\n",
115
+ "from google.colab import drive\n",
116
+ "from IPython.display import clear_output\n",
117
+ "from IPython.display import display\n",
118
+ "import ipywidgets as widgets\n",
119
+ "import os\n",
120
+ "\n",
121
+ "def inf(msg, style, wdth): inf = widgets.Button(description=msg, disabled=True, button_style=style, layout=widgets.Layout(min_width=wdth));display(inf)\n",
122
+ "Shared_Drive = \"\" #@param {type:\"string\"}\n",
123
+ "#@markdown - If you're not using a shared drive, leave this empty\n",
124
+ "\n",
125
+ "print(\"\u001b[0;33mConnecting...\")\n",
126
+ "drive.mount('/content/gdrive')\n",
127
+ "\n",
128
+ "if Shared_Drive!=\"\" and os.path.exists(\"/content/gdrive/Shareddrives\"):\n",
129
+ " mainpth=\"Shareddrives/\"+Shared_Drive\n",
130
+ "else:\n",
131
+ " mainpth=\"MyDrive\"\n",
132
+ "\n",
133
+ "clear_output()\n",
134
+ "inf('\\u2714 Done','success', '50px')"
135
+ ],
136
+ "metadata": {
137
+ "id": "fCR2boKCTx0z",
138
+ "cellView": "form",
139
+ "outputId": "baf6303f-9850-4dd2-a6d3-86871ac8aef5",
140
+ "colab": {
141
+ "base_uri": "https://localhost:8080/",
142
+ "height": 49,
143
+ "referenced_widgets": [
144
+ "a44dd6024769456a8262a17b0ce6a2ed",
145
+ "49441085d85a4f219a6ccbf2a197f527",
146
+ "f084b7dfcae445a58d36a9c21971793c"
147
+ ]
148
+ }
149
+ },
150
+ "execution_count": null,
151
+ "outputs": [
152
+ {
153
+ "output_type": "display_data",
154
+ "data": {
155
+ "text/plain": [
156
+ "Button(button_style='success', description='✔ Done', disabled=True, layout=Layout(min_width='50px'), style=But…"
157
+ ],
158
+ "application/vnd.jupyter.widget-view+json": {
159
+ "version_major": 2,
160
+ "version_minor": 0,
161
+ "model_id": "a44dd6024769456a8262a17b0ce6a2ed"
162
+ }
163
+ },
164
+ "metadata": {}
165
+ }
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "code",
170
+ "source": [
171
+ "#@title Install Required Dependencies\n",
172
+ "!pip install torch\n",
173
+ "!pip install safetensors\n",
174
+ "!pip install pytorch-lightning"
175
+ ],
176
+ "metadata": {
177
+ "id": "5S88gkUJzeqG"
178
+ },
179
+ "execution_count": null,
180
+ "outputs": []
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "source": [
185
+ "def inf(msg, style, wdth): inf = widgets.Button(description=msg, disabled=True, button_style=style, layout=widgets.Layout(min_width=wdth));display(inf)\n",
186
+ "file_path = \"\" #@param {type:\"string\"}\n",
187
+ "#@markdown - Copy and paste the path to an embedding or VAE file that you are converting, or a directory containing several files\n",
188
+ "#@markdown - For example: /content/gdrive/MyDrive/myembedding.pt or /content/gdrive/MyDrive/my_directory\n",
189
+ "#@markdown - Pickle files must be in .pt format\n",
190
+ "verbose=True"
191
+ ],
192
+ "metadata": {
193
+ "id": "7aLFC6c4O5EW"
194
+ },
195
+ "execution_count": null,
196
+ "outputs": []
197
+ },
198
+ {
199
+ "cell_type": "code",
200
+ "source": [
201
+ "#@title Define Converter Functions\n",
202
+ "import os\n",
203
+ "from typing import Any, Dict\n",
204
+ "\n",
205
+ "import torch\n",
206
+ "from safetensors.torch import save_file\n",
207
+ "\n",
208
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
209
+ "\n",
210
+ "def process_pt_files(path: str, model_type: str, verbose=True) -> None:\n",
211
+ " if os.path.isdir(path):\n",
212
+ " # Path is a directory, process all .pt files in the directory\n",
213
+ " for file_name in os.listdir(path):\n",
214
+ " if file_name.endswith('.pt'):\n",
215
+ " process_file(os.path.join(path, file_name), model_type, verbose)\n",
216
+ " elif os.path.isfile(path) and path.endswith('.pt'):\n",
217
+ " # Path is a .pt file, process this file\n",
218
+ " process_file(path, model_type, verbose)\n",
219
+ " else:\n",
220
+ " print(f\"{path} is not a valid directory or .pt file.\")\n",
221
+ "\n",
222
+ "def process_file(file_path: str, model_type: str, verbose: bool) -> None:\n",
223
+ " # Load the PyTorch model\n",
224
+ " model = torch.load(file_path, map_location=device)\n",
225
+ "\n",
226
+ " if verbose:\n",
227
+ " print(file_path)\n",
228
+ "\n",
229
+ " if model_type == 'embedding':\n",
230
+ " s_model = process_embedding_file(model, verbose)\n",
231
+ " elif model_type == 'vae':\n",
232
+ " s_model = process_vae_file(model, verbose)\n",
233
+ " else:\n",
234
+ " raise Exception(f\"model_type `{model_type}` is not supported!\")\n",
235
+ "\n",
236
+ " # Save the model with the new extension\n",
237
+ " if file_path.endswith('.pt'):\n",
238
+ " new_file_path = file_path[:-3] + '.safetensors'\n",
239
+ " else:\n",
240
+ " new_file_path = file_path + '.safetensors'\n",
241
+ " save_file(s_model, new_file_path)\n",
242
+ "\n",
243
+ "def process_embedding_file(model: Dict[str, Any], verbose: bool) -> Dict[str, torch.Tensor]:\n",
244
+ " # Extract the embedding tensors\n",
245
+ " model_tensors = model.get('string_to_param').get('*')\n",
246
+ " s_model = {\n",
247
+ " 'emb_params': model_tensors\n",
248
+ " }\n",
249
+ "\n",
250
+ " if verbose:\n",
251
+ " # Print the requested training information, if it exists\n",
252
+ " if ('sd_checkpoint_name' in model) and (model['sd_checkpoint_name'] is not None):\n",
253
+ " print(f\"Trained on {model['sd_checkpoint_name']}.\")\n",
254
+ " else:\n",
255
+ " print(\"Checkpoint name not found in the model.\")\n",
256
+ "\n",
257
+ " if ('step' in model) and (model['step'] is not None):\n",
258
+ " print(f\"Trained for {model['step']} steps.\")\n",
259
+ " else:\n",
260
+ " print(\"Step not found in the model.\")\n",
261
+ " # Display the tensor's shape\n",
262
+ " print(f\"Dimensions of embedding tensor: {model_tensors.shape}\")\n",
263
+ " print()\n",
264
+ "\n",
265
+ " return s_model\n",
266
+ "\n",
267
+ "def process_vae_file(model: Dict[str, Any], verbose: bool) -> Dict[str, torch.Tensor]:\n",
268
+ " # Extract the state dictionary\n",
269
+ " s_model = model[\"state_dict\"]\n",
270
+ " if verbose:\n",
271
+ " # Print the requested training information, if it exists\n",
272
+ " step = model.get('step', model.get('global_step'))\n",
273
+ " if step is not None:\n",
274
+ " print(f\"Trained for {step} steps.\")\n",
275
+ " else:\n",
276
+ " print(\"Step not found in the model.\")\n",
277
+ " print()\n",
278
+ " return s_model"
279
+ ],
280
+ "metadata": {
281
+ "id": "UwH1lXmGw9XP"
282
+ },
283
+ "execution_count": null,
284
+ "outputs": []
285
+ },
286
+ {
287
+ "cell_type": "markdown",
288
+ "source": [
289
+ "## Convert the file(s)\n",
290
+ "\n",
291
+ "Run whichever of the two following code blocks corresponds to the type of file you are converting.\n",
292
+ "\n",
293
+ "The converted Safetensor file will be saved in the same directory as the original."
294
+ ],
295
+ "metadata": {
296
+ "id": "LqEl4sM0sMPG"
297
+ }
298
+ },
299
+ {
300
+ "cell_type": "code",
301
+ "source": [
302
+ "#@title Convert the Embedding(s)\n",
303
+ "process_pt_files(file_path, 'embedding', verbose=verbose)"
304
+ ],
305
+ "metadata": {
306
+ "id": "4LEWGfjiUeG1",
307
+ "cellView": "form"
308
+ },
309
+ "execution_count": null,
310
+ "outputs": []
311
+ },
312
+ {
313
+ "cell_type": "code",
314
+ "source": [
315
+ "#@title Convert the VAE(s)\n",
316
+ "process_pt_files(file_path, 'vae', verbose=verbose)"
317
+ ],
318
+ "metadata": {
319
+ "id": "Jil7A1ckyiHA",
320
+ "cellView": "form"
321
+ },
322
+ "execution_count": null,
323
+ "outputs": []
324
+ }
325
+ ]
326
+ }