Spaces:
Running
Running
File size: 12,178 Bytes
08e5ef1 173d502 0aca28b 173d502 6c22125 173d502 5fd1a0a 7edda8b 173d502 75b770e 08e5ef1 aa85862 ac97e5b 08e5ef1 1fba392 925d15e 08e5ef1 0aca28b 6c22125 0aca28b 6c22125 0aca28b 2bede7c 0aca28b c613bb1 0aca28b ac97e5b 925d15e 7686e09 173d502 3c53b8d ea0a3af 173d502 ea0a3af 173d502 0aca28b 12b0af6 5b4e988 aa85862 173d502 aa85862 0aca28b ee1187d 345fe11 ac97e5b 344e270 ac97e5b 75ab8ca ac97e5b ee1187d ac97e5b 12b0af6 79202e2 12b0af6 7e7dbaf 4cd057e ac97e5b 6c22125 0aca28b 6c22125 0aca28b 6c22125 0aca28b 173d502 7c36326 aa85862 5696fee eefa44d 9781999 d1518f3 0aca28b d1518f3 0aca28b d1518f3 0aca28b d1518f3 0aca28b d1518f3 9781999 173d502 9781999 5b4e988 9781999 00dc59f 2bede7c 00dc59f 098f871 ec000c3 3ad22ce 4c4c78d 3ad22ce 4c4c78d d1518f3 098f871 4c4c78d 892a74e 3ad22ce 4c4c78d 3ad22ce d1518f3 c360795 3ad22ce 2bede7c 925d15e 098f871 925d15e b31944c 925d15e 2bede7c 4566890 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 |
import os
import tempfile
import importlib.util
from enum import Enum
from contextlib import contextmanager, AbstractContextManager
from functools import wraps
os.environ["HF_HUB_CACHE"] = "cache"
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
import gradio as gr
from huggingface_hub import HfApi
from huggingface_hub import whoami
from huggingface_hub import ModelCard
from huggingface_hub import scan_cache_dir
from huggingface_hub import logging
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from apscheduler.schedulers.background import BackgroundScheduler
from textwrap import dedent
from typing import (
Any,
Callable,
Dict,
Optional,
Tuple,
Type,
Union,
NamedTuple,
)
import mlx.nn as nn
import mlx_lm
from mlx_lm.utils import (
load_config,
get_model_path,
)
import mlx_vlm
# mlx-lm/mlx_lm/utils.py
MODEL_REMAPPING_MLX_LM = {
"mistral": "llama", # mistral is compatible with llama
"phi-msft": "phixtral",
"falcon_mamba": "mamba",
}
# mlx-vlm/mlx_vlm/utils.py
MODEL_REMAPPING_MLX_VLM = {
"llava-qwen2": "llava_bunny",
"bunny-llama": "llava_bunny",
}
MODEL_REMAPPING = {
**MODEL_REMAPPING_MLX_LM,
**MODEL_REMAPPING_MLX_VLM,
}
HF_TOKEN = os.environ.get("HF_TOKEN")
# I'm not sure if we need to add more stuff here
QUANT_PARAMS = {
"Q2": 2,
"Q3": 3,
"Q4": 4,
"Q6": 6,
"Q8": 8,
}
class RuntimeInfo(NamedTuple):
name: str
package: str
version: str
convert_fn: Callable
usage_example: Callable[[str], str]
format: str = "MLX"
class Runtime(RuntimeInfo, Enum):
MLX_LM = RuntimeInfo(
name="MLX LM",
package="mlx-lm",
version=mlx_lm.__version__,
convert_fn=mlx_lm.convert,
usage_example=lambda upload_repo: dedent(
f"""
## Use with mlx
```bash
pip install mlx-lm
```
```python
from mlx_lm import load, generate
model, tokenizer = load("{upload_repo}")
prompt="hello"
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
messages = [{{"role": "user", "content": prompt}}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
response = generate(model, tokenizer, prompt=prompt, verbose=True)
```
"""
)
)
MLX_VLM = RuntimeInfo(
name="MLX-VLM",
package="mlx-vlm",
version=mlx_vlm.__version__,
convert_fn=mlx_vlm.convert,
usage_example=lambda upload_repo: dedent(
f"""
```bash
pip install -U mlx-vlm
```
```bash
python -m mlx_vlm.generate --model {upload_repo} --max-tokens 100 --temp 0.0 --prompt "Describe this image." --image <path_to_image>
```
"""
)
)
def list_files_in_folder(folder_path):
# List all files and directories in the specified folder
all_items = os.listdir(folder_path)
# Filter out only files
files = [item for item in all_items if os.path.isfile(os.path.join(folder_path, item))]
return files
def clear_hf_cache_space():
scan = scan_cache_dir()
to_delete = []
for repo in scan.repos:
if repo.repo_type == "model":
to_delete.extend([rev.commit_hash for rev in repo.revisions])
scan.delete_revisions(*to_delete).execute()
print("Cache has been cleared")
def upload_to_hub(path, upload_repo, hf_path, oauth_token, runtime: Runtime):
card = ModelCard.load(hf_path, token=oauth_token.token)
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx", "mlx-my-repo"]
card.data.base_model = hf_path
card.text = dedent(
f"""
# {upload_repo}
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to {runtime.format} format from [{hf_path}](https://huggingface.co/{hf_path}) using {runtime.package} version **{runtime.version}**.
"""
) + runtime.usage_example(upload_repo)
card.save(os.path.join(path, "README.md"))
logging.set_verbosity_info()
api = HfApi(token=oauth_token.token)
api.create_repo(repo_id=upload_repo, exist_ok=True)
files = list_files_in_folder(path)
print(files)
for file in files:
file_path = os.path.join(path, file)
print(f"Uploading file: {file_path}")
api.upload_file(
path_or_fileobj=file_path,
path_in_repo=file,
repo_id=upload_repo,
)
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
@contextmanager
def patch_strict_default_methods_ctx() -> AbstractContextManager[Callable[[Any, str], None]]:
"""
Context manager to temporarily set the default value of the 'strict' arg to `False`
for specified class methods.
Does not affect explict `strict=True`.
(e.g. `def update(self, parameters: dict, strict: bool = True)`
becomes `def update(self, parameters: dict, strict: bool = False)`)
Typical usage:
with patch_strict_default_methods_ctx() as patch:
patch(Foo, "bar")
patch(Foo, "baz")
patch(Bar, "foo")
# Patched methods active here
# Originals restored here
"""
originals: Dict[Tuple[Type[Any], str], Callable] = {}
def patch(cls: Any, method_name: str):
method = getattr(cls, method_name)
originals[(cls, method_name)] = method
@wraps(method)
def wrapper(self, *args, strict=False, **kwargs):
return method(self, *args, strict=strict, **kwargs)
setattr(cls, method_name, wrapper)
try:
yield patch
finally:
# Restore all patched methods
for (cls, method_name), original in originals.items():
setattr(cls, method_name, original)
originals.clear()
def convert(
hf_path: str,
mlx_path: str = "mlx_model",
quantize: bool = False,
q_group_size: int = 64,
q_bits: int = 4,
dtype: Optional[str] = None,
upload_repo: str = None,
revision: Optional[str] = None,
dequantize: bool = False,
quant_predicate: Optional[
Union[Callable[[str, nn.Module, dict], Union[bool, dict]], str]
] = None, # mlx-lm
skip_vision: bool = False, # mlx-vlm
trust_remote_code: bool = True, # mlx-vlm
) -> Runtime :
model_path = get_model_path(hf_path, revision=revision)
def mlx_lm_convert():
mlx_lm.convert(
hf_path=hf_path,
mlx_path=mlx_path,
quantize=quantize,
q_group_size=q_group_size,
q_bits=q_bits,
dtype=dtype,
upload_repo=upload_repo,
revision=revision,
dequantize=dequantize,
quant_predicate=quant_predicate,
)
def mlx_vlm_convert():
def _mlx_vlm_convert():
mlx_vlm.convert(
#hf_path=new_model_path,
hf_path=hf_path,
mlx_path=mlx_path,
quantize=quantize,
q_group_size=q_group_size,
q_bits=q_bits,
dtype=dtype,
upload_repo=upload_repo,
revision=revision,
dequantize=dequantize,
skip_vision=skip_vision,
trust_remote_code=trust_remote_code,
)
try:
_mlx_vlm_convert()
except ValueError as e:
print(e)
print(f"Error converting, try again with strict = False")
with patch_strict_default_methods_ctx() as patch:
import mlx.nn as n
patch(nn.Module, "load_weights")
patch(nn.Module, "update")
patch(nn.Module, "update_modules")
# patched strict=False by default, try again
_mlx_vlm_convert()
config = load_config(model_path)
model_type = config["model_type"]
model_type = MODEL_REMAPPING.get(model_type, model_type)
is_lm = importlib.util.find_spec(f"mlx_lm.models.{model_type}") is not None
is_vlm = importlib.util.find_spec(f"mlx_vlm.models.{model_type}") is not None
if is_lm and (not is_vlm):
mlx_lm_convert()
runtime = Runtime.MLX_LM
elif is_vlm and (not is_lm):
mlx_vlm_convert()
runtime = Runtime.MLX_VLM
else:
# fallback in-case our MODEL_REMAPPING is outdated
try:
mlx_vlm_convert()
runtime = Runtime.MLX_VLM
except Exception as e:
mlx_lm_convert()
runtime = Runtime.MLX_LM
return runtime
def process_model(model_id, q_method, oauth_token: gr.OAuthToken | None):
if oauth_token.token is None:
raise ValueError("You must be logged in to use MLX-my-repo")
model_name = model_id.split('/')[-1]
username = whoami(oauth_token.token)["name"]
try:
if q_method == "FP16":
upload_repo = f"{username}/{model_name}-mlx-fp16"
with tempfile.TemporaryDirectory(dir="converted") as tmpdir:
# The target directory must not exist
mlx_path = os.path.join(tmpdir, "mlx")
runtime = convert(model_id, mlx_path=mlx_path, quantize=False, dtype="float16")
print("Conversion done")
upload_to_hub(path=mlx_path, upload_repo=upload_repo, hf_path=model_id, oauth_token=oauth_token, runtime=runtime)
print("Upload done")
else:
q_bits = QUANT_PARAMS[q_method]
upload_repo = f"{username}/{model_name}-mlx-{q_bits}Bit"
with tempfile.TemporaryDirectory(dir="converted") as tmpdir:
# The target directory must not exist
mlx_path = os.path.join(tmpdir, "mlx")
runtime = convert(model_id, mlx_path=mlx_path, quantize=True, q_bits=q_bits)
print("Conversion done")
upload_to_hub(path=mlx_path, upload_repo=upload_repo, hf_path=model_id, oauth_token=oauth_token, runtime=runtime)
print("Upload done")
return (
f'Find your repo <a href="https://hf.co/{upload_repo}" target="_blank" style="text-decoration:underline">here</a>',
"llama.png",
)
except Exception as e:
return (f"Error: {e}", "error.png")
finally:
clear_hf_cache_space()
print("Folder cleaned up successfully!")
css="""/* Custom CSS to allow scrolling */
.gradio-container {overflow-y: auto;}
"""
# Create Gradio interface
with gr.Blocks(css=css) as demo:
gr.Markdown("You must be logged in to use MLX-my-repo.")
gr.LoginButton(min_width=250)
model_id = HuggingfaceHubSearch(
label="Hub Model ID",
placeholder="Search for model id on Huggingface",
search_type="model",
)
q_method = gr.Dropdown(
["FP16", "Q2", "Q3", "Q4", "Q6", "Q8"],
label="Conversion Method",
info="MLX conversion type (FP16 for float16, Q2–Q8 for quantized models)",
value="Q4",
filterable=False,
visible=True
)
iface = gr.Interface(
fn=process_model,
inputs=[
model_id,
q_method,
],
outputs=[
gr.Markdown(label="output"),
gr.Image(show_label=False),
],
title="Create your own MLX Models, blazingly fast ⚡!",
description="The space takes an HF repo as an input, converts it to MLX format (FP16 or quantized), and creates a Public/Private repo under your HF user namespace.",
api_name=False
)
def restart_space():
HfApi().restart_space(repo_id="reach-vb/mlx-my-repo", token=HF_TOKEN, factory_reboot=True)
scheduler = BackgroundScheduler()
scheduler.add_job(restart_space, "interval", seconds=21600)
scheduler.start()
# Launch the interface
demo.queue(default_concurrency_limit=1, max_size=5).launch(debug=True, show_api=False)
|