codeShare commited on
Commit
e5232dc
1 Parent(s): ecd5d22

Delete Text

Browse files
Text/pt_to_safetensors_converter.ipynb DELETED
@@ -1,326 +0,0 @@
1
- {
2
- "nbformat": 4,
3
- "nbformat_minor": 0,
4
- "metadata": {
5
- "colab": {
6
- "provenance": []
7
- },
8
- "kernelspec": {
9
- "name": "python3",
10
- "display_name": "Python 3"
11
- },
12
- "language_info": {
13
- "name": "python"
14
- },
15
- "widgets": {
16
- "application/vnd.jupyter.widget-state+json": {
17
- "a44dd6024769456a8262a17b0ce6a2ed": {
18
- "model_module": "@jupyter-widgets/controls",
19
- "model_name": "ButtonModel",
20
- "model_module_version": "1.5.0",
21
- "state": {
22
- "_dom_classes": [],
23
- "_model_module": "@jupyter-widgets/controls",
24
- "_model_module_version": "1.5.0",
25
- "_model_name": "ButtonModel",
26
- "_view_count": null,
27
- "_view_module": "@jupyter-widgets/controls",
28
- "_view_module_version": "1.5.0",
29
- "_view_name": "ButtonView",
30
- "button_style": "success",
31
- "description": "✔ Done",
32
- "disabled": true,
33
- "icon": "",
34
- "layout": "IPY_MODEL_49441085d85a4f219a6ccbf2a197f527",
35
- "style": "IPY_MODEL_f084b7dfcae445a58d36a9c21971793c",
36
- "tooltip": ""
37
- }
38
- },
39
- "49441085d85a4f219a6ccbf2a197f527": {
40
- "model_module": "@jupyter-widgets/base",
41
- "model_name": "LayoutModel",
42
- "model_module_version": "1.2.0",
43
- "state": {
44
- "_model_module": "@jupyter-widgets/base",
45
- "_model_module_version": "1.2.0",
46
- "_model_name": "LayoutModel",
47
- "_view_count": null,
48
- "_view_module": "@jupyter-widgets/base",
49
- "_view_module_version": "1.2.0",
50
- "_view_name": "LayoutView",
51
- "align_content": null,
52
- "align_items": null,
53
- "align_self": null,
54
- "border": null,
55
- "bottom": null,
56
- "display": null,
57
- "flex": null,
58
- "flex_flow": null,
59
- "grid_area": null,
60
- "grid_auto_columns": null,
61
- "grid_auto_flow": null,
62
- "grid_auto_rows": null,
63
- "grid_column": null,
64
- "grid_gap": null,
65
- "grid_row": null,
66
- "grid_template_areas": null,
67
- "grid_template_columns": null,
68
- "grid_template_rows": null,
69
- "height": null,
70
- "justify_content": null,
71
- "justify_items": null,
72
- "left": null,
73
- "margin": null,
74
- "max_height": null,
75
- "max_width": null,
76
- "min_height": null,
77
- "min_width": "50px",
78
- "object_fit": null,
79
- "object_position": null,
80
- "order": null,
81
- "overflow": null,
82
- "overflow_x": null,
83
- "overflow_y": null,
84
- "padding": null,
85
- "right": null,
86
- "top": null,
87
- "visibility": null,
88
- "width": null
89
- }
90
- },
91
- "f084b7dfcae445a58d36a9c21971793c": {
92
- "model_module": "@jupyter-widgets/controls",
93
- "model_name": "ButtonStyleModel",
94
- "model_module_version": "1.5.0",
95
- "state": {
96
- "_model_module": "@jupyter-widgets/controls",
97
- "_model_module_version": "1.5.0",
98
- "_model_name": "ButtonStyleModel",
99
- "_view_count": null,
100
- "_view_module": "@jupyter-widgets/base",
101
- "_view_module_version": "1.2.0",
102
- "_view_name": "StyleView",
103
- "button_color": null,
104
- "font_weight": ""
105
- }
106
- }
107
- }
108
- }
109
- },
110
- "cells": [
111
- {
112
- "cell_type": "code",
113
- "source": [
114
- "#@title Mount Google Drive\n",
115
- "from google.colab import drive\n",
116
- "from IPython.display import clear_output\n",
117
- "from IPython.display import display\n",
118
- "import ipywidgets as widgets\n",
119
- "import os\n",
120
- "\n",
121
- "def inf(msg, style, wdth): inf = widgets.Button(description=msg, disabled=True, button_style=style, layout=widgets.Layout(min_width=wdth));display(inf)\n",
122
- "Shared_Drive = \"\" #@param {type:\"string\"}\n",
123
- "#@markdown - If you're not using a shared drive, leave this empty\n",
124
- "\n",
125
- "print(\"\u001b[0;33mConnecting...\")\n",
126
- "drive.mount('/content/gdrive')\n",
127
- "\n",
128
- "if Shared_Drive!=\"\" and os.path.exists(\"/content/gdrive/Shareddrives\"):\n",
129
- " mainpth=\"Shareddrives/\"+Shared_Drive\n",
130
- "else:\n",
131
- " mainpth=\"MyDrive\"\n",
132
- "\n",
133
- "clear_output()\n",
134
- "inf('\\u2714 Done','success', '50px')"
135
- ],
136
- "metadata": {
137
- "id": "fCR2boKCTx0z",
138
- "cellView": "form",
139
- "outputId": "baf6303f-9850-4dd2-a6d3-86871ac8aef5",
140
- "colab": {
141
- "base_uri": "https://localhost:8080/",
142
- "height": 49,
143
- "referenced_widgets": [
144
- "a44dd6024769456a8262a17b0ce6a2ed",
145
- "49441085d85a4f219a6ccbf2a197f527",
146
- "f084b7dfcae445a58d36a9c21971793c"
147
- ]
148
- }
149
- },
150
- "execution_count": null,
151
- "outputs": [
152
- {
153
- "output_type": "display_data",
154
- "data": {
155
- "text/plain": [
156
- "Button(button_style='success', description='✔ Done', disabled=True, layout=Layout(min_width='50px'), style=But…"
157
- ],
158
- "application/vnd.jupyter.widget-view+json": {
159
- "version_major": 2,
160
- "version_minor": 0,
161
- "model_id": "a44dd6024769456a8262a17b0ce6a2ed"
162
- }
163
- },
164
- "metadata": {}
165
- }
166
- ]
167
- },
168
- {
169
- "cell_type": "code",
170
- "source": [
171
- "#@title Install Required Dependencies\n",
172
- "!pip install torch\n",
173
- "!pip install safetensors\n",
174
- "!pip install pytorch-lightning"
175
- ],
176
- "metadata": {
177
- "id": "5S88gkUJzeqG"
178
- },
179
- "execution_count": null,
180
- "outputs": []
181
- },
182
- {
183
- "cell_type": "code",
184
- "source": [
185
- "def inf(msg, style, wdth): inf = widgets.Button(description=msg, disabled=True, button_style=style, layout=widgets.Layout(min_width=wdth));display(inf)\n",
186
- "file_path = \"\" #@param {type:\"string\"}\n",
187
- "#@markdown - Copy and paste the path to an embedding or VAE file that you are converting, or a directory containing several files\n",
188
- "#@markdown - For example: /content/gdrive/MyDrive/myembedding.pt or /content/gdrive/MyDrive/my_directory\n",
189
- "#@markdown - Pickle files must be in .pt format\n",
190
- "verbose=True"
191
- ],
192
- "metadata": {
193
- "id": "7aLFC6c4O5EW"
194
- },
195
- "execution_count": null,
196
- "outputs": []
197
- },
198
- {
199
- "cell_type": "code",
200
- "source": [
201
- "#@title Define Converter Functions\n",
202
- "import os\n",
203
- "from typing import Any, Dict\n",
204
- "\n",
205
- "import torch\n",
206
- "from safetensors.torch import save_file\n",
207
- "\n",
208
- "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
209
- "\n",
210
- "def process_pt_files(path: str, model_type: str, verbose=True) -> None:\n",
211
- " if os.path.isdir(path):\n",
212
- " # Path is a directory, process all .pt files in the directory\n",
213
- " for file_name in os.listdir(path):\n",
214
- " if file_name.endswith('.pt'):\n",
215
- " process_file(os.path.join(path, file_name), model_type, verbose)\n",
216
- " elif os.path.isfile(path) and path.endswith('.pt'):\n",
217
- " # Path is a .pt file, process this file\n",
218
- " process_file(path, model_type, verbose)\n",
219
- " else:\n",
220
- " print(f\"{path} is not a valid directory or .pt file.\")\n",
221
- "\n",
222
- "def process_file(file_path: str, model_type: str, verbose: bool) -> None:\n",
223
- " # Load the PyTorch model\n",
224
- " model = torch.load(file_path, map_location=device)\n",
225
- "\n",
226
- " if verbose:\n",
227
- " print(file_path)\n",
228
- "\n",
229
- " if model_type == 'embedding':\n",
230
- " s_model = process_embedding_file(model, verbose)\n",
231
- " elif model_type == 'vae':\n",
232
- " s_model = process_vae_file(model, verbose)\n",
233
- " else:\n",
234
- " raise Exception(f\"model_type `{model_type}` is not supported!\")\n",
235
- "\n",
236
- " # Save the model with the new extension\n",
237
- " if file_path.endswith('.pt'):\n",
238
- " new_file_path = file_path[:-3] + '.safetensors'\n",
239
- " else:\n",
240
- " new_file_path = file_path + '.safetensors'\n",
241
- " save_file(s_model, new_file_path)\n",
242
- "\n",
243
- "def process_embedding_file(model: Dict[str, Any], verbose: bool) -> Dict[str, torch.Tensor]:\n",
244
- " # Extract the embedding tensors\n",
245
- " model_tensors = model.get('string_to_param').get('*')\n",
246
- " s_model = {\n",
247
- " 'emb_params': model_tensors\n",
248
- " }\n",
249
- "\n",
250
- " if verbose:\n",
251
- " # Print the requested training information, if it exists\n",
252
- " if ('sd_checkpoint_name' in model) and (model['sd_checkpoint_name'] is not None):\n",
253
- " print(f\"Trained on {model['sd_checkpoint_name']}.\")\n",
254
- " else:\n",
255
- " print(\"Checkpoint name not found in the model.\")\n",
256
- "\n",
257
- " if ('step' in model) and (model['step'] is not None):\n",
258
- " print(f\"Trained for {model['step']} steps.\")\n",
259
- " else:\n",
260
- " print(\"Step not found in the model.\")\n",
261
- " # Display the tensor's shape\n",
262
- " print(f\"Dimensions of embedding tensor: {model_tensors.shape}\")\n",
263
- " print()\n",
264
- "\n",
265
- " return s_model\n",
266
- "\n",
267
- "def process_vae_file(model: Dict[str, Any], verbose: bool) -> Dict[str, torch.Tensor]:\n",
268
- " # Extract the state dictionary\n",
269
- " s_model = model[\"state_dict\"]\n",
270
- " if verbose:\n",
271
- " # Print the requested training information, if it exists\n",
272
- " step = model.get('step', model.get('global_step'))\n",
273
- " if step is not None:\n",
274
- " print(f\"Trained for {step} steps.\")\n",
275
- " else:\n",
276
- " print(\"Step not found in the model.\")\n",
277
- " print()\n",
278
- " return s_model"
279
- ],
280
- "metadata": {
281
- "id": "UwH1lXmGw9XP"
282
- },
283
- "execution_count": null,
284
- "outputs": []
285
- },
286
- {
287
- "cell_type": "markdown",
288
- "source": [
289
- "## Convert the file(s)\n",
290
- "\n",
291
- "Run whichever of the two following code blocks corresponds to the type of file you are converting.\n",
292
- "\n",
293
- "The converted Safetensor file will be saved in the same directory as the original."
294
- ],
295
- "metadata": {
296
- "id": "LqEl4sM0sMPG"
297
- }
298
- },
299
- {
300
- "cell_type": "code",
301
- "source": [
302
- "#@title Convert the Embedding(s)\n",
303
- "process_pt_files(file_path, 'embedding', verbose=verbose)"
304
- ],
305
- "metadata": {
306
- "id": "4LEWGfjiUeG1",
307
- "cellView": "form"
308
- },
309
- "execution_count": null,
310
- "outputs": []
311
- },
312
- {
313
- "cell_type": "code",
314
- "source": [
315
- "#@title Convert the VAE(s)\n",
316
- "process_pt_files(file_path, 'vae', verbose=verbose)"
317
- ],
318
- "metadata": {
319
- "id": "Jil7A1ckyiHA",
320
- "cellView": "form"
321
- },
322
- "execution_count": null,
323
- "outputs": []
324
- }
325
- ]
326
- }