hysts HF staff commited on
Commit
a9106b7
0 Parent(s):

initial commit

Browse files
Files changed (7) hide show
  1. .gitattributes +35 -0
  2. .pre-commit-config.yaml +60 -0
  3. .vscode/settings.json +30 -0
  4. README.md +12 -0
  5. app.py +150 -0
  6. requirements.txt +6 -0
  7. style.css +11 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.pre-commit-config.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.6.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: end-of-file-fixer
12
+ - id: mixed-line-ending
13
+ args: ["--fix=lf"]
14
+ - id: requirements-txt-fixer
15
+ - id: trailing-whitespace
16
+ - repo: https://github.com/myint/docformatter
17
+ rev: v1.7.5
18
+ hooks:
19
+ - id: docformatter
20
+ args: ["--in-place"]
21
+ - repo: https://github.com/pycqa/isort
22
+ rev: 5.13.2
23
+ hooks:
24
+ - id: isort
25
+ args: ["--profile", "black"]
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v1.10.1
28
+ hooks:
29
+ - id: mypy
30
+ args: ["--ignore-missing-imports"]
31
+ additional_dependencies:
32
+ [
33
+ "types-python-slugify",
34
+ "types-requests",
35
+ "types-PyYAML",
36
+ "types-pytz",
37
+ ]
38
+ - repo: https://github.com/psf/black
39
+ rev: 24.4.2
40
+ hooks:
41
+ - id: black
42
+ language_version: python3.10
43
+ args: ["--line-length", "119"]
44
+ - repo: https://github.com/kynan/nbstripout
45
+ rev: 0.7.1
46
+ hooks:
47
+ - id: nbstripout
48
+ args:
49
+ [
50
+ "--extra-keys",
51
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
52
+ ]
53
+ - repo: https://github.com/nbQA-dev/nbQA
54
+ rev: 1.8.5
55
+ hooks:
56
+ - id: nbqa-black
57
+ - id: nbqa-pyupgrade
58
+ args: ["--py37-plus"]
59
+ - id: nbqa-isort
60
+ args: ["--float-to-top"]
.vscode/settings.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "editor.formatOnSave": true,
3
+ "files.insertFinalNewline": false,
4
+ "[python]": {
5
+ "editor.defaultFormatter": "ms-python.black-formatter",
6
+ "editor.formatOnType": true,
7
+ "editor.codeActionsOnSave": {
8
+ "source.organizeImports": "explicit"
9
+ }
10
+ },
11
+ "[jupyter]": {
12
+ "files.insertFinalNewline": false
13
+ },
14
+ "black-formatter.args": [
15
+ "--line-length=119"
16
+ ],
17
+ "isort.args": ["--profile", "black"],
18
+ "flake8.args": [
19
+ "--max-line-length=119"
20
+ ],
21
+ "ruff.lint.args": [
22
+ "--line-length=119"
23
+ ],
24
+ "notebook.output.scrolling": true,
25
+ "notebook.formatOnCellExecution": true,
26
+ "notebook.formatOnSave.enabled": true,
27
+ "notebook.codeActionsOnSave": {
28
+ "source.organizeImports": "explicit"
29
+ }
30
+ }
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Gemma 2 9B IT
3
+ emoji: 😻
4
+ colorFrom: indigo
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 4.37.1
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shlex
3
+ import subprocess
4
+ from threading import Thread
5
+ from typing import Iterator
6
+
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ whl_path = hf_hub_download("google/gemma-2-9b-it", "transformers/transformers-4.42.0.dev0-py3-none-any.whl")
10
+ subprocess.run(shlex.split(f"pip install {whl_path}"))
11
+
12
+
13
+ import gradio as gr
14
+ import spaces
15
+ import torch
16
+ from transformers import (
17
+ AutoModelForCausalLM,
18
+ BitsAndBytesConfig,
19
+ GemmaTokenizerFast,
20
+ TextIteratorStreamer,
21
+ )
22
+
23
+ DESCRIPTION = """\
24
+ # Gemma 2 9B IT
25
+
26
+ Gemma 2 is Google's latest iteration of open LLMs.
27
+ This is a demo of [`google/gemma-2-9b-it`](https://huggingface.co/google/gemma-2-9b-it), fine-tuned for instruction following.
28
+ For more details, please check [our post](https://huggingface.co/blog/gemma-2).
29
+ """
30
+
31
+ MAX_MAX_NEW_TOKENS = 2048
32
+ DEFAULT_MAX_NEW_TOKENS = 1024
33
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
34
+
35
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
36
+
37
+ model_id = "google/gemma-2-9b-it"
38
+ tokenizer = GemmaTokenizerFast.from_pretrained(model_id)
39
+ model = AutoModelForCausalLM.from_pretrained(
40
+ model_id,
41
+ device_map="auto",
42
+ quantization_config=BitsAndBytesConfig(load_in_8bit=True),
43
+ )
44
+ model.config.sliding_window = 4096
45
+ model.eval()
46
+
47
+
48
+ @spaces.GPU(duration=90)
49
+ def generate(
50
+ message: str,
51
+ chat_history: list[tuple[str, str]],
52
+ max_new_tokens: int = 1024,
53
+ temperature: float = 0.6,
54
+ top_p: float = 0.9,
55
+ top_k: int = 50,
56
+ repetition_penalty: float = 1.2,
57
+ ) -> Iterator[str]:
58
+ conversation = []
59
+ for user, assistant in chat_history:
60
+ conversation.extend(
61
+ [
62
+ {"role": "user", "content": user},
63
+ {"role": "assistant", "content": assistant},
64
+ ]
65
+ )
66
+ conversation.append({"role": "user", "content": message})
67
+
68
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
69
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
70
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
71
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
72
+ input_ids = input_ids.to(model.device)
73
+
74
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
75
+ generate_kwargs = dict(
76
+ {"input_ids": input_ids},
77
+ streamer=streamer,
78
+ max_new_tokens=max_new_tokens,
79
+ do_sample=True,
80
+ top_p=top_p,
81
+ top_k=top_k,
82
+ temperature=temperature,
83
+ num_beams=1,
84
+ repetition_penalty=repetition_penalty,
85
+ )
86
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
87
+ t.start()
88
+
89
+ outputs = []
90
+ for text in streamer:
91
+ outputs.append(text)
92
+ yield "".join(outputs)
93
+
94
+
95
+ chat_interface = gr.ChatInterface(
96
+ fn=generate,
97
+ additional_inputs=[
98
+ gr.Slider(
99
+ label="Max new tokens",
100
+ minimum=1,
101
+ maximum=MAX_MAX_NEW_TOKENS,
102
+ step=1,
103
+ value=DEFAULT_MAX_NEW_TOKENS,
104
+ ),
105
+ gr.Slider(
106
+ label="Temperature",
107
+ minimum=0.1,
108
+ maximum=4.0,
109
+ step=0.1,
110
+ value=0.6,
111
+ ),
112
+ gr.Slider(
113
+ label="Top-p (nucleus sampling)",
114
+ minimum=0.05,
115
+ maximum=1.0,
116
+ step=0.05,
117
+ value=0.9,
118
+ ),
119
+ gr.Slider(
120
+ label="Top-k",
121
+ minimum=1,
122
+ maximum=1000,
123
+ step=1,
124
+ value=50,
125
+ ),
126
+ gr.Slider(
127
+ label="Repetition penalty",
128
+ minimum=1.0,
129
+ maximum=2.0,
130
+ step=0.05,
131
+ value=1.2,
132
+ ),
133
+ ],
134
+ stop_btn=None,
135
+ examples=[
136
+ ["Hello there! How are you doing?"],
137
+ ["Can you explain briefly to me what is the Python programming language?"],
138
+ ["Explain the plot of Cinderella in a sentence."],
139
+ ["How many hours does it take a man to eat a Helicopter?"],
140
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
141
+ ],
142
+ )
143
+
144
+ with gr.Blocks(css="style.css", fill_height=True) as demo:
145
+ gr.Markdown(DESCRIPTION)
146
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
147
+ chat_interface.render()
148
+
149
+ if __name__ == "__main__":
150
+ demo.queue(max_size=20).launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ accelerate==0.31.0
2
+ bitsandbytes==0.43.1
3
+ gradio==4.37.1
4
+ spaces==0.28.3
5
+ torch==2.2.0
6
+ #transformers==4.42.0
style.css ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ display: block;
4
+ }
5
+
6
+ #duplicate-button {
7
+ margin: auto;
8
+ color: #fff;
9
+ background: #1565c0;
10
+ border-radius: 100vh;
11
+ }