hayas commited on
Commit
5bc0d7f
1 Parent(s): 170fd86
Files changed (6) hide show
  1. .pre-commit-config.yaml +60 -0
  2. .vscode/settings.json +26 -0
  3. README.md +5 -3
  4. app.py +138 -0
  5. requirements.txt +8 -0
  6. style.css +17 -0
.pre-commit-config.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.5.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: end-of-file-fixer
12
+ - id: mixed-line-ending
13
+ args: ["--fix=lf"]
14
+ - id: requirements-txt-fixer
15
+ - id: trailing-whitespace
16
+ - repo: https://github.com/myint/docformatter
17
+ rev: v1.7.5
18
+ hooks:
19
+ - id: docformatter
20
+ args: ["--in-place"]
21
+ - repo: https://github.com/pycqa/isort
22
+ rev: 5.13.2
23
+ hooks:
24
+ - id: isort
25
+ args: ["--profile", "black"]
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v1.8.0
28
+ hooks:
29
+ - id: mypy
30
+ args: ["--ignore-missing-imports"]
31
+ additional_dependencies:
32
+ [
33
+ "types-python-slugify",
34
+ "types-requests",
35
+ "types-PyYAML",
36
+ "types-pytz",
37
+ ]
38
+ - repo: https://github.com/psf/black
39
+ rev: 23.12.1
40
+ hooks:
41
+ - id: black
42
+ language_version: python3.10
43
+ args: ["--line-length", "119"]
44
+ - repo: https://github.com/kynan/nbstripout
45
+ rev: 0.6.1
46
+ hooks:
47
+ - id: nbstripout
48
+ args:
49
+ [
50
+ "--extra-keys",
51
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
52
+ ]
53
+ - repo: https://github.com/nbQA-dev/nbQA
54
+ rev: 1.7.1
55
+ hooks:
56
+ - id: nbqa-black
57
+ - id: nbqa-pyupgrade
58
+ args: ["--py37-plus"]
59
+ - id: nbqa-isort
60
+ args: ["--float-to-top"]
.vscode/settings.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "editor.formatOnSave": true,
3
+ "files.insertFinalNewline": false,
4
+ "[python]": {
5
+ "editor.defaultFormatter": "ms-python.black-formatter",
6
+ "editor.formatOnType": true,
7
+ "editor.codeActionsOnSave": {
8
+ "source.organizeImports": "explicit"
9
+ }
10
+ },
11
+ "[jupyter]": {
12
+ "files.insertFinalNewline": false
13
+ },
14
+ "black-formatter.args": [
15
+ "--line-length=119"
16
+ ],
17
+ "isort.args": ["--profile", "black"],
18
+ "flake8.args": [
19
+ "--max-line-length=119"
20
+ ],
21
+ "ruff.lint.args": [
22
+ "--line-length=119"
23
+ ],
24
+ "notebook.output.scrolling": true,
25
+ "notebook.formatOnCellExecution": true
26
+ }
README.md CHANGED
@@ -1,12 +1,14 @@
1
  ---
2
  title: RakutenAI 7B Chat
