File size: 12,178 Bytes
08e5ef1
173d502
0aca28b
 
173d502
6c22125
 
 
173d502
5fd1a0a
7edda8b
 
173d502
75b770e
08e5ef1
aa85862
ac97e5b
08e5ef1
1fba392
925d15e
 
08e5ef1
0aca28b
6c22125
0aca28b
 
 
6c22125
 
0aca28b
 
 
2bede7c
0aca28b
c613bb1
0aca28b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac97e5b
925d15e
7686e09
173d502
 
3c53b8d
ea0a3af
173d502
ea0a3af
173d502
 
 
0aca28b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12b0af6
 
 
 
 
 
 
 
 
5b4e988
aa85862
 
 
 
173d502
 
aa85862
 
0aca28b
ee1187d
345fe11
ac97e5b
 
 
 
 
344e270
ac97e5b
 
75ab8ca
ac97e5b
 
 
 
ee1187d
ac97e5b
12b0af6
 
 
 
79202e2
12b0af6
 
 
 
7e7dbaf
4cd057e
 
ac97e5b
 
6c22125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0aca28b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c22125
 
0aca28b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c22125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0aca28b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173d502
7c36326
aa85862
 
5696fee
eefa44d
9781999
d1518f3
 
 
 
 
0aca28b
d1518f3
0aca28b
d1518f3
 
 
 
 
 
 
0aca28b
d1518f3
0aca28b
d1518f3
9781999
173d502
9781999
 
 
 
 
5b4e988
9781999
 
00dc59f
 
 
2bede7c
00dc59f
098f871
ec000c3
3ad22ce
4c4c78d
3ad22ce
 
 
 
 
4c4c78d
d1518f3
 
 
098f871
4c4c78d
 
 
892a74e
3ad22ce
 
 
4c4c78d
 
3ad22ce
 
 
 
 
d1518f3
 
c360795
3ad22ce
2bede7c
925d15e
098f871
925d15e
 
b31944c
925d15e
 
2bede7c
4566890
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
import os
import tempfile
import importlib.util
from enum import Enum

from contextlib import contextmanager, AbstractContextManager
from functools import wraps

os.environ["HF_HUB_CACHE"] = "cache"
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
import gradio as gr

from huggingface_hub import HfApi
from huggingface_hub import whoami
from huggingface_hub import ModelCard
from huggingface_hub import scan_cache_dir
from huggingface_hub import logging

from gradio_huggingfacehub_search import HuggingfaceHubSearch
from apscheduler.schedulers.background import BackgroundScheduler

from textwrap import dedent
from typing import (
    Any,
    Callable,
    Dict,
    Optional,
    Tuple,
    Type,
    Union,
    NamedTuple,
)

import mlx.nn as nn
import mlx_lm
from mlx_lm.utils import (
    load_config,
    get_model_path,
)
import mlx_vlm


# mlx-lm/mlx_lm/utils.py
MODEL_REMAPPING_MLX_LM = {
    "mistral": "llama",  # mistral is compatible with llama
    "phi-msft": "phixtral",
    "falcon_mamba": "mamba",
}

# mlx-vlm/mlx_vlm/utils.py
MODEL_REMAPPING_MLX_VLM = {
    "llava-qwen2": "llava_bunny",
    "bunny-llama": "llava_bunny",
}

MODEL_REMAPPING = {
    **MODEL_REMAPPING_MLX_LM,
    **MODEL_REMAPPING_MLX_VLM,
}

HF_TOKEN = os.environ.get("HF_TOKEN")

# I'm not sure if we need to add more stuff here
QUANT_PARAMS = {
    "Q2": 2,
    "Q3": 3,
    "Q4": 4,
    "Q6": 6,
    "Q8": 8,
}

class RuntimeInfo(NamedTuple):
    name: str
    package: str
    version: str
    convert_fn: Callable
    usage_example: Callable[[str], str]
    format: str = "MLX"

class Runtime(RuntimeInfo, Enum):
    MLX_LM = RuntimeInfo(
        name="MLX LM",
        package="mlx-lm",
        version=mlx_lm.__version__,
        convert_fn=mlx_lm.convert,
        usage_example=lambda upload_repo: dedent(
            f"""
            ## Use with mlx

            ```bash
            pip install mlx-lm
            ```

            ```python
            from mlx_lm import load, generate

            model, tokenizer = load("{upload_repo}")

            prompt="hello"

            if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
                messages = [{{"role": "user", "content": prompt}}]
                prompt = tokenizer.apply_chat_template(
                    messages, tokenize=False, add_generation_prompt=True
                )

            response = generate(model, tokenizer, prompt=prompt, verbose=True)
            ```
            """
        )
    )
    MLX_VLM = RuntimeInfo(
        name="MLX-VLM",
        package="mlx-vlm",
        version=mlx_vlm.__version__,
        convert_fn=mlx_vlm.convert,
        usage_example=lambda upload_repo: dedent(
            f"""
            ```bash
            pip install -U mlx-vlm
            ```

            ```bash
            python -m mlx_vlm.generate --model {upload_repo} --max-tokens 100 --temp 0.0 --prompt "Describe this image." --image <path_to_image>
            ```
            """
        )
    )

