masanorihirano commited on
Commit
dc15b84
1 Parent(s): b3d63a6
Files changed (1) hide show
  1. app.py +40 -47
app.py CHANGED
@@ -23,11 +23,14 @@ print("starting server ...")
23
 
24
  BASE_MODEL = "decapoda-research/llama-13b-hf"
25
  LORA_WEIGHTS = "izumi-lab/llama-13b-japanese-lora-v0-1ep"
26
- DATASET_REPOSITORY = os.environ.get("DATASET_REPOSITORY", None)
27
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
28
 
29
  repo = None
30
  LOCAL_DIR = "/home/user/data/"
 
 
 
31
  if HF_TOKEN and DATASET_REPOSITORY:
32
  try:
33
  shutil.rmtree(LOCAL_DIR)
@@ -42,7 +45,6 @@ if HF_TOKEN and DATASET_REPOSITORY:
42
  )
43
  repo.git_pull()
44
 
45
-
46
  tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
47
 
48
  if torch.cuda.is_available():
@@ -62,7 +64,7 @@ if device == "cuda":
62
  load_in_8bit=True,
63
  device_map="auto",
64
  )
65
- model = PeftModel.from_pretrained(model, LORA_WEIGHTS, load_in_8bit=True)
66
  elif device == "mps":
67
  model = AutoModelForCausalLM.from_pretrained(
68
  BASE_MODEL,
@@ -77,10 +79,7 @@ elif device == "mps":
77
  )
78
  else:
79
  model = AutoModelForCausalLM.from_pretrained(
80
- BASE_MODEL,
81
- device_map={"": device},
82
- low_cpu_mem_usage=True,
83
- load_in_8bit=True,
84
  )
