lewtun HF staff commited on
Commit
da700a9
β€’
1 Parent(s): 2184a6f
Files changed (2) hide show
  1. .gitignore +160 -0
  2. app.py +17 -14
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
app.py CHANGED
@@ -19,6 +19,7 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
19
 
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
  model_id = "trl-lib/llama-se-rl-merged"
 
22
  if device == "cpu":
23
  model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, use_auth_token=HF_TOKEN)
24
  else:
@@ -28,11 +29,14 @@ else:
28
 
29
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
30
 
31
- PROMPT_TEMPLATE = """Question: {prompt}\n\nAnswer: """
32
 
33
 
34
- def generate(instruction, temperature=1, max_new_tokens=256, top_p=1, top_k=0):
35
  formatted_instruction = PROMPT_TEMPLATE.format(prompt=instruction)
 
 
 
36
  streamer = TextIteratorStreamer(tokenizer)
37
  model_inputs = tokenizer(formatted_instruction, return_tensors="pt", truncation=True, max_length=2048).to(device)
38
 
@@ -56,8 +60,8 @@ def generate(instruction, temperature=1, max_new_tokens=256, top_p=1, top_k=0):
56
  hidden_output += new_text
57
  continue
58
  # replace eos token
59
- if tokenizer.eos_token in new_text:
60
- new_text = new_text.replace(tokenizer.eos_token, "")
61
  output += new_text
62
  yield output
63
  return output
@@ -66,8 +70,7 @@ def generate(instruction, temperature=1, max_new_tokens=256, top_p=1, top_k=0):
66
  examples = [
67
  "How do I create an array in C++ of length 5 which contains all even numbers between 1 and 10?",
68
  "How can I write a Java function to generate the nth Fibonacci number?",
69
- "How can I write a Python function that checks if a given number is a palindrome or not?",
70
- "I have a lion in my garden. How can I get rid of it?",
71
  ]
72
 
73
 
@@ -77,7 +80,7 @@ def process_example(args):
77
  return x
78
 
79
 
80
- with gr.Blocks(theme=theme) as demo:
81
  with gr.Column():
82
  gr.Markdown(
83
  """<h1><center>πŸ¦™πŸ¦™πŸ¦™ StackLLaMa πŸ¦™πŸ¦™πŸ¦™</center></h1>
@@ -111,7 +114,7 @@ with gr.Blocks(theme=theme) as demo:
111
  with gr.Column(scale=1):
112
  temperature = gr.Slider(
113
  label="Temperature",
114
- value=1.0,
115
  minimum=0.0,
116
  maximum=2.0,
117
  step=0.1,
@@ -120,25 +123,25 @@ with gr.Blocks(theme=theme) as demo:
120
  )
121
  max_new_tokens = gr.Slider(
122
  label="Max new tokens",
123
- value=256,
124
  minimum=0,
125
  maximum=2048,
126
- step=5,
127
  interactive=True,
128
  info="The maximum numbers of new tokens",
129
  )
130
  top_p = gr.Slider(
131
  label="Top-p (nucleus sampling)",
132
- value=1.0,
133
  minimum=0.0,
134
  maximum=1,
135
  step=0.05,
136
  interactive=True,
137
- info="Higher values sample fewer low-probability tokens",
138
  )
139
  top_k = gr.Slider(
140
  label="Top-k",
141
- value=0,
142
  minimum=0,
143
  maximum=100,
144
  step=2,
@@ -150,4 +153,4 @@ with gr.Blocks(theme=theme) as demo:
150
  instruction.submit(generate, inputs=[instruction, temperature, max_new_tokens, top_p, top_k], outputs=[output])
151
 
152
  demo.queue(concurrency_count=1)
153
- demo.launch(enable_queue=True)
 
19
 
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
  model_id = "trl-lib/llama-se-rl-merged"
22
+ print(f"Loading model: {model_id}")
23
  if device == "cpu":
24
  model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, use_auth_token=HF_TOKEN)
25
  else:
 
29
 
30
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
31
 
32
+ PROMPT_TEMPLATE = """Question: {prompt}\n\nAnswer:"""
33
 
34
 
35
+ def generate(instruction, temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=40):
36
  formatted_instruction = PROMPT_TEMPLATE.format(prompt=instruction)
37
+
38
+ temperature = float(temperature)
39
+ top_p = float(top_p)
40
  streamer = TextIteratorStreamer(tokenizer)
41
  model_inputs = tokenizer(formatted_instruction, return_tensors="pt", truncation=True, max_length=2048).to(device)
42
 
 
60
  hidden_output += new_text
61
  continue
62
  # replace eos token
63
+ # if tokenizer.eos_token in new_text:
64
+ # new_text = new_text.replace(tokenizer.eos_token, "")
65
  output += new_text
66
  yield output
67
  return output
 
70
  examples = [
71
  "How do I create an array in C++ of length 5 which contains all even numbers between 1 and 10?",
72
  "How can I write a Java function to generate the nth Fibonacci number?",
73
+ "How can I sort a list in Python?",
 
74
  ]
75
 
76
 
 
80
  return x
81
 
82
 
83
+ with gr.Blocks(theme=theme, analytics_enabled=False) as demo:
84
  with gr.Column():
85
  gr.Markdown(
86
  """<h1><center>πŸ¦™πŸ¦™πŸ¦™ StackLLaMa πŸ¦™πŸ¦™πŸ¦™</center></h1>
 
114
  with gr.Column(scale=1):
115
  temperature = gr.Slider(
116
  label="Temperature",
117
+ value=0.7,
118
  minimum=0.0,
119
  maximum=2.0,
120
  step=0.1,
 
123
  )
124
  max_new_tokens = gr.Slider(
125
  label="Max new tokens",
126
+ value=64,
127
  minimum=0,
128
  maximum=2048,
129
+ step=4,
130
  interactive=True,
131
  info="The maximum numbers of new tokens",
132
  )
133
  top_p = gr.Slider(
134
  label="Top-p (nucleus sampling)",
135
+ value=0.95,
136
  minimum=0.0,
137
  maximum=1,
138
  step=0.05,
139
  interactive=True,
140
+ info="Higher values sample more low-probability tokens",
141
  )
142
  top_k = gr.Slider(
143
  label="Top-k",
144
+ value=40,
145
  minimum=0,
146
  maximum=100,
147
  step=2,
 
153
  instruction.submit(generate, inputs=[instruction, temperature, max_new_tokens, top_p, top_k], outputs=[output])
154
 
155
  demo.queue(concurrency_count=1)
156
+ demo.launch(enable_queue=True, share=True)