def list_files_in_folder(folder_path):
    # List all files and directories in the specified folder
    all_items = os.listdir(folder_path)
    
    # Filter out only files
    files = [item for item in all_items if os.path.isfile(os.path.join(folder_path, item))]
    
    return files

def clear_hf_cache_space():
    scan = scan_cache_dir()
    to_delete = []
    for repo in scan.repos:
        if repo.repo_type == "model":
            to_delete.extend([rev.commit_hash for rev in repo.revisions])
    scan.delete_revisions(*to_delete).execute()
    print("Cache has been cleared")

def upload_to_hub(path, upload_repo, hf_path, oauth_token, runtime: Runtime):
    card = ModelCard.load(hf_path, token=oauth_token.token)
    card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx", "mlx-my-repo"]
    card.data.base_model = hf_path
    card.text = dedent(
        f"""
        # {upload_repo}

        The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to {runtime.format} format from [{hf_path}](https://huggingface.co/{hf_path}) using {runtime.package} version **{runtime.version}**.

        """
    ) + runtime.usage_example(upload_repo)
    card.save(os.path.join(path, "README.md"))

    logging.set_verbosity_info()

    api = HfApi(token=oauth_token.token)
    api.create_repo(repo_id=upload_repo, exist_ok=True)

    files = list_files_in_folder(path)
    print(files)
    for file in files:
        file_path = os.path.join(path, file)
        print(f"Uploading file: {file_path}")
        api.upload_file(
            path_or_fileobj=file_path,
            path_in_repo=file,
            repo_id=upload_repo,
        )
        
    print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")    


@contextmanager
def patch_strict_default_methods_ctx() -> AbstractContextManager[Callable[[Any, str], None]]:
    """
    Context manager to temporarily set the default value of the 'strict' arg to `False` 
    for specified class methods.
    Does not affect explict `strict=True`.

    (e.g. `def update(self, parameters: dict, strict: bool = True)` 
    becomes `def update(self, parameters: dict, strict: bool = False)`)

    Typical usage:

        with patch_strict_default_methods_ctx() as patch:
            patch(Foo, "bar")
            patch(Foo, "baz")
            patch(Bar, "foo")
            # Patched methods active here
        # Originals restored here
    """

    originals: Dict[Tuple[Type[Any], str], Callable] = {}

    def patch(cls: Any, method_name: str):
        method = getattr(cls, method_name)
        originals[(cls, method_name)] = method

        @wraps(method)
        def wrapper(self, *args, strict=False, **kwargs):
            return method(self, *args, strict=strict, **kwargs)

        setattr(cls, method_name, wrapper)

    try:
        yield patch
    finally:
        # Restore all patched methods
        for (cls, method_name), original in originals.items():
            setattr(cls, method_name, original)
        originals.clear()

def convert(
    hf_path: str,
    mlx_path: str = "mlx_model",
    quantize: bool = False,
    q_group_size: int = 64,
    q_bits: int = 4,
    dtype: Optional[str] = None,
    upload_repo: str = None,
    revision: Optional[str] = None,
    dequantize: bool = False,
    quant_predicate: Optional[
        Union[Callable[[str, nn.Module, dict], Union[bool, dict]], str]
    ] = None, # mlx-lm
    skip_vision: bool = False, # mlx-vlm
    trust_remote_code: bool = True, # mlx-vlm
) -> Runtime :
    model_path = get_model_path(hf_path, revision=revision)

    def mlx_lm_convert():
        mlx_lm.convert(
            hf_path=hf_path,
            mlx_path=mlx_path,
            quantize=quantize,
            q_group_size=q_group_size,
            q_bits=q_bits,
            dtype=dtype,
            upload_repo=upload_repo,
            revision=revision,
            dequantize=dequantize,
            quant_predicate=quant_predicate,
        )

    def mlx_vlm_convert():
        def _mlx_vlm_convert():
            mlx_vlm.convert(
                #hf_path=new_model_path,
                hf_path=hf_path,
                mlx_path=mlx_path,
                quantize=quantize,
                q_group_size=q_group_size,
                q_bits=q_bits,
                dtype=dtype,
                upload_repo=upload_repo,
                revision=revision,
                dequantize=dequantize,
                skip_vision=skip_vision,
                trust_remote_code=trust_remote_code,
            )

        try:
            _mlx_vlm_convert()
        except ValueError as e:
            print(e)
            print(f"Error converting, try again with strict = False")
            with patch_strict_default_methods_ctx() as patch:
                import mlx.nn as n
                patch(nn.Module, "load_weights")
                patch(nn.Module, "update")
                patch(nn.Module, "update_modules")
                # patched strict=False by default, try again
                _mlx_vlm_convert()

    config = load_config(model_path)
    model_type = config["model_type"]
    model_type = MODEL_REMAPPING.get(model_type, model_type)

    is_lm = importlib.util.find_spec(f"mlx_lm.models.{model_type}") is not None
    is_vlm = importlib.util.find_spec(f"mlx_vlm.models.{model_type}") is not None

    if is_lm and (not is_vlm):
        mlx_lm_convert()
        runtime = Runtime.MLX_LM
    elif is_vlm and (not is_lm):
        mlx_vlm_convert()
        runtime = Runtime.MLX_VLM
    else:
        # fallback in-case our MODEL_REMAPPING is outdated
        try:
            mlx_vlm_convert()
            runtime = Runtime.MLX_VLM
        except Exception as e:
            mlx_lm_convert()
            runtime = Runtime.MLX_LM
    return runtime