3
- emoji: 🐢
4
- colorFrom: green
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 4.23.0
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: RakutenAI 7B Chat
3
+ emoji:
4
+ colorFrom: red
5
+ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.23.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
+ suggested-hardware: t4-small
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ from threading import Thread
5
+ from typing import Iterator
6
+
7
+ import gradio as gr
8
+ import spaces
9
+ import torch
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
+
12
+ DESCRIPTION = "# RakutenAI-7B-chat"
13
+
14
+ if not torch.cuda.is_available():
15
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
16
+
17
+ MAX_MAX_NEW_TOKENS = 2048
18
+ DEFAULT_MAX_NEW_TOKENS = 1024
19
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "32768"))
20
+
21
+ if torch.cuda.is_available():
22
+ model_id = "Rakuten/RakutenAI-7B-chat"
23
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
24
+ model.eval()
25
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
26
+
27
+
28
+ def apply_chat_template(conversation: list[dict[str, str]]) -> str:
29
+ prompt = "\n".join([f"{c['role']}: {c['content']}" for c in conversation])
30
+ prompt = f"{prompt}\nASSISTANT: "
31
+ return prompt
32
+
33
+
34
+ @spaces.GPU
35
+ @torch.inference_mode()
36
+ def generate(
37
+ message: str,
38
+ chat_history: list[tuple[str, str]],
39
+ max_new_tokens: int = 1024,
40
+ temperature: float = 0.7,
41
+ top_p: float = 0.95,
42
+ top_k: int = 50,
43
+ repetition_penalty: float = 1.0,
44
+ ) -> Iterator[str]:
45
+ conversation = []
46
+ for user, assistant in chat_history:
47
+ conversation.extend([{"role": "USER", "content": user}, {"role": "ASSISTANT", "content": assistant}])
48
+ conversation.append({"role": "USER", "content": message})
49
+
50
+ prompt = apply_chat_template(conversation)
51
+ input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
52
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
53
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
54
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
55
+ input_ids = input_ids.to(model.device)
56
+
57
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
58
+ generate_kwargs = dict(
59
+ {"input_ids": input_ids},
60
+ streamer=streamer,
61
+ max_new_tokens=max_new_tokens,
62
+ do_sample=True,
63
+ top_p=top_p,
64
+ top_k=top_k,
65
+ temperature=temperature,
66
+ num_beams=1,
67
+ repetition_penalty=repetition_penalty,
68
+ )
69
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
70
+ t.start()
71
+
72
+ outputs = []
73
+ for text in streamer:
74
+ outputs.append(text)
75
+ yield "".join(outputs)
76
+
77
+
78
+ chat_interface = gr.ChatInterface(
79
+ fn=generate,
80
+ chatbot=gr.Chatbot(show_label=False, layout="panel", height=600),
81
+ additional_inputs_accordion_name="詳細設定",
82
+ additional_inputs=[
83
+ gr.Slider(
84
+ label="Max new tokens",
85
+ minimum=1,
86
+ maximum=MAX_MAX_NEW_TOKENS,
87
+ step=1,
88
+ value=DEFAULT_MAX_NEW_TOKENS,
89
+ ),
90
+ gr.Slider(
91
+ label="Temperature",
92
+ minimum=0.1,
93
+ maximum=4.0,
94
+ step=0.1,
95
+ value=0.7,
96
+ ),
97
+ gr.Slider(
98
+ label="Top-p (nucleus sampling)",
99
+ minimum=0.05,
100
+ maximum=1.0,
101
+ step=0.05,
102
+ value=0.95,
103
+ ),
104
+ gr.Slider(
105
+ label="Top-k",
106
+ minimum=1,
107
+ maximum=1000,
108
+ step=1,
109
+ value=50,
110
+ ),
111
+ gr.Slider(
112
+ label="Repetition penalty",
113
+ minimum=1.0,
114
+ maximum=2.0,
115
+ step=0.05,
116
+ value=1.0,
117
+ ),
118
+ ],
119
+ stop_btn=None,
120
+ examples=[
121
+ ["東京の観光名所を教えて。"],
122
+ ["落武者って何?"],
123
+ ["暴れん坊将軍って誰のこと?"],
124
+ ["人がヘリを食べるのにかかる時間は?"],
125
+ ],
126
+ )
127
+
128
+ with gr.Blocks(css="style.css") as demo:
129
+ gr.Markdown(DESCRIPTION)
130
+ gr.DuplicateButton(
131
+ value="Duplicate Space for private use",
132
+ elem_id="duplicate-button",
133
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
134
+ )
135
+ chat_interface.render()
136
+
137
+ if __name__ == "__main__":
138
+ demo.queue(max_size=20).launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.28.0
2
+ bitsandbytes==0.43.0
3
+ gradio==4.23.0
4
+ scipy==1.12.0
5
+ sentencepiece==0.1.99
6
+ spaces==0.24.2
7
+ torch==2.0.0
8
+ transformers==4.39.1
style.css ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ display: block;
4
+ }
5
+
6
+ #duplicate-button {
7
+ margin: auto;
8
+ color: white;
9
+ background: #1565c0;
10
+ border-radius: 100vh;
11
+ }
12
+
13
+ .contain {
14
+ max-width: 900px;
15
+ margin: auto;
16
+ padding-top: 1.5rem;
17
+ }