85
  model = PeftModel.from_pretrained(
86
  model,
@@ -91,18 +90,29 @@ else:
91
 
92
 
93
  def generate_prompt(instruction: str, input: Optional[str] = None):
 
94
  if input:
95
- return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
 
 
 
96
  ### Instruction:
97
  {instruction}
98
  ### Input:
99
  {input}
100
  ### Response:"""
 
 
101
  else:
102
- return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
 
 
 
103
  ### Instruction:
104
  {instruction}
105
  ### Response:"""
 
 
106
 
107
 
108
  if device != "cpu":
@@ -114,7 +124,7 @@ if torch.__version__ >= "2":
114
 
115
  def save_inputs_and_outputs(now, inputs, outputs, generate_kwargs):
116
  current_hour = now.strftime("%Y-%m-%d_%H")
117
- file_name = f"prompts_{current_hour}.jsonl"
118
 
119
  if repo is not None:
120
  repo.git_pull(rebase=True)
@@ -138,11 +148,11 @@ def evaluate(
138
  instruction,
139
  input=None,
140
  temperature=0.7,
141
- top_p=1.0,
142
- top_k=40,
143
- num_beams=4,
144
  max_new_tokens=256,
145
  ):
 
 
 
146
  prompt = generate_prompt(instruction, input)
147
  inputs = tokenizer(prompt, return_tensors="pt")
148
  input_ids = inputs["input_ids"].to(device)
@@ -151,6 +161,8 @@ def evaluate(
151
  top_p=top_p,
152
  top_k=top_k,
153
  num_beams=num_beams,
 
 
154
  )
155
  with torch.no_grad():
156
  generation_output = model.generate(
@@ -161,9 +173,14 @@ def evaluate(
161
  max_new_tokens=max_new_tokens,
162
  )
163
  s = generation_output.sequences[0]
164
- output = tokenizer.decode(s)
165
- output = output.split("### Response:")[1].strip()
166
- if HF_TOKEN and DATASET_REPOSITORY:
 
 
 
 
 
167
  try:
168
  now = datetime.datetime.now()
169
  current_time = now.strftime("%Y-%m-%d %H:%M:%S")
@@ -215,10 +232,10 @@ with gr.Blocks(
215
  clear_button = gr.Button("Clear").style(full_width=True)
216
  with gr.Column(scale=5):
217
  submit_button = gr.Button("Submit").style(full_width=True)
218
- outputs = gr.Textbox(lines=5, label="Output")
219
 
220
  # inputs, top_p, temperature, top_k, repetition_penalty
221
- with gr.Accordion("Parameters", open=False):
222
  temperature = gr.Slider(
223
  minimum=0,
224
  maximum=1.0,
@@ -227,34 +244,10 @@ with gr.Blocks(
227
  interactive=True,
228
  label="Temperature",
229
  )
230
- top_p = gr.Slider(
231
- minimum=0,
232
- maximum=1.0,
233
- value=1.0,
234
- step=0.05,
235
- interactive=True,
236
- label="Top p",
237
- )
238
- top_k = gr.Slider(
239
- minimum=1,
240
- maximum=50,
241
- value=4,
242
- step=1,
243
- interactive=True,
244
- label="Top k",
245
- )
246
- num_beams = gr.Slider(
247
- minimum=1,
248
- maximum=50,
249
- value=4,
250
- step=1,
251
- interactive=True,
252
- label="Beams",
253
- )
254
  max_new_tokens = gr.Slider(
255
  minimum=1,
256
- maximum=50,
257
- value=4,
258
  step=1,
259
  interactive=True,
260
  label="Max length",
@@ -301,17 +294,17 @@ with gr.Blocks(
301
  inputs.submit(no_interactive, [], [submit_button, clear_button])
302
  inputs.submit(
303
  evaluate,
304
- [instruction, inputs, temperature, top_p, top_k, num_beams, max_new_tokens],
305
  [outputs, submit_button, clear_button],
306
  )
307
  submit_button.click(no_interactive, [], [submit_button, clear_button])
308
  submit_button.click(
309
  evaluate,
310
- [instruction, inputs, temperature, top_p, top_k, num_beams, max_new_tokens],
311
  [outputs, submit_button, clear_button],
312
  )
313
  clear_button.click(reset_textbox, [], [instruction, inputs, outputs], queue=False)
314
 
315
  demo.queue(max_size=20, concurrency_count=NUM_THREADS, api_open=False).launch(
316
- share=False, server_name="0.0.0.0", server_port=7860
317
  )
 
23
 
24
  BASE_MODEL = "decapoda-research/llama-13b-hf"
25
  LORA_WEIGHTS = "izumi-lab/llama-13b-japanese-lora-v0-1ep"
 
26
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
27
+ DATASET_REPOSITORY = os.environ.get("DATASET_REPOSITORY", None)
28
 
29
  repo = None
30
  LOCAL_DIR = "/home/user/data/"
31
+ PROMPT_LANG = "en"
32
+ assert PROMPT_LANG in ["ja", "en"]
33
+
34
  if HF_TOKEN and DATASET_REPOSITORY:
35
  try:
36
  shutil.rmtree(LOCAL_DIR)
 
45
  )
46
  repo.git_pull()
47
 
 
48
  tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
49
 
50
  if torch.cuda.is_available():
 
64
  load_in_8bit=True,
65
  device_map="auto",
66
  )
67
+ model = PeftModel.from_pretrained(model, LORA_WEIGHTS, load_in_8bit=True,)
68
  elif device == "mps":
69
  model = AutoModelForCausalLM.from_pretrained(
70
  BASE_MODEL,
 
79
  )
80
  else:
81
  model = AutoModelForCausalLM.from_pretrained(
82
+ BASE_MODEL, device_map={"": device},load_in_8bit=True, low_cpu_mem_usage=True
 
 
 
83
  )
84
  model = PeftModel.from_pretrained(
85
  model,
 
90
 
91
 
92
  def generate_prompt(instruction: str, input: Optional[str] = None):
93
+ print(f"input: {input}")
94
  if input:
95
+ if PROMPT_LANG == "ja":
96
+ return f"以下はタスクを説明する指示とさらなる文脈を適用する入力の組み合わせです。\n\n### 指示:\n{instruction}\n\n### 入力:\n{input}\n\n### Response:\n"
97
+ elif PROMPT_LANG == "en":
98
+ return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
99
  ### Instruction:
100
  {instruction}
101
  ### Input:
102
  {input}
103
  ### Response:"""
104
+ else:
105
+ raise ValueError("PROMPT_LANG")
106
  else:
107
+ if PROMPT_LANG == "ja":
108
+ return f"以下はタスクを説明する指示とさらなる文脈を適用する入力の組み合わせです。\n\n### 指示:\n{instruction}\n\n### 返答:\n"
109
+ elif PROMPT_LANG == "en":
110
+ return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
111
  ### Instruction:
112
  {instruction}
113
  ### Response:"""
114
+ else:
115
+ raise ValueError("PROMPT_LANG")
116
 
117
 
118
  if device != "cpu":
 
124
 
125
  def save_inputs_and_outputs(now, inputs, outputs, generate_kwargs):
126
  current_hour = now.strftime("%Y-%m-%d_%H")
127
+ file_name = f"prompts_{LORA_WEIGHTS.split('/')[-1]}{current_hour}.jsonl"
128
 
129
  if repo is not None:
130
  repo.git_pull(rebase=True)
 
148
  instruction,
149
  input=None,
150
  temperature=0.7,
 
 
 
151
  max_new_tokens=256,
152
  ):
153
+ num_beams: int = 1
154
+ top_p: float = 1.0
155
+ top_k: int = 0
156
  prompt = generate_prompt(instruction, input)
157
  inputs = tokenizer(prompt, return_tensors="pt")
158
  input_ids = inputs["input_ids"].to(device)
 
161
  top_p=top_p,
162
  top_k=top_k,
163
  num_beams=num_beams,
164
+ pad_token_id=tokenizer.pad_token_id,
165
+ eos_token=tokenizer.eos_token_id,
166
  )
167
  with torch.no_grad():
168
  generation_output = model.generate(
 
173
  max_new_tokens=max_new_tokens,
174
  )
175
  s = generation_output.sequences[0]
176
+ output = tokenizer.decode(s, skip_special_tokens=True)
177
+ if prompt.endswith("Response:"):
178
+ output = output.split("### Response:")[1].strip()
179
+ elif prompt.endswith("返答:"):
180
+ output = output.split("### 返答:")[1].strip()
181
+ else:
182
+ raise ValueError(f"No valid prompt ends. {prompt}")
183
+ if HF_TOKEN:
184
  try:
185
  now = datetime.datetime.now()
186
  current_time = now.strftime("%Y-%m-%d %H:%M:%S")
 
232
  clear_button = gr.Button("Clear").style(full_width=True)
233
  with gr.Column(scale=5):
234
  submit_button = gr.Button("Submit").style(full_width=True)
235
+ outputs = gr.Textbox(lines=4, label="Output")
236
 
237
  # inputs, top_p, temperature, top_k, repetition_penalty
238
+ with gr.Accordion("Parameters", open=True):
239
  temperature = gr.Slider(
240
  minimum=0,
241
  maximum=1.0,
 
244
  interactive=True,
245
  label="Temperature",
246
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  max_new_tokens = gr.Slider(
248
  minimum=1,
249
+ maximum=256,
250
+ value=128,
251
  step=1,
252
  interactive=True,
253
  label="Max length",
 
294
  inputs.submit(no_interactive, [], [submit_button, clear_button])
295
  inputs.submit(
296
  evaluate,
297
+ [instruction, inputs, temperature, max_new_tokens],
298
  [outputs, submit_button, clear_button],
299
  )
300
  submit_button.click(no_interactive, [], [submit_button, clear_button])
301
  submit_button.click(
302
  evaluate,
303
+ [instruction, inputs, temperature, max_new_tokens],
304
  [outputs, submit_button, clear_button],
305
  )
306
  clear_button.click(reset_textbox, [], [instruction, inputs, outputs], queue=False)
307
 
308
  demo.queue(max_size=20, concurrency_count=NUM_THREADS, api_open=False).launch(
309
+ server_name="0.0.0.0", server_port=7860
310
  )