Spaces:
Paused
Paused
masanorihirano
commited on
Commit
•
dfb3b68
1
Parent(s):
90ca338
added
Browse files- .gitignore +165 -0
- Dockerfile +47 -0
- Makefile +35 -0
- README.md +8 -4
- app.py +406 -0
- model_pull.py +18 -0
- pyproject.toml +63 -0
.gitignore
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
secret.txt
|
2 |
+
slack_url.txt
|
3 |
+
.idea
|
4 |
+
.env
|
5 |
+
poetry.lock
|
6 |
+
# Byte-compiled / optimized / DLL files
|
7 |
+
__pycache__/
|
8 |
+
*.py[cod]
|
9 |
+
*$py.class
|
10 |
+
|
11 |
+
# C extensions
|
12 |
+
*.so
|
13 |
+
|
14 |
+
# Distribution / packaging
|
15 |
+
.Python
|
16 |
+
build/
|
17 |
+
develop-eggs/
|
18 |
+
dist/
|
19 |
+
downloads/
|
20 |
+
eggs/
|
21 |
+
.eggs/
|
22 |
+
lib/
|
23 |
+
lib64/
|
24 |
+
parts/
|
25 |
+
sdist/
|
26 |
+
var/
|
27 |
+
wheels/
|
28 |
+
share/python-wheels/
|
29 |
+
*.egg-info/
|
30 |
+
.installed.cfg
|
31 |
+
*.egg
|
32 |
+
MANIFEST
|
33 |
+
|
34 |
+
# PyInstaller
|
35 |
+
# Usually these files are written by a python script from a template
|
36 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
37 |
+
*.manifest
|
38 |
+
*.spec
|
39 |
+
|
40 |
+
# Installer logs
|
41 |
+
pip-log.txt
|
42 |
+
pip-delete-this-directory.txt
|
43 |
+
|
44 |
+
# Unit test / coverage reports
|
45 |
+
htmlcov/
|
46 |
+
.tox/
|
47 |
+
.nox/
|
48 |
+
.coverage
|
49 |
+
.coverage.*
|
50 |
+
.cache
|
51 |
+
nosetests.xml
|
52 |
+
coverage.xml
|
53 |
+
*.cover
|
54 |
+
*.py,cover
|
55 |
+
.hypothesis/
|
56 |
+
.pytest_cache/
|
57 |
+
cover/
|
58 |
+
|
59 |
+
# Translations
|
60 |
+
*.mo
|
61 |
+
*.pot
|
62 |
+
|
63 |
+
# Django stuff:
|
64 |
+
*.log
|
65 |
+
local_settings.py
|
66 |
+
db.sqlite3
|
67 |
+
db.sqlite3-journal
|
68 |
+
|
69 |
+
# Flask stuff:
|
70 |
+
instance/
|
71 |
+
.webassets-cache
|
72 |
+
|
73 |
+
# Scrapy stuff:
|
74 |
+
.scrapy
|
75 |
+
|
76 |
+
# Sphinx documentation
|
77 |
+
docs/_build/
|
78 |
+
|
79 |
+
# PyBuilder
|
80 |
+
.pybuilder/
|
81 |
+
target/
|
82 |
+
|
83 |
+
# Jupyter Notebook
|
84 |
+
.ipynb_checkpoints
|
85 |
+
|
86 |
+
# IPython
|
87 |
+
profile_default/
|
88 |
+
ipython_config.py
|
89 |
+
|
90 |
+
# pyenv
|
91 |
+
# For a library or package, you might want to ignore these files since the code is
|
92 |
+
# intended to run in multiple environments; otherwise, check them in:
|
93 |
+
# .python-version
|
94 |
+
|
95 |
+
# pipenv
|
96 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
97 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
98 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
99 |
+
# install all needed dependencies.
|
100 |
+
#Pipfile.lock
|
101 |
+
|
102 |
+
# poetry
|
103 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
104 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
105 |
+
# commonly ignored for libraries.
|
106 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
107 |
+
#poetry.lock
|
108 |
+
|
109 |
+
# pdm
|
110 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
111 |
+
#pdm.lock
|
112 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
113 |
+
# in version control.
|
114 |
+
# https://pdm.fming.dev/#use-with-ide
|
115 |
+
.pdm.toml
|
116 |
+
|
117 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
118 |
+
__pypackages__/
|
119 |
+
|
120 |
+
# Celery stuff
|
121 |
+
celerybeat-schedule
|
122 |
+
celerybeat.pid
|
123 |
+
|
124 |
+
# SageMath parsed files
|
125 |
+
*.sage.py
|
126 |
+
|
127 |
+
# Environments
|
128 |
+
.env
|
129 |
+
.venv
|
130 |
+
env/
|
131 |
+
venv/
|
132 |
+
ENV/
|
133 |
+
env.bak/
|
134 |
+
venv.bak/
|
135 |
+
|
136 |
+
# Spyder project settings
|
137 |
+
.spyderproject
|
138 |
+
.spyproject
|
139 |
+
|
140 |
+
# Rope project settings
|
141 |
+
.ropeproject
|
142 |
+
|
143 |
+
# mkdocs documentation
|
144 |
+
/site
|
145 |
+
|
146 |
+
# mypy
|
147 |
+
.mypy_cache/
|
148 |
+
.dmypy.json
|
149 |
+
dmypy.json
|
150 |
+
|
151 |
+
# Pyre type checker
|
152 |
+
.pyre/
|
153 |
+
|
154 |
+
# pytype static type analyzer
|
155 |
+
.pytype/
|
156 |
+
|
157 |
+
# Cython debug symbols
|
158 |
+
cython_debug/
|
159 |
+
|
160 |
+
# PyCharm
|
161 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
162 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
163 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
164 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
165 |
+
#.idea/
|
Dockerfile
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# syntax=docker/dockerfile:1.4
|
2 |
+
FROM docker.io/nvidia/cuda:11.7.1-cudnn8-devel-ubuntu20.04
|
3 |
+
ENV TZ=Asia/Tokyo
|
4 |
+
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
|
5 |
+
RUN sed -i 's http://deb.debian.org http://cdn-aws.deb.debian.org g' /etc/apt/sources.list && \
|
6 |
+
sed -i 's http://archive.ubuntu.com http://us-east-1.ec2.archive.ubuntu.com g' /etc/apt/sources.list && \
|
7 |
+
sed -i '/security/d' /etc/apt/sources.list && apt-get update && \
|
8 |
+
apt-get install -y \
|
9 |
+
git \
|
10 |
+
make build-essential libssl-dev zlib1g-dev \
|
11 |
+
libbz2-dev libreadline-dev libsqlite3-dev wget curl llvm \
|
12 |
+
libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev libffi-dev liblzma-dev git-lfs \
|
13 |
+
ffmpeg libsm6 libxext6 cmake libgl1-mesa-glx \
|
14 |
+
python3.9-dev cuda-cudart-11-7 && \
|
15 |
+
rm -rf /var/lib/apt/lists/* && \
|
16 |
+
git lfs install
|
17 |
+
|
18 |
+
RUN useradd -m -u 1000 user
|
19 |
+
USER user
|
20 |
+
|
21 |
+
RUN curl https://pyenv.run | bash
|
22 |
+
ENV PYENV_ROOT /home/user/.pyenv
|
23 |
+
ENV PATH ${PYENV_ROOT}/shims:${PYENV_ROOT}/bin:${PATH}
|
24 |
+
RUN eval "$(pyenv init -)" && \
|
25 |
+
eval "$(pyenv virtualenv-init -)" && \
|
26 |
+
pyenv install 3.9.7 && \
|
27 |
+
pyenv global 3.9.7 && \
|
28 |
+
pyenv rehash && \
|
29 |
+
pip install --no-cache-dir --upgrade pip==22.3.1 setuptools wheel && \
|
30 |
+
pip install --no-cache-dir datasets "huggingface-hub>=0.12.1" "protobuf<4" "click<8.1" && \
|
31 |
+
curl -sSL https://install.python-poetry.org | python -
|
32 |
+
ENV PATH /home/user/.local/bin:${PATH}
|
33 |
+
|
34 |
+
COPY --link --chown=1000 ./pyproject.toml /home/user/app/pyproject.toml
|
35 |
+
COPY --link --chown=1000 ./model_pull.py /home/user/app/model_pull.py
|
36 |
+
WORKDIR /home/user/app
|
37 |
+
|
38 |
+
RUN poetry install
|
39 |
+
RUN --mount=type=secret,id=HF_TOKEN,mode=0444,required=true \
|
40 |
+
git config --global credential.helper store && \
|
41 |
+
huggingface-cli login --token $(cat /run/secrets/HF_TOKEN) --add-to-git-credential
|
42 |
+
RUN poetry run python model_pull.py
|
43 |
+
|
44 |
+
COPY --link --chown=1000 ./app.py /home/user/app/app.py
|
45 |
+
|
46 |
+
EXPOSE 7860
|
47 |
+
ENTRYPOINT ["/home/user/.local/bin/poetry", "run", "python", "app.py"]
|
Makefile
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
RUN := poetry run
|
3 |
+
|
4 |
+
.PHONY: check
|
5 |
+
check: lint mypy
|
6 |
+
|
7 |
+
.PHONY: lint
|
8 |
+
lint: lint-black lint-isort lint-flake8
|
9 |
+
|
10 |
+
.PHONY: lint-black
|
11 |
+
lint-black:
|
12 |
+
$(RUN) black --check --diff --quiet .
|
13 |
+
|
14 |
+
.PHONY: lint-isort
|
15 |
+
lint-isort:
|
16 |
+
$(RUN) isort --check --quiet .
|
17 |
+
|
18 |
+
.PHONY: lint-flake8
|
19 |
+
lint-flake8:
|
20 |
+
$(RUN) pflake8 .
|
21 |
+
|
22 |
+
.PHONY: mypy
|
23 |
+
mypy:
|
24 |
+
$(RUN) mypy .
|
25 |
+
|
26 |
+
.PHONY: format
|
27 |
+
format: format-black format-isort
|
28 |
+
|
29 |
+
.PHONY: format-black
|
30 |
+
format-black:
|
31 |
+
$(RUN) black --quiet .
|
32 |
+
|
33 |
+
.PHONY: format-isort
|
34 |
+
format-isort:
|
35 |
+
$(RUN) isort --quiet .
|
README.md
CHANGED
@@ -1,11 +1,15 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: docker
|
|
|
7 |
pinned: false
|
8 |
license: mit
|
|
|
|
|
|
|
9 |
---
|
10 |
|
11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: stormy 7b 10epochs
|
3 |
+
emoji: 💨
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: blue
|
6 |
sdk: docker
|
7 |
+
app_port: 7860
|
8 |
pinned: false
|
9 |
license: mit
|
10 |
+
models:
|
11 |
+
- cyberagent/open-calm-7b
|
12 |
+
- izumi-lab/izumi-lab/stormy-7b-10ep
|
13 |
---
|
14 |
|
15 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
from typing import Optional
|
6 |
+
from typing import Tuple
|
7 |
+
from typing import Union
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
import requests
|
11 |
+
import torch
|
12 |
+
from fastchat.conversation import Conversation
|
13 |
+
from fastchat.conversation import SeparatorStyle
|
14 |
+
from fastchat.conversation import get_conv_template
|
15 |
+
from fastchat.conversation import register_conv_template
|
16 |
+
from fastchat.model.model_adapter import BaseAdapter
|
17 |
+
from fastchat.model.model_adapter import load_model
|
18 |
+
from fastchat.model.model_adapter import model_adapters
|
19 |
+
from fastchat.serve.cli import SimpleChatIO
|
20 |
+
from fastchat.serve.inference import generate_stream
|
21 |
+
from huggingface_hub import Repository
|
22 |
+
from huggingface_hub import snapshot_download
|
23 |
+
from peft import LoraConfig
|
24 |
+
from peft import PeftModel
|
25 |
+
from peft import get_peft_model
|
26 |
+
from peft import set_peft_model_state_dict
|
27 |
+
from transformers import AutoModelForCausalLM
|
28 |
+
from transformers import AutoTokenizer
|
29 |
+
from transformers import PreTrainedModel
|
30 |
+
from transformers import PreTrainedTokenizerBase
|
31 |
+
|
32 |
+
|
33 |
+
class FastTokenizerAvailableBaseAdapter(BaseAdapter):
|
34 |
+
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
|
35 |
+
try:
|
36 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
37 |
+
except ValueError:
|
38 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
39 |
+
model = AutoModelForCausalLM.from_pretrained(
|
40 |
+
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
|
41 |
+
)
|
42 |
+
return model, tokenizer
|
43 |
+
|
44 |
+
|
45 |
+
model_adapters[-1] = FastTokenizerAvailableBaseAdapter()
|
46 |
+
|
47 |
+
|
48 |
+
def load_lora_model(
|
49 |
+
model_path: str,
|
50 |
+
lora_weight: str,
|
51 |
+
device: str,
|
52 |
+
num_gpus: int,
|
53 |
+
max_gpu_memory: Optional[str] = None,
|
54 |
+
load_8bit: bool = False,
|
55 |
+
cpu_offloading: bool = False,
|
56 |
+
debug: bool = False,
|
57 |
+
) -> Tuple[Union[PreTrainedModel, PeftModel], PreTrainedTokenizerBase]:
|
58 |
+
model: Union[PreTrainedModel, PeftModel]
|
59 |
+
tokenizer: PreTrainedTokenizerBase
|
60 |
+
model, tokenizer = load_model(
|
61 |
+
model_path=model_path,
|
62 |
+
device=device,
|
63 |
+
num_gpus=num_gpus,
|
64 |
+
max_gpu_memory=max_gpu_memory,
|
65 |
+
load_8bit=load_8bit,
|
66 |
+
cpu_offloading=cpu_offloading,
|
67 |
+
debug=debug,
|
68 |
+
)
|
69 |
+
if lora_weight is not None:
|
70 |
+
# model = PeftModelForCausalLM.from_pretrained(model, model_path, **kwargs)
|
71 |
+
config = LoraConfig.from_pretrained(lora_weight)
|
72 |
+
model = get_peft_model(model, config)
|
73 |
+
|
74 |
+
# Check the available weights and load them
|
75 |
+
checkpoint_name = os.path.join(
|
76 |
+
lora_weight, "pytorch_model.bin"
|
77 |
+
) # Full checkpoint
|
78 |
+
if not os.path.exists(checkpoint_name):
|
79 |
+
checkpoint_name = os.path.join(
|
80 |
+
lora_weight, "adapter_model.bin"
|
81 |
+
) # only LoRA model - LoRA config above has to fit
|
82 |
+
# The two files above have a different name depending on how they were saved,
|
83 |
+
# but are actually the same.
|
84 |
+
if os.path.exists(checkpoint_name):
|
85 |
+
adapters_weights = torch.load(checkpoint_name)
|
86 |
+
set_peft_model_state_dict(model, adapters_weights)
|
87 |
+
else:
|
88 |
+
raise IOError(f"Checkpoint {checkpoint_name} not found")
|
89 |
+
|
90 |
+
if debug:
|
91 |
+
print(model)
|
92 |
+
|
93 |
+
return model, tokenizer
|
94 |
+
|
95 |
+
|
96 |
+
print(datetime.datetime.now())
|
97 |
+
|
98 |
+
NUM_THREADS = 1
|
99 |
+
|
100 |
+
print(NUM_THREADS)
|
101 |
+
|
102 |
+
print("starting server ...")
|
103 |
+
|
104 |
+
BASE_MODEL = "cyberagent/open-calm-7b"
|
105 |
+
LORA_WEIGHTS_HF = "izumi-lab/stormy-7b-5ep"
|
106 |
+
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
107 |
+
DATASET_REPOSITORY = os.environ.get("DATASET_REPOSITORY", None)
|
108 |
+
SLACK_WEBHOOK = os.environ.get("SLACK_WEBHOOK", None)
|
109 |
+
|
110 |
+
LORA_WEIGHTS = snapshot_download(LORA_WEIGHTS_HF)
|
111 |
+
|
112 |
+
repo = None
|
113 |
+
LOCAL_DIR = "/home/user/data/"
|
114 |
+
|
115 |
+
if HF_TOKEN and DATASET_REPOSITORY:
|
116 |
+
try:
|
117 |
+
shutil.rmtree(LOCAL_DIR)
|
118 |
+
except Exception:
|
119 |
+
pass
|
120 |
+
|
121 |
+
repo = Repository(
|
122 |
+
local_dir=LOCAL_DIR,
|
123 |
+
clone_from=DATASET_REPOSITORY,
|
124 |
+
use_auth_token=HF_TOKEN,
|
125 |
+
repo_type="dataset",
|
126 |
+
)
|
127 |
+
repo.git_pull()
|
128 |
+
|
129 |
+
if torch.cuda.is_available():
|
130 |
+
device = "cuda"
|
131 |
+
else:
|
132 |
+
device = "cpu"
|
133 |
+
|
134 |
+
model, tokenizer = load_lora_model(
|
135 |
+
model_path=BASE_MODEL,
|
136 |
+
lora_weight=LORA_WEIGHTS,
|
137 |
+
device=device,
|
138 |
+
num_gpus=1,
|
139 |
+
max_gpu_memory="16GiB",
|
140 |
+
load_8bit=False,
|
141 |
+
cpu_offloading=False,
|
142 |
+
debug=False,
|
143 |
+
)
|
144 |
+
|
145 |
+
register_conv_template(
|
146 |
+
Conversation(
|
147 |
+
name="japanese",
|
148 |
+
system="以下はタスクを説明する指示です。要求を適切に満たすような返答を書いてください。\n\n",
|
149 |
+
roles=("### 指示", "### 返答"),
|
150 |
+
messages=(),
|
151 |
+
offset=0,
|
152 |
+
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
|
153 |
+
sep="\n###",
|
154 |
+
stop_str="###",
|
155 |
+
)
|
156 |
+
)
|
157 |
+
|
158 |
+
|
159 |
+
Conversation._get_prompt = Conversation.get_prompt
|
160 |
+
Conversation._append_message = Conversation.append_message
|
161 |
+
|
162 |
+
|
163 |
+
def conversation_append_message(cls, role: str, message: str):
|
164 |
+
cls.offset = -2
|
165 |
+
return cls._append_message(role, message)
|
166 |
+
|
167 |
+
|
168 |
+
def conversation_get_prompt_overrider(cls: Conversation) -> str:
|
169 |
+
cls.messages = cls.messages[-2:]
|
170 |
+
return cls._get_prompt()
|
171 |
+
|
172 |
+
|
173 |
+
def save_inputs_and_outputs(now, inputs, outputs, generate_kwargs):
|
174 |
+
current_hour = now.strftime("%Y-%m-%d_%H")
|
175 |
+
file_name = f"prompts_{LORA_WEIGHTS_HF.split('/')[-1]}_{current_hour}.jsonl"
|
176 |
+
|
177 |
+
if repo is not None:
|
178 |
+
repo.git_pull(rebase=True)
|
179 |
+
with open(os.path.join(LOCAL_DIR, file_name), "a", encoding="utf-8") as f:
|
180 |
+
json.dump(
|
181 |
+
{
|
182 |
+
"inputs": inputs,
|
183 |
+
"outputs": outputs,
|
184 |
+
"generate_kwargs": generate_kwargs,
|
185 |
+
},
|
186 |
+
f,
|
187 |
+
ensure_ascii=False,
|
188 |
+
)
|
189 |
+
f.write("\n")
|
190 |
+
repo.push_to_hub()
|
191 |
+
|
192 |
+
|
193 |
+
# we cant add typing now
|
194 |
+
# https://github.com/gradio-app/gradio/issues/3514
|
195 |
+
def evaluate(
|
196 |
+
instruction,
|
197 |
+
temperature=0.7,
|
198 |
+
max_tokens=256,
|
199 |
+
repetition_penalty=1.0,
|
200 |
+
):
|
201 |
+
try:
|
202 |
+
conv_template = "japanese"
|
203 |
+
|
204 |
+
inputs = tokenizer(instruction, return_tensors="pt")
|
205 |
+
if len(inputs["input_ids"][0]) > max_tokens - 40:
|
206 |
+
if HF_TOKEN and DATASET_REPOSITORY:
|
207 |
+
try:
|
208 |
+
now = datetime.datetime.now()
|
209 |
+
current_time = now.strftime("%Y-%m-%d %H:%M:%S")
|
210 |
+
print(f"[{current_time}] Pushing prompt and completion to the Hub")
|
211 |
+
save_inputs_and_outputs(
|
212 |
+
now,
|
213 |
+
instruction,
|
214 |
+
"",
|
215 |
+
{
|
216 |
+
"temperature": temperature,
|
217 |
+
"max_tokens": max_tokens,
|
218 |
+
"repetition_penalty": repetition_penalty,
|
219 |
+
},
|
220 |
+
)
|
221 |
+
except Exception as e:
|
222 |
+
print(e)
|
223 |
+
return (
|
224 |
+
f"please reduce the input length. Currently, {len(inputs['input_ids'][0])} ( > {max_tokens - 40}) tokens are used.",
|
225 |
+
gr.update(interactive=True),
|
226 |
+
gr.update(interactive=True),
|
227 |
+
)
|
228 |
+
|
229 |
+
conv = get_conv_template(conv_template)
|
230 |
+
|
231 |
+
conv.append_message(conv.roles[0], instruction)
|
232 |
+
conv.append_message(conv.roles[1], None)
|
233 |
+
|
234 |
+
generate_stream_func = generate_stream
|
235 |
+
prompt = conv.get_prompt()
|
236 |
+
|
237 |
+
gen_params = {
|
238 |
+
"model": BASE_MODEL,
|
239 |
+
"prompt": prompt,
|
240 |
+
"temperature": temperature,
|
241 |
+
"max_new_tokens": max_tokens - len(inputs["input_ids"][0]) - 30,
|
242 |
+
"stop": conv.stop_str,
|
243 |
+
"stop_token_ids": conv.stop_token_ids,
|
244 |
+
"echo": False,
|
245 |
+
"repetition_penalty": repetition_penalty,
|
246 |
+
}
|
247 |
+
chatio = SimpleChatIO()
|
248 |
+
chatio.prompt_for_output(conv.roles[1])
|
249 |
+
output_stream = generate_stream_func(model, tokenizer, gen_params, device)
|
250 |
+
output = chatio.stream_output(output_stream)
|
251 |
+
|
252 |
+
if HF_TOKEN and DATASET_REPOSITORY:
|
253 |
+
try:
|
254 |
+
now = datetime.datetime.now()
|
255 |
+
current_time = now.strftime("%Y-%m-%d %H:%M:%S")
|
256 |
+
print(f"[{current_time}] Pushing prompt and completion to the Hub")
|
257 |
+
save_inputs_and_outputs(
|
258 |
+
now,
|
259 |
+
prompt,
|
260 |
+
output,
|
261 |
+
{
|
262 |
+
"temperature": temperature,
|
263 |
+
"max_tokens": max_tokens,
|
264 |
+
"repetition_penalty": repetition_penalty,
|
265 |
+
},
|
266 |
+
)
|
267 |
+
except Exception as e:
|
268 |
+
print(e)
|
269 |
+
return output, gr.update(interactive=True), gr.update(interactive=True)
|
270 |
+
except Exception as e:
|
271 |
+
print(e)
|
272 |
+
import traceback
|
273 |
+
|
274 |
+
if SLACK_WEBHOOK:
|
275 |
+
payload_dic = {
|
276 |
+
"text": f"BASE_MODEL: {BASE_MODEL}\n LORA_WEIGHTS: {LORA_WEIGHTS_HF}\n"
|
277 |
+
+ f"instruction: {instruction}\ninput: {input}\ntemperature: {temperature}\n"
|
278 |
+
+ f"max_tokens: {max_tokens}\nrepetition_penalty: {repetition_penalty}\n\n"
|
279 |
+
+ str(traceback.format_exc()),
|
280 |
+
"username": "Hugging Face Space",
|
281 |
+
"channel": "#monitor",
|
282 |
+
}
|
283 |
+
|
284 |
+
try:
|
285 |
+
requests.post(SLACK_WEBHOOK, data=json.dumps(payload_dic))
|
286 |
+
except Exception:
|
287 |
+
pass
|
288 |
+
return (
|
289 |
+
"Error happend. Please return later.",
|
290 |
+
gr.update(interactive=True),
|
291 |
+
gr.update(interactive=True),
|
292 |
+
)
|
293 |
+
|
294 |
+
|
295 |
+
def reset_textbox():
|
296 |
+
return gr.update(value=""), gr.update(value=""), gr.update(value="")
|
297 |
+
|
298 |
+
|
299 |
+
def no_interactive() -> Tuple[gr.Request, gr.Request]:
|
300 |
+
return gr.update(interactive=False), gr.update(interactive=False)
|
301 |
+
|
302 |
+
|
303 |
+
title = """<h1 align="center">stormy 7B 10epochs</h1>"""
|
304 |
+
|
305 |
+
theme = gr.themes.Default(primary_hue="green")
|
306 |
+
description = (
|
307 |
+
"The official demo for **[izumi-lab/stormy-7b-10ep](https://huggingface.co/izumi-lab/izumi-lab/stormy-7b-10ep)**. "
|
308 |
+
"It is a 7B-parameter CALM model finetuned to follow instructions. "
|
309 |
+
"It is trained on the dataset specially extracted from [izumi-lab/llm-japanese-dataset](https://huggingface.co/datasets/izumi-lab/llm-japanese-dataset) dataset. "
|
310 |
+
"For more information, please visit [the project's website](https://llm.msuzuki.me). "
|
311 |
+
"This model can output up to 256 tokens. "
|
312 |
+
"It takes about **1 minute** to output. When access is concentrated, the operation may become slow."
|
313 |
+
)
|
314 |
+
with gr.Blocks(
|
315 |
+
css="""#col_container { margin-left: auto; margin-right: auto;}""",
|
316 |
+
theme=theme,
|
317 |
+
) as demo:
|
318 |
+
gr.HTML(title)
|
319 |
+
gr.Markdown(description)
|
320 |
+
with gr.Column(elem_id="col_container", visible=False) as main_block:
|
321 |
+
with gr.Row():
|
322 |
+
with gr.Column():
|
323 |
+
instruction = gr.Textbox(
|
324 |
+
lines=3, label="Instruction", placeholder="こんにちは"
|
325 |
+
)
|
326 |
+
with gr.Row():
|
327 |
+
with gr.Column(scale=3):
|
328 |
+
clear_button = gr.Button("Clear").style(full_width=True)
|
329 |
+
with gr.Column(scale=5):
|
330 |
+
submit_button = gr.Button("Submit").style(full_width=True)
|
331 |
+
outputs = gr.Textbox(lines=4, label="Output")
|
332 |
+
|
333 |
+
# inputs, top_p, temperature, top_k, repetition_penalty
|
334 |
+
with gr.Accordion("Parameters", open=True):
|
335 |
+
temperature = gr.Slider(
|
336 |
+
minimum=0,
|
337 |
+
maximum=1.0,
|
338 |
+
value=0.0,
|
339 |
+
step=0.05,
|
340 |
+
interactive=True,
|
341 |
+
label="Temperature",
|
342 |
+
)
|
343 |
+
max_tokens = gr.Slider(
|
344 |
+
minimum=20,
|
345 |
+
maximum=256,
|
346 |
+
value=128,
|
347 |
+
step=1,
|
348 |
+
interactive=True,
|
349 |
+
label="Max length (Pre-prompt + instruction + input + output)",
|
350 |
+
)
|
351 |
+
repetition_penalty = gr.Slider(
|
352 |
+
minimum=0.0,
|
353 |
+
maximum=5.0,
|
354 |
+
value=1.0,
|
355 |
+
step=0.1,
|
356 |
+
interactive=True,
|
357 |
+
label="Repetition penalty",
|
358 |
+
)
|
359 |
+
|
360 |
+
with gr.Column(elem_id="user_consent_container") as user_consent_block:
|
361 |
+
# Get user consent
|
362 |
+
gr.Markdown(
|
363 |
+
"""
|
364 |
+
## User Consent for Data Collection, Use, and Sharing:
|
365 |
+
By using our app, you acknowledge and agree to the following terms regarding the data you provide:
|
366 |
+
- **Collection**: We may collect inputs you type into our app.
|
367 |
+
- **Use**: We may use the collected data for research purposes, to improve our services, and to develop new products or services, including commercial applications.
|
368 |
+
- **Sharing and Publication**: Your input data may be published, shared with third parties, or used for analysis and reporting purposes.
|
369 |
+
- **Data Retention**: We may retain your input data for as long as necessary.
|
370 |
+
|
371 |
+
By continuing to use our app, you provide your explicit consent to the collection, use, and potential sharing of your data as described above. If you do not agree with our data collection, use, and sharing practices, please do not use our app.
|
372 |
+
|
373 |
+
## データ収集、利用、共有に関するユーザーの同意:
|
374 |
+
本アプリを使用することにより、提供するデータに関する以下の条件に同意するものとします:
|
375 |
+
- **収集**: 本アプリに入力されるテキストデータは収集される場合があります。
|
376 |
+
- **利用**: 収集されたデータは研究や、商用アプリケーションを含むサービスの開発に使用される場合があります。
|
377 |
+
- **共有および公開**: 入力データは第三者と共有されたり、分析や公開の目的で使用される場合があります。
|
378 |
+
- **データ保持**: 入力データは必要な限り保持されます。
|
379 |
+
|
380 |
+
本アプリを引き続き使用することにより、上記のようにデータの収集・利用・共有について同意します。データの利用方法に同意しない場合は、本アプリを使用しないでください。
|
381 |
+
"""
|
382 |
+
)
|
383 |
+
accept_button = gr.Button("I Agree")
|
384 |
+
|
385 |
+
def enable_inputs():
|
386 |
+
return user_consent_block.update(visible=False), main_block.update(
|
387 |
+
visible=True
|
388 |
+
)
|
389 |
+
|
390 |
+
accept_button.click(
|
391 |
+
fn=enable_inputs,
|
392 |
+
inputs=[],
|
393 |
+
outputs=[user_consent_block, main_block],
|
394 |
+
queue=False,
|
395 |
+
)
|
396 |
+
submit_button.click(no_interactive, [], [submit_button, clear_button])
|
397 |
+
submit_button.click(
|
398 |
+
evaluate,
|
399 |
+
[instruction, temperature, max_tokens, repetition_penalty],
|
400 |
+
[outputs, submit_button, clear_button],
|
401 |
+
)
|
402 |
+
clear_button.click(reset_textbox, [], [instruction, outputs], queue=False)
|
403 |
+
|
404 |
+
demo.queue(max_size=20, concurrency_count=NUM_THREADS, api_open=False).launch(
|
405 |
+
server_name="0.0.0.0", server_port=7860
|
406 |
+
)
|
model_pull.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from peft import PeftModel
|
3 |
+
from transformers import AutoModelForCausalLM
|
4 |
+
from transformers import AutoTokenizer
|
5 |
+
|
6 |
+
BASE_MODEL = "cyberagent/open-calm-7b"
|
7 |
+
LORA_WEIGHTS = "izumi-lab/stormy-7b-5ep"
|
8 |
+
|
9 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
10 |
+
model = AutoModelForCausalLM.from_pretrained(
|
11 |
+
BASE_MODEL,
|
12 |
+
load_in_8bit=False,
|
13 |
+
torch_dtype=torch.float16,
|
14 |
+
device_map="auto",
|
15 |
+
)
|
16 |
+
model = PeftModel.from_pretrained(
|
17 |
+
model, LORA_WEIGHTS, torch_dtype=torch.float16, use_auth_token=True
|
18 |
+
)
|
pyproject.toml
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.poetry]
|
2 |
+
name = "space-stormy-7b-10ep"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = ""
|
5 |
+
authors = ["Masanori HIRANO <masa.hirano.1996@gmail.com>"]
|
6 |
+
license = "MIT"
|
7 |
+
readme = "README.md"
|
8 |
+
|
9 |
+
[tool.poetry.dependencies]
|
10 |
+
python = "^3.8.1"
|
11 |
+
peft = "^0.3.0"
|
12 |
+
transformers = "4.28.1"
|
13 |
+
gradio = "3.23.0"
|
14 |
+
torch = "^2.0.1"
|
15 |
+
huggingface-hub = "^0.14.1"
|
16 |
+
bitsandbytes = "^0.38.1"
|
17 |
+
accelerate = "^0.19.0"
|
18 |
+
fschat = "0.2.8"
|
19 |
+
sentencepiece = "^0.1.99"
|
20 |
+
|
21 |
+
[tool.poetry.group.dev.dependencies]
|
22 |
+
black = "^23.3.0"
|
23 |
+
isort = "^5.12.0"
|
24 |
+
mypy = "^1.3.0"
|
25 |
+
flake8 = "^6.0.0"
|
26 |
+
pyproject-flake8 = "^6.0.0.post1"
|
27 |
+
|
28 |
+
[build-system]
|
29 |
+
requires = ["poetry-core"]
|
30 |
+
build-backend = "poetry.core.masonry.api"
|
31 |
+
|
32 |
+
[tool.isort]
|
33 |
+
profile = 'black'
|
34 |
+
force_single_line = true
|
35 |
+
skip = [
|
36 |
+
".git",
|
37 |
+
"__pycache__",
|
38 |
+
"docs",
|
39 |
+
"build",
|
40 |
+
"dist",
|
41 |
+
"examples",
|
42 |
+
".venv",
|
43 |
+
"tests/examples"
|
44 |
+
]
|
45 |
+
|
46 |
+
[tool.mypy]
|
47 |
+
disallow_untyped_defs = true
|
48 |
+
ignore_missing_imports = true
|
49 |
+
|
50 |
+
[tool.flake8]
|
51 |
+
ignore = "E203,E231,E501,W503"
|
52 |
+
max-line-length = 88
|
53 |
+
exclude = [
|
54 |
+
".git",
|
55 |
+
"__pycache__",
|
56 |
+
"docs",
|
57 |
+
"build",
|
58 |
+
"dist",
|
59 |
+
"examples",
|
60 |
+
".venv",
|
61 |
+
"__init__.py"
|
62 |
+
]
|
63 |
+
select = "B,B950,C,E,F,W"
|