masanorihirano commited on
Commit
dfb3b68
1 Parent(s): 90ca338
Files changed (7) hide show
  1. .gitignore +165 -0
  2. Dockerfile +47 -0
  3. Makefile +35 -0
  4. README.md +8 -4
  5. app.py +406 -0
  6. model_pull.py +18 -0
  7. pyproject.toml +63 -0
.gitignore ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ secret.txt
2
+ slack_url.txt
3
+ .idea
4
+ .env
5
+ poetry.lock
6
+ # Byte-compiled / optimized / DLL files
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+
11
+ # C extensions
12
+ *.so
13
+
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+ cover/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+ db.sqlite3-journal
68
+
69
+ # Flask stuff:
70
+ instance/
71
+ .webassets-cache
72
+
73
+ # Scrapy stuff:
74
+ .scrapy
75
+
76
+ # Sphinx documentation
77
+ docs/_build/
78
+
79
+ # PyBuilder
80
+ .pybuilder/
81
+ target/
82
+
83
+ # Jupyter Notebook
84
+ .ipynb_checkpoints
85
+
86
+ # IPython
87
+ profile_default/
88
+ ipython_config.py
89
+
90
+ # pyenv
91
+ # For a library or package, you might want to ignore these files since the code is
92
+ # intended to run in multiple environments; otherwise, check them in:
93
+ # .python-version
94
+
95
+ # pipenv
96
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
98
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
99
+ # install all needed dependencies.
100
+ #Pipfile.lock
101
+
102
+ # poetry
103
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
105
+ # commonly ignored for libraries.
106
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107
+ #poetry.lock
108
+
109
+ # pdm
110
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111
+ #pdm.lock
112
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113
+ # in version control.
114
+ # https://pdm.fming.dev/#use-with-ide
115
+ .pdm.toml
116
+
117
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118
+ __pypackages__/
119
+
120
+ # Celery stuff
121
+ celerybeat-schedule
122
+ celerybeat.pid
123
+
124
+ # SageMath parsed files
125
+ *.sage.py
126
+
127
+ # Environments
128
+ .env
129
+ .venv
130
+ env/
131
+ venv/
132
+ ENV/
133
+ env.bak/
134
+ venv.bak/
135
+
136
+ # Spyder project settings
137
+ .spyderproject
138
+ .spyproject
139
+
140
+ # Rope project settings
141
+ .ropeproject
142
+
143
+ # mkdocs documentation
144
+ /site
145
+
146
+ # mypy
147
+ .mypy_cache/
148
+ .dmypy.json
149
+ dmypy.json
150
+
151
+ # Pyre type checker
152
+ .pyre/
153
+
154
+ # pytype static type analyzer
155
+ .pytype/
156
+
157
+ # Cython debug symbols
158
+ cython_debug/
159
+
160
+ # PyCharm
161
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
164
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165
+ #.idea/
Dockerfile ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # syntax=docker/dockerfile:1.4
2
+ FROM docker.io/nvidia/cuda:11.7.1-cudnn8-devel-ubuntu20.04
3
+ ENV TZ=Asia/Tokyo
4
+ RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
5
+ RUN sed -i 's http://deb.debian.org http://cdn-aws.deb.debian.org g' /etc/apt/sources.list && \
6
+ sed -i 's http://archive.ubuntu.com http://us-east-1.ec2.archive.ubuntu.com g' /etc/apt/sources.list && \
7
+ sed -i '/security/d' /etc/apt/sources.list && apt-get update && \
8
+ apt-get install -y \
9
+ git \
10
+ make build-essential libssl-dev zlib1g-dev \
11
+ libbz2-dev libreadline-dev libsqlite3-dev wget curl llvm \
12
+ libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev libffi-dev liblzma-dev git-lfs \
13
+ ffmpeg libsm6 libxext6 cmake libgl1-mesa-glx \
14
+ python3.9-dev cuda-cudart-11-7 && \
15
+ rm -rf /var/lib/apt/lists/* && \
16
+ git lfs install
17
+
18
+ RUN useradd -m -u 1000 user
19
+ USER user
20
+
21
+ RUN curl https://pyenv.run | bash
22
+ ENV PYENV_ROOT /home/user/.pyenv
23
+ ENV PATH ${PYENV_ROOT}/shims:${PYENV_ROOT}/bin:${PATH}
24
+ RUN eval "$(pyenv init -)" && \
25
+ eval "$(pyenv virtualenv-init -)" && \
26
+ pyenv install 3.9.7 && \
27
+ pyenv global 3.9.7 && \
28
+ pyenv rehash && \
29
+ pip install --no-cache-dir --upgrade pip==22.3.1 setuptools wheel && \
30
+ pip install --no-cache-dir datasets "huggingface-hub>=0.12.1" "protobuf<4" "click<8.1" && \
31
+ curl -sSL https://install.python-poetry.org | python -
32
+ ENV PATH /home/user/.local/bin:${PATH}
33
+
34
+ COPY --link --chown=1000 ./pyproject.toml /home/user/app/pyproject.toml
35
+ COPY --link --chown=1000 ./model_pull.py /home/user/app/model_pull.py
36
+ WORKDIR /home/user/app
37
+
38
+ RUN poetry install
39
+ RUN --mount=type=secret,id=HF_TOKEN,mode=0444,required=true \
40
+ git config --global credential.helper store && \
41
+ huggingface-cli login --token $(cat /run/secrets/HF_TOKEN) --add-to-git-credential
42
+ RUN poetry run python model_pull.py
43
+
44
+ COPY --link --chown=1000 ./app.py /home/user/app/app.py
45
+
46
+ EXPOSE 7860
47
+ ENTRYPOINT ["/home/user/.local/bin/poetry", "run", "python", "app.py"]
Makefile ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ RUN := poetry run
3
+
4
+ .PHONY: check
5
+ check: lint mypy
6
+
7
+ .PHONY: lint
8
+ lint: lint-black lint-isort lint-flake8
9
+
10
+ .PHONY: lint-black
11
+ lint-black:
12
+ $(RUN) black --check --diff --quiet .
13
+
14
+ .PHONY: lint-isort
15
+ lint-isort:
16
+ $(RUN) isort --check --quiet .
17
+
18
+ .PHONY: lint-flake8
19
+ lint-flake8:
20
+ $(RUN) pflake8 .
21
+
22
+ .PHONY: mypy
23
+ mypy:
24
+ $(RUN) mypy .
25
+
26
+ .PHONY: format
27
+ format: format-black format-isort
28
+
29
+ .PHONY: format-black
30
+ format-black:
31
+ $(RUN) black --quiet .
32
+
33
+ .PHONY: format-isort
34
+ format-isort:
35
+ $(RUN) isort --quiet .
README.md CHANGED
@@ -1,11 +1,15 @@
1
  ---
2
- title: Stormy 7b 10ep
3
- emoji: 🌍
4
- colorFrom: red
5
- colorTo: green
6
  sdk: docker
 
7
  pinned: false
8
  license: mit
 
 
 
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: stormy 7b 10epochs
3
+ emoji: 💨
4
+ colorFrom: blue
5
+ colorTo: blue
6
  sdk: docker
7
+ app_port: 7860
8
  pinned: false
9
  license: mit
10
+ models:
11
+ - cyberagent/open-calm-7b
12
+ - izumi-lab/izumi-lab/stormy-7b-10ep
13
  ---
14
 
15
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import json
3
+ import os
4
+ import shutil
5
+ from typing import Optional
6
+ from typing import Tuple
7
+ from typing import Union
8
+
9
+ import gradio as gr
10
+ import requests
11
+ import torch
12
+ from fastchat.conversation import Conversation
13
+ from fastchat.conversation import SeparatorStyle
14
+ from fastchat.conversation import get_conv_template
15
+ from fastchat.conversation import register_conv_template
16
+ from fastchat.model.model_adapter import BaseAdapter
17
+ from fastchat.model.model_adapter import load_model
18
+ from fastchat.model.model_adapter import model_adapters
19
+ from fastchat.serve.cli import SimpleChatIO
20
+ from fastchat.serve.inference import generate_stream
21
+ from huggingface_hub import Repository
22
+ from huggingface_hub import snapshot_download
23
+ from peft import LoraConfig
24
+ from peft import PeftModel
25
+ from peft import get_peft_model
26
+ from peft import set_peft_model_state_dict
27
+ from transformers import AutoModelForCausalLM
28
+ from transformers import AutoTokenizer
29
+ from transformers import PreTrainedModel
30
+ from transformers import PreTrainedTokenizerBase
31
+
32
+
33
+ class FastTokenizerAvailableBaseAdapter(BaseAdapter):
34
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
35
+ try:
36
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
37
+ except ValueError:
38
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
39
+ model = AutoModelForCausalLM.from_pretrained(
40
+ model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
41
+ )
42
+ return model, tokenizer
43
+
44
+
45
+ model_adapters[-1] = FastTokenizerAvailableBaseAdapter()
46
+
47
+
48
+ def load_lora_model(
49
+ model_path: str,
50
+ lora_weight: str,
51
+ device: str,
52
+ num_gpus: int,
53
+ max_gpu_memory: Optional[str] = None,
54
+ load_8bit: bool = False,
55
+ cpu_offloading: bool = False,
56
+ debug: bool = False,
57
+ ) -> Tuple[Union[PreTrainedModel, PeftModel], PreTrainedTokenizerBase]:
58
+ model: Union[PreTrainedModel, PeftModel]
59
+ tokenizer: PreTrainedTokenizerBase
60
+ model, tokenizer = load_model(
61
+ model_path=model_path,
62
+ device=device,
63
+ num_gpus=num_gpus,
64
+ max_gpu_memory=max_gpu_memory,
65
+ load_8bit=load_8bit,
66
+ cpu_offloading=cpu_offloading,
67
+ debug=debug,
68
+ )
69
+ if lora_weight is not None:
70
+ # model = PeftModelForCausalLM.from_pretrained(model, model_path, **kwargs)
71
+ config = LoraConfig.from_pretrained(lora_weight)
72
+ model = get_peft_model(model, config)
73
+
74
+ # Check the available weights and load them
75
+ checkpoint_name = os.path.join(
76
+ lora_weight, "pytorch_model.bin"
77
+ ) # Full checkpoint
78
+ if not os.path.exists(checkpoint_name):
79
+ checkpoint_name = os.path.join(
80
+ lora_weight, "adapter_model.bin"
81
+ ) # only LoRA model - LoRA config above has to fit
82
+ # The two files above have a different name depending on how they were saved,
83
+ # but are actually the same.
84
+ if os.path.exists(checkpoint_name):
85
+ adapters_weights = torch.load(checkpoint_name)
86
+ set_peft_model_state_dict(model, adapters_weights)
87
+ else:
88
+ raise IOError(f"Checkpoint {checkpoint_name} not found")
89
+
90
+ if debug:
91
+ print(model)
92
+
93
+ return model, tokenizer
94
+
95
+
96
+ print(datetime.datetime.now())
97
+
98
+ NUM_THREADS = 1
99
+
100
+ print(NUM_THREADS)
101
+
102
+ print("starting server ...")
103
+
104
+ BASE_MODEL = "cyberagent/open-calm-7b"
105
+ LORA_WEIGHTS_HF = "izumi-lab/stormy-7b-5ep"
106
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
107
+ DATASET_REPOSITORY = os.environ.get("DATASET_REPOSITORY", None)
108
+ SLACK_WEBHOOK = os.environ.get("SLACK_WEBHOOK", None)
109
+
110
+ LORA_WEIGHTS = snapshot_download(LORA_WEIGHTS_HF)
111
+
112
+ repo = None
113
+ LOCAL_DIR = "/home/user/data/"
114
+
115
+ if HF_TOKEN and DATASET_REPOSITORY:
116
+ try:
117
+ shutil.rmtree(LOCAL_DIR)
118
+ except Exception:
119
+ pass
120
+
121
+ repo = Repository(
122
+ local_dir=LOCAL_DIR,
123
+ clone_from=DATASET_REPOSITORY,
124
+ use_auth_token=HF_TOKEN,
125
+ repo_type="dataset",
126
+ )
127
+ repo.git_pull()
128
+
129
+ if torch.cuda.is_available():
130
+ device = "cuda"
131
+ else:
132
+ device = "cpu"
133
+
134
+ model, tokenizer = load_lora_model(
135
+ model_path=BASE_MODEL,
136
+ lora_weight=LORA_WEIGHTS,
137
+ device=device,
138
+ num_gpus=1,
139
+ max_gpu_memory="16GiB",
140
+ load_8bit=False,
141
+ cpu_offloading=False,
142
+ debug=False,
143
+ )
144
+
145
+ register_conv_template(
146
+ Conversation(
147
+ name="japanese",
148
+ system="以下はタスクを説明する指示です。要求を適切に満たすような返答を書いてください。\n\n",
149
+ roles=("### 指示", "### 返答"),
150
+ messages=(),
151
+ offset=0,
152
+ sep_style=SeparatorStyle.ADD_COLON_SINGLE,
153
+ sep="\n###",
154
+ stop_str="###",
155
+ )
156
+ )
157
+
158
+
159
+ Conversation._get_prompt = Conversation.get_prompt
160
+ Conversation._append_message = Conversation.append_message
161
+
162
+
163
+ def conversation_append_message(cls, role: str, message: str):
164
+ cls.offset = -2
165
+ return cls._append_message(role, message)
166
+
167
+
168
+ def conversation_get_prompt_overrider(cls: Conversation) -> str:
169
+ cls.messages = cls.messages[-2:]
170
+ return cls._get_prompt()
171
+
172
+
173
+ def save_inputs_and_outputs(now, inputs, outputs, generate_kwargs):
174
+ current_hour = now.strftime("%Y-%m-%d_%H")
175
+ file_name = f"prompts_{LORA_WEIGHTS_HF.split('/')[-1]}_{current_hour}.jsonl"
176
+
177
+ if repo is not None:
178
+ repo.git_pull(rebase=True)
179
+ with open(os.path.join(LOCAL_DIR, file_name), "a", encoding="utf-8") as f:
180
+ json.dump(
181
+ {
182
+ "inputs": inputs,
183
+ "outputs": outputs,
184
+ "generate_kwargs": generate_kwargs,
185
+ },
186
+ f,
187
+ ensure_ascii=False,
188
+ )
189
+ f.write("\n")
190
+ repo.push_to_hub()
191
+
192
+
193
+ # we cant add typing now
194
+ # https://github.com/gradio-app/gradio/issues/3514
195
+ def evaluate(
196
+ instruction,
197
+ temperature=0.7,
198
+ max_tokens=256,
199
+ repetition_penalty=1.0,
200
+ ):
201
+ try:
202
+ conv_template = "japanese"
203
+
204
+ inputs = tokenizer(instruction, return_tensors="pt")
205
+ if len(inputs["input_ids"][0]) > max_tokens - 40:
206
+ if HF_TOKEN and DATASET_REPOSITORY:
207
+ try:
208
+ now = datetime.datetime.now()
209
+ current_time = now.strftime("%Y-%m-%d %H:%M:%S")
210
+ print(f"[{current_time}] Pushing prompt and completion to the Hub")
211
+ save_inputs_and_outputs(
212
+ now,
213
+ instruction,
214
+ "",
215
+ {
216
+ "temperature": temperature,
217
+ "max_tokens": max_tokens,
218
+ "repetition_penalty": repetition_penalty,
219
+ },
220
+ )
221
+ except Exception as e:
222
+ print(e)
223
+ return (
224
+ f"please reduce the input length. Currently, {len(inputs['input_ids'][0])} ( > {max_tokens - 40}) tokens are used.",
225
+ gr.update(interactive=True),
226
+ gr.update(interactive=True),
227
+ )
228
+
229
+ conv = get_conv_template(conv_template)
230
+
231
+ conv.append_message(conv.roles[0], instruction)
232
+ conv.append_message(conv.roles[1], None)
233
+
234
+ generate_stream_func = generate_stream
235
+ prompt = conv.get_prompt()
236
+
237
+ gen_params = {
238
+ "model": BASE_MODEL,
239
+ "prompt": prompt,
240
+ "temperature": temperature,
241
+ "max_new_tokens": max_tokens - len(inputs["input_ids"][0]) - 30,
242
+ "stop": conv.stop_str,
243
+ "stop_token_ids": conv.stop_token_ids,
244
+ "echo": False,
245
+ "repetition_penalty": repetition_penalty,
246
+ }
247
+ chatio = SimpleChatIO()
248
+ chatio.prompt_for_output(conv.roles[1])
249
+ output_stream = generate_stream_func(model, tokenizer, gen_params, device)
250
+ output = chatio.stream_output(output_stream)
251
+
252
+ if HF_TOKEN and DATASET_REPOSITORY:
253
+ try:
254
+ now = datetime.datetime.now()
255
+ current_time = now.strftime("%Y-%m-%d %H:%M:%S")
256
+ print(f"[{current_time}] Pushing prompt and completion to the Hub")
257
+ save_inputs_and_outputs(
258
+ now,
259
+ prompt,
260
+ output,
261
+ {
262
+ "temperature": temperature,
263
+ "max_tokens": max_tokens,
264
+ "repetition_penalty": repetition_penalty,
265
+ },
266
+ )
267
+ except Exception as e:
268
+ print(e)
269
+ return output, gr.update(interactive=True), gr.update(interactive=True)
270
+ except Exception as e:
271
+ print(e)
272
+ import traceback
273
+
274
+ if SLACK_WEBHOOK:
275
+ payload_dic = {
276
+ "text": f"BASE_MODEL: {BASE_MODEL}\n LORA_WEIGHTS: {LORA_WEIGHTS_HF}\n"
277
+ + f"instruction: {instruction}\ninput: {input}\ntemperature: {temperature}\n"
278
+ + f"max_tokens: {max_tokens}\nrepetition_penalty: {repetition_penalty}\n\n"
279
+ + str(traceback.format_exc()),
280
+ "username": "Hugging Face Space",
281
+ "channel": "#monitor",
282
+ }
283
+
284
+ try:
285
+ requests.post(SLACK_WEBHOOK, data=json.dumps(payload_dic))
286
+ except Exception:
287
+ pass
288
+ return (
289
+ "Error happend. Please return later.",
290
+ gr.update(interactive=True),
291
+ gr.update(interactive=True),
292
+ )
293
+
294
+
295
+ def reset_textbox():
296
+ return gr.update(value=""), gr.update(value=""), gr.update(value="")
297
+
298
+
299
+ def no_interactive() -> Tuple[gr.Request, gr.Request]:
300
+ return gr.update(interactive=False), gr.update(interactive=False)
301
+
302
+
303
+ title = """<h1 align="center">stormy 7B 10epochs</h1>"""
304
+
305
+ theme = gr.themes.Default(primary_hue="green")
306
+ description = (
307
+ "The official demo for **[izumi-lab/stormy-7b-10ep](https://huggingface.co/izumi-lab/izumi-lab/stormy-7b-10ep)**. "
308
+ "It is a 7B-parameter CALM model finetuned to follow instructions. "
309
+ "It is trained on the dataset specially extracted from [izumi-lab/llm-japanese-dataset](https://huggingface.co/datasets/izumi-lab/llm-japanese-dataset) dataset. "
310
+ "For more information, please visit [the project's website](https://llm.msuzuki.me). "
311
+ "This model can output up to 256 tokens. "
312
+ "It takes about **1 minute** to output. When access is concentrated, the operation may become slow."
313
+ )
314
+ with gr.Blocks(
315
+ css="""#col_container { margin-left: auto; margin-right: auto;}""",
316
+ theme=theme,
317
+ ) as demo:
318
+ gr.HTML(title)
319
+ gr.Markdown(description)
320
+ with gr.Column(elem_id="col_container", visible=False) as main_block:
321
+ with gr.Row():
322
+ with gr.Column():
323
+ instruction = gr.Textbox(
324
+ lines=3, label="Instruction", placeholder="こんにちは"
325
+ )
326
+ with gr.Row():
327
+ with gr.Column(scale=3):
328
+ clear_button = gr.Button("Clear").style(full_width=True)
329
+ with gr.Column(scale=5):
330
+ submit_button = gr.Button("Submit").style(full_width=True)
331
+ outputs = gr.Textbox(lines=4, label="Output")
332
+
333
+ # inputs, top_p, temperature, top_k, repetition_penalty
334
+ with gr.Accordion("Parameters", open=True):
335
+ temperature = gr.Slider(
336
+ minimum=0,
337
+ maximum=1.0,
338
+ value=0.0,
339
+ step=0.05,
340
+ interactive=True,
341
+ label="Temperature",
342
+ )
343
+ max_tokens = gr.Slider(
344
+ minimum=20,
345
+ maximum=256,
346
+ value=128,
347
+ step=1,
348
+ interactive=True,
349
+ label="Max length (Pre-prompt + instruction + input + output)",
350
+ )
351
+ repetition_penalty = gr.Slider(
352
+ minimum=0.0,
353
+ maximum=5.0,
354
+ value=1.0,
355
+ step=0.1,
356
+ interactive=True,
357
+ label="Repetition penalty",
358
+ )
359
+
360
+ with gr.Column(elem_id="user_consent_container") as user_consent_block:
361
+ # Get user consent
362
+ gr.Markdown(
363
+ """
364
+ ## User Consent for Data Collection, Use, and Sharing:
365
+ By using our app, you acknowledge and agree to the following terms regarding the data you provide:
366
+ - **Collection**: We may collect inputs you type into our app.
367
+ - **Use**: We may use the collected data for research purposes, to improve our services, and to develop new products or services, including commercial applications.
368
+ - **Sharing and Publication**: Your input data may be published, shared with third parties, or used for analysis and reporting purposes.
369
+ - **Data Retention**: We may retain your input data for as long as necessary.
370
+
371
+ By continuing to use our app, you provide your explicit consent to the collection, use, and potential sharing of your data as described above. If you do not agree with our data collection, use, and sharing practices, please do not use our app.
372
+
373
+ ## データ収集、利用、共有に関するユーザーの同意:
374
+ 本アプリを使用することにより、提供するデータに関する以下の条件に同意するものとします:
375
+ - **収集**: 本アプリに入力されるテキストデータは収集される場合があります。
376
+ - **利用**: 収集されたデータは研究や、商用アプリケーションを含むサービスの開発に使用される場合があります。
377
+ - **共有および公開**: 入力データは第三者と共有されたり、分析や公開の目的で使用される場合があります。
378
+ - **データ保持**: 入力データは必要な限り保持されます。
379
+
380
+ 本アプリを引き続き使用することにより、上記のようにデータの収集・利用・共有について同意します。データの利用方法に同意しない場合は、本アプリを使用しないでください。
381
+ """
382
+ )
383
+ accept_button = gr.Button("I Agree")
384
+
385
+ def enable_inputs():
386
+ return user_consent_block.update(visible=False), main_block.update(
387
+ visible=True
388
+ )
389
+
390
+ accept_button.click(
391
+ fn=enable_inputs,
392
+ inputs=[],
393
+ outputs=[user_consent_block, main_block],
394
+ queue=False,
395
+ )
396
+ submit_button.click(no_interactive, [], [submit_button, clear_button])
397
+ submit_button.click(
398
+ evaluate,
399
+ [instruction, temperature, max_tokens, repetition_penalty],
400
+ [outputs, submit_button, clear_button],
401
+ )
402
+ clear_button.click(reset_textbox, [], [instruction, outputs], queue=False)
403
+
404
+ demo.queue(max_size=20, concurrency_count=NUM_THREADS, api_open=False).launch(
405
+ server_name="0.0.0.0", server_port=7860
406
+ )
model_pull.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from peft import PeftModel
3
+ from transformers import AutoModelForCausalLM
4
+ from transformers import AutoTokenizer
5
+
6
+ BASE_MODEL = "cyberagent/open-calm-7b"
7
+ LORA_WEIGHTS = "izumi-lab/stormy-7b-5ep"
8
+
9
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ BASE_MODEL,
12
+ load_in_8bit=False,
13
+ torch_dtype=torch.float16,
14
+ device_map="auto",
15
+ )
16
+ model = PeftModel.from_pretrained(
17
+ model, LORA_WEIGHTS, torch_dtype=torch.float16, use_auth_token=True
18
+ )
pyproject.toml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "space-stormy-7b-10ep"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = ["Masanori HIRANO <masa.hirano.1996@gmail.com>"]
6
+ license = "MIT"
7
+ readme = "README.md"
8
+
9
+ [tool.poetry.dependencies]
10
+ python = "^3.8.1"
11
+ peft = "^0.3.0"
12
+ transformers = "4.28.1"
13
+ gradio = "3.23.0"
14
+ torch = "^2.0.1"
15
+ huggingface-hub = "^0.14.1"
16
+ bitsandbytes = "^0.38.1"
17
+ accelerate = "^0.19.0"
18
+ fschat = "0.2.8"
19
+ sentencepiece = "^0.1.99"
20
+
21
+ [tool.poetry.group.dev.dependencies]
22
+ black = "^23.3.0"
23
+ isort = "^5.12.0"
24
+ mypy = "^1.3.0"
25
+ flake8 = "^6.0.0"
26
+ pyproject-flake8 = "^6.0.0.post1"
27
+
28
+ [build-system]
29
+ requires = ["poetry-core"]
30
+ build-backend = "poetry.core.masonry.api"
31
+
32
+ [tool.isort]
33
+ profile = 'black'
34
+ force_single_line = true
35
+ skip = [
36
+ ".git",
37
+ "__pycache__",
38
+ "docs",
39
+ "build",
40
+ "dist",
41
+ "examples",
42
+ ".venv",
43
+ "tests/examples"
44
+ ]
45
+
46
+ [tool.mypy]
47
+ disallow_untyped_defs = true
48
+ ignore_missing_imports = true
49
+
50
+ [tool.flake8]
51
+ ignore = "E203,E231,E501,W503"
52
+ max-line-length = 88
53
+ exclude = [
54
+ ".git",
55
+ "__pycache__",
56
+ "docs",
57
+ "build",
58
+ "dist",
59
+ "examples",
60
+ ".venv",
61
+ "__init__.py"
62
+ ]
63
+ select = "B,B950,C,E,F,W"