hysts HF staff commited on
Commit
c49fb1b
·
1 Parent(s): e375076
Files changed (7) hide show
  1. .pre-commit-config.yaml +55 -0
  2. .vscode/settings.json +21 -0
  3. LICENSE +21 -0
  4. README.md +1 -0
  5. app.py +136 -0
  6. requirements.txt +9 -0
  7. style.css +16 -0
.pre-commit-config.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.5.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: end-of-file-fixer
12
+ - id: mixed-line-ending
13
+ args: ["--fix=lf"]
14
+ - id: requirements-txt-fixer
15
+ - id: trailing-whitespace
16
+ - repo: https://github.com/myint/docformatter
17
+ rev: v1.7.5
18
+ hooks:
19
+ - id: docformatter
20
+ args: ["--in-place"]
21
+ - repo: https://github.com/pycqa/isort
22
+ rev: 5.12.0
23
+ hooks:
24
+ - id: isort
25
+ args: ["--profile", "black"]
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v1.6.1
28
+ hooks:
29
+ - id: mypy
30
+ args: ["--ignore-missing-imports"]
31
+ additional_dependencies:
32
+ ["types-python-slugify", "types-requests", "types-PyYAML"]
33
+ - repo: https://github.com/psf/black
34
+ rev: 23.10.0
35
+ hooks:
36
+ - id: black
37
+ language_version: python3.10
38
+ args: ["--line-length", "119"]
39
+ - repo: https://github.com/kynan/nbstripout
40
+ rev: 0.6.1
41
+ hooks:
42
+ - id: nbstripout
43
+ args:
44
+ [
45
+ "--extra-keys",
46
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
47
+ ]
48
+ - repo: https://github.com/nbQA-dev/nbQA
49
+ rev: 1.7.0
50
+ hooks:
51
+ - id: nbqa-black
52
+ - id: nbqa-pyupgrade
53
+ args: ["--py37-plus"]
54
+ - id: nbqa-isort
55
+ args: ["--float-to-top"]
.vscode/settings.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "[python]": {
3
+ "editor.defaultFormatter": "ms-python.black-formatter",
4
+ "editor.formatOnType": true,
5
+ "editor.codeActionsOnSave": {
6
+ "source.organizeImports": true
7
+ }
8
+ },
9
+ "black-formatter.args": [
10
+ "--line-length=119"
11
+ ],
12
+ "isort.args": ["--profile", "black"],
13
+ "flake8.args": [
14
+ "--max-line-length=119"
15
+ ],
16
+ "ruff.args": [
17
+ "--line-length=119"
18
+ ],
19
+ "editor.formatOnSave": true,
20
+ "files.insertFinalNewline": true
21
+ }
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 hysts
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -8,6 +8,7 @@ sdk_version: 3.50.2
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ suggested-hardware: t4-small
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ from threading import Thread
5
+ from typing import Iterator
6
+
7
+ import gradio as gr
8
+ import spaces
9
+ import torch
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
+
12
+ DESCRIPTION = "# Zephyr-7B-alpha"
13
+
14
+ if not torch.cuda.is_available():
15
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
16
+
17
+ MAX_MAX_NEW_TOKENS = 2048
18
+ DEFAULT_MAX_NEW_TOKENS = 1024
19
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
20
+
21
+ if torch.cuda.is_available():
22
+ model_id = "HuggingFaceH4/zephyr-7b-alpha"
23
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
24
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
25
+
26
+
27
+ @spaces.GPU
28
+ def generate(
29
+ message: str,
30
+ chat_history: list[tuple[str, str]],
31
+ system_prompt: str = "",
32
+ max_new_tokens: int = 1024,
33
+ temperature: float = 0.7,
34
+ top_p: float = 0.95,
35
+ top_k: int = 50,
36
+ repetition_penalty: float = 1.0,
37
+ ) -> Iterator[str]:
38
+ conversation = []
39
+ if system_prompt:
40
+ conversation.append({"role": "system", "content": system_prompt})
41
+ for user, assistant in chat_history:
42
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
43
+ conversation.append({"role": "user", "content": message})
44
+
45
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True)
46
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
47
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
48
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
49
+ input_ids = input_ids.to(model.device)
50
+
51
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
52
+ generate_kwargs = dict(
53
+ {"input_ids": input_ids},
54
+ streamer=streamer,
55
+ max_new_tokens=max_new_tokens,
56
+ do_sample=True,
57
+ top_p=top_p,
58
+ top_k=top_k,
59
+ temperature=temperature,
60
+ num_beams=1,
61
+ repetition_penalty=repetition_penalty,
62
+ )
63
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
64
+ t.start()
65
+
66
+ outputs = []
67
+ for text in streamer:
68
+ outputs.append(text)
69
+ yield "".join(outputs)
70
+
71
+
72
+ chat_interface = gr.ChatInterface(
73
+ fn=generate,
74
+ additional_inputs=[
75
+ gr.Textbox(
76
+ label="System prompt",
77
+ lines=6,
78
+ placeholder="You are a friendly chatbot who always responds in the style of a pirate.",
79
+ ),
80
+ gr.Slider(
81
+ label="Max new tokens",
82
+ minimum=1,
83
+ maximum=MAX_MAX_NEW_TOKENS,
84
+ step=1,
85
+ value=DEFAULT_MAX_NEW_TOKENS,
86
+ ),
87
+ gr.Slider(
88
+ label="Temperature",
89
+ minimum=0.1,
90
+ maximum=4.0,
91
+ step=0.1,
92
+ value=0.7,
93
+ ),
94
+ gr.Slider(
95
+ label="Top-p (nucleus sampling)",
96
+ minimum=0.05,
97
+ maximum=1.0,
98
+ step=0.05,
99
+ value=0.95,
100
+ ),
101
+ gr.Slider(
102
+ label="Top-k",
103
+ minimum=1,
104
+ maximum=1000,
105
+ step=1,
106
+ value=50,
107
+ ),
108
+ gr.Slider(
109
+ label="Repetition penalty",
110
+ minimum=1.0,
111
+ maximum=2.0,
112
+ step=0.05,
113
+ value=1.0,
114
+ ),
115
+ ],
116
+ stop_btn=None,
117
+ examples=[
118
+ ["Hello there! How are you doing?"],
119
+ ["Can you explain briefly to me what is the Python programming language?"],
120
+ ["Explain the plot of Cinderella in a sentence."],
121
+ ["How many hours does it take a man to eat a Helicopter?"],
122
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
123
+ ],
124
+ )
125
+
126
+ with gr.Blocks(css="style.css") as demo:
127
+ gr.Markdown(DESCRIPTION)
128
+ gr.DuplicateButton(
129
+ value="Duplicate Space for private use",
130
+ elem_id="duplicate-button",
131
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
132
+ )
133
+ chat_interface.render()
134
+
135
+ if __name__ == "__main__":
136
+ demo.queue(max_size=20).launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.23.0
2
+ bitsandbytes==0.41.1
3
+ gradio==3.50.2
4
+ protobuf==3.20.3
5
+ scipy==1.11.2
6
+ sentencepiece==0.1.99
7
+ spaces==0.16.3
8
+ torch==2.0.0
9
+ transformers==4.34.1
style.css ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+
5
+ #duplicate-button {
6
+ margin: auto;
7
+ color: white;
8
+ background: #1565c0;
9
+ border-radius: 100vh;
10
+ }
11
+
12
+ .contain {
13
+ max-width: 900px;
14
+ margin: auto;
15
+ padding-top: 1.5rem;
16
+ }