andstor commited on
Commit
b47009b
·
1 Parent(s): f645242

Add initial version

Browse files
Files changed (4) hide show
  1. .gitignore +144 -0
  2. README.md +1 -1
  3. src/app.py +42 -0
  4. src/model_utils.py +76 -0
.gitignore ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ # VSCode
132
+ .vscode
133
+
134
+ # IntelliJ
135
+ .idea
136
+
137
+ # Mac .DS_Store
138
+ .DS_Store
139
+
140
+ # More test things
141
+ wandb
142
+
143
+ # ruff
144
+ .ruff_cache
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Model Visualizer
3
  emoji: 👁
4
  colorFrom: gray
5
  colorTo: red
 
1
  ---
2
+ title: Model Representation
3
  emoji: 👁
4
  colorFrom: gray
5
  colorTo: red
src/app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ from huggingface_hub.utils import HfHubHTTPError
4
+ from accelerate.commands.estimate import check_has_model, create_empty_model
5
+ from accelerate.utils import calculate_maximum_sizes
6
+ from model_utils import get_model
7
+
8
+ # We need to store them as globals because gradio doesn't have a way for us to pass them in to the button
9
+ MODEL = None
10
+
11
+
12
+ def get_results(model_name: str, library: str, precision: list, training: list, access_token: str, zero_stage: int, num_nodes: int, num_gpus: int, offloading: list, zero_init: list, additional_buffer_factor: float):
13
+ global MODEL
14
+
15
+ MODEL = get_model(model_name, library, access_token)
16
+
17
+ data = MODEL.__repr__()
18
+
19
+ title = f"## Model Representation for '{model_name}'"
20
+ return [title, gr.update(visible=True, value=data)]
21
+
22
+ with gr.Blocks() as demo:
23
+ with gr.Column():
24
+ out_text = gr.Markdown()
25
+ out = gr.Code()
26
+
27
+ with gr.Row():
28
+ inp = gr.Textbox(label="Model Name or URL", value="bert-base-cased")
29
+ with gr.Row():
30
+ library = gr.Radio(["auto", "transformers", "timm"], label="Library", value="auto")
31
+ access_token = gr.Textbox(label="API Token", placeholder="Optional (for gated models)")
32
+ with gr.Row():
33
+ btn = gr.Button("Calculate Memory Usage")
34
+
35
+
36
+ btn.click(
37
+ get_results,
38
+ inputs=[inp, library, access_token,],
39
+ outputs=[out_text, out],
40
+ )
41
+
42
+ demo.launch()
src/model_utils.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Utilities related to loading in and working with models/specific models
2
+ from urllib.parse import urlparse
3
+
4
+ import gradio as gr
5
+ import torch
6
+ from accelerate.commands.estimate import check_has_model, create_empty_model
7
+ from accelerate.utils import calculate_maximum_sizes, convert_bytes
8
+ from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
9
+
10
+
11
+ DTYPE_MODIFIER = {"float32": 1, "float16/bfloat16": 2, "int8": 4, "int4": 8}
12
+ PRECISION = {"Mixed precision": "mixed", "Single precision": "single"}
13
+ DTYPE = {"float32": torch.float32, "float16/bfloat16": torch.float16}
14
+
15
+
16
+ def extract_from_url(name: str):
17
+ "Checks if `name` is a URL, and if so converts it to a model name"
18
+ is_url = False
19
+ try:
20
+ result = urlparse(name)
21
+ is_url = all([result.scheme, result.netloc])
22
+ except Exception:
23
+ is_url = False
24
+ # Pass through if not a URL
25
+ if not is_url:
26
+ return name
27
+ else:
28
+ path = result.path
29
+ return path[1:]
30
+
31
+
32
+ def translate_llama2(text):
33
+ "Translates llama-2 to its hf counterpart"
34
+ if not text.endswith("-hf"):
35
+ return text + "-hf"
36
+ return text
37
+
38
+
39
+ def get_model(model_name: str, library: str, access_token: str):
40
+ "Finds and grabs model from the Hub, and initializes on `meta`"
41
+ if "meta-llama" in model_name:
42
+ model_name = translate_llama2(model_name)
43
+ if library == "auto":
44
+ library = None
45
+ model_name = extract_from_url(model_name)
46
+ try:
47
+ model = create_empty_model(model_name, library_name=library, trust_remote_code=True, access_token=access_token)
48
+ except GatedRepoError:
49
+ raise gr.Error(
50
+ f"Model `{model_name}` is a gated model, please ensure to pass in your access token and try again if you have access. You can find your access token here : https://huggingface.co/settings/tokens. "
51
+ )
52
+ except RepositoryNotFoundError:
53
+ raise gr.Error(f"Model `{model_name}` was not found on the Hub, please try another model name.")
54
+ except ValueError:
55
+ raise gr.Error(
56
+ f"Model `{model_name}` does not have any library metadata on the Hub, please manually select a library_name to use (such as `transformers`)"
57
+ )
58
+ except (RuntimeError, OSError) as e:
59
+ library = check_has_model(e)
60
+ if library != "unknown":
61
+ raise gr.Error(
62
+ f"Tried to load `{model_name}` with `{library}` but a possible model to load was not found inside the repo."
63
+ )
64
+ raise gr.Error(
65
+ f"Model `{model_name}` had an error, please open a discussion on the model's page with the error message and name: `{e}`"
66
+ )
67
+ except ImportError:
68
+ # hacky way to check if it works with `trust_remote_code=False`
69
+ model = create_empty_model(
70
+ model_name, library_name=library, trust_remote_code=False, access_token=access_token
71
+ )
72
+ except Exception as e:
73
+ raise gr.Error(
74
+ f"Model `{model_name}` had an error, please open a discussion on the model's page with the error message and name: `{e}`"
75
+ )
76
+ return model