def process_model(model_id, q_method, oauth_token: gr.OAuthToken | None):
    if oauth_token.token is None:
        raise ValueError("You must be logged in to use MLX-my-repo")
    
    model_name = model_id.split('/')[-1]
    username = whoami(oauth_token.token)["name"]
    try:
        if q_method == "FP16":
            upload_repo = f"{username}/{model_name}-mlx-fp16"
            with tempfile.TemporaryDirectory(dir="converted") as tmpdir:
                # The target directory must not exist
                mlx_path = os.path.join(tmpdir, "mlx")
                runtime = convert(model_id, mlx_path=mlx_path, quantize=False, dtype="float16")
                print("Conversion done")
                upload_to_hub(path=mlx_path, upload_repo=upload_repo, hf_path=model_id, oauth_token=oauth_token, runtime=runtime)
                print("Upload done")
        else:
            q_bits = QUANT_PARAMS[q_method]
            upload_repo = f"{username}/{model_name}-mlx-{q_bits}Bit"
            with tempfile.TemporaryDirectory(dir="converted") as tmpdir:
                # The target directory must not exist
                mlx_path = os.path.join(tmpdir, "mlx")
                runtime = convert(model_id, mlx_path=mlx_path, quantize=True, q_bits=q_bits)
                print("Conversion done")
                upload_to_hub(path=mlx_path, upload_repo=upload_repo, hf_path=model_id, oauth_token=oauth_token, runtime=runtime)
                print("Upload done")
        return (
            f'Find your repo <a href="https://hf.co/{upload_repo}" target="_blank" style="text-decoration:underline">here</a>',
            "llama.png",
        )
    except Exception as e:
        return (f"Error: {e}", "error.png")
    finally:
        clear_hf_cache_space()
        print("Folder cleaned up successfully!")

css="""/* Custom CSS to allow scrolling */
.gradio-container {overflow-y: auto;}
"""
# Create Gradio interface
with gr.Blocks(css=css) as demo: 
    gr.Markdown("You must be logged in to use MLX-my-repo.")
    gr.LoginButton(min_width=250)

    model_id = HuggingfaceHubSearch(
        label="Hub Model ID",
        placeholder="Search for model id on Huggingface",
        search_type="model",
    )

    q_method = gr.Dropdown(
        ["FP16", "Q2", "Q3", "Q4", "Q6", "Q8"],
        label="Conversion Method",
        info="MLX conversion type (FP16 for float16, Q2–Q8 for quantized models)",
        value="Q4",
        filterable=False,
        visible=True
    )
    
    iface = gr.Interface(
        fn=process_model,
        inputs=[
            model_id,
            q_method,
        ],
        outputs=[
            gr.Markdown(label="output"),
            gr.Image(show_label=False),
        ],
        title="Create your own MLX Models, blazingly fast ⚡!",
        description="The space takes an HF repo as an input, converts it to MLX format (FP16 or quantized), and creates a Public/Private repo under your HF user namespace.",
        api_name=False
    )

def restart_space():
    HfApi().restart_space(repo_id="reach-vb/mlx-my-repo", token=HF_TOKEN, factory_reboot=True)

scheduler = BackgroundScheduler()
scheduler.add_job(restart_space, "interval", seconds=21600)
scheduler.start()

# Launch the interface
demo.queue(default_concurrency_limit=1, max_size=5).launch(debug=True, show_api=False)