lewtun HF staff commited on
Commit
9f1411e
1 Parent(s): 7aabaa7

Fix prompt

Browse files
Files changed (1) hide show
  1. app.py +1 -19
app.py CHANGED
@@ -1,10 +1,8 @@
1
- import json
2
  import os
3
  from threading import Thread
4
 
5
  import gradio as gr
6
  import torch
7
- from huggingface_hub import Repository
8
  from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
  GenerationConfig, TextIteratorStreamer)
10
 
@@ -15,13 +13,8 @@ theme = gr.themes.Monochrome(
15
  radius_size=gr.themes.sizes.radius_sm,
16
  font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
17
  )
18
- # filesystem to save input and outputs
19
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
20
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
- # if HF_TOKEN:
22
- # repo = Repository(
23
- # local_dir="data", clone_from="philschmid/playground-prompts", use_auth_token=HF_TOKEN, repo_type="dataset"
24
- # )
25
 
26
 
27
  # Load peft config for pre-trained checkpoint etc.
@@ -30,8 +23,6 @@ model_id = "HuggingFaceH4/llama-se-rl-ed"
30
  if device == "cpu":
31
  model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, use_auth_token=HF_TOKEN)
32
  else:
33
- # torch_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
34
- # model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, device_map="auto")
35
  model = AutoModelForCausalLM.from_pretrained(
36
  model_id, device_map="auto", load_in_8bit=True, use_auth_token=HF_TOKEN
37
  )
@@ -42,7 +33,7 @@ PROMPT_TEMPLATE = """Question: {prompt}\n\nAnswer: """
42
 
43
 
44
  def generate(instruction, temperature, max_new_tokens, top_p, length_penalty):
45
- formatted_instruction = PROMPT_TEMPLATE.format(input=instruction)
46
  # COMMENT IN FOR NON STREAMING
47
  # generation_config = GenerationConfig(
48
  # do_sample=True,
@@ -95,18 +86,9 @@ def generate(instruction, temperature, max_new_tokens, top_p, length_penalty):
95
  new_text = new_text.replace(tokenizer.eos_token, "")
96
  output += new_text
97
  yield output
98
- # if HF_TOKEN:
99
- # save_inputs_and_outputs(formatted_instruction, output, generate_kwargs)
100
  return output
101
 
102
 
103
- # def save_inputs_and_outputs(inputs, outputs, generate_kwargs):
104
- # with open(os.path.join("data", "prompts.jsonl"), "a") as f:
105
- # json.dump({"inputs": inputs, "outputs": outputs, "generate_kwargs": generate_kwargs}, f, ensure_ascii=False)
106
- # f.write("\n")
107
- # commit_url = repo.push_to_hub()
108
-
109
-
110
  examples = [
111
  "How do I create an array in C++ of length 5 which contains all even numbers between 1 and 10?",
112
  "How can I write a Java function to generate the nth Fibonacci number?",
 
 
1
  import os
2
  from threading import Thread
3
 
4
  import gradio as gr
5
  import torch
 
6
  from transformers import (AutoModelForCausalLM, AutoTokenizer,
7
  GenerationConfig, TextIteratorStreamer)
8
 
 
13
  radius_size=gr.themes.sizes.radius_sm,
14
  font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
15
  )
 
16
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
17
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
 
 
 
18
 
19
 
20
  # Load peft config for pre-trained checkpoint etc.
 
23
  if device == "cpu":
24
  model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, use_auth_token=HF_TOKEN)
25
  else:
 
 
26
  model = AutoModelForCausalLM.from_pretrained(
27
  model_id, device_map="auto", load_in_8bit=True, use_auth_token=HF_TOKEN
28
  )
 
33
 
34
 
35
  def generate(instruction, temperature, max_new_tokens, top_p, length_penalty):
36
+ formatted_instruction = PROMPT_TEMPLATE.format(prompt=instruction)
37
  # COMMENT IN FOR NON STREAMING
38
  # generation_config = GenerationConfig(
39
  # do_sample=True,
 
86
  new_text = new_text.replace(tokenizer.eos_token, "")
87
  output += new_text
88
  yield output
 
 
89
  return output
90
 
91
 
 
 
 
 
 
 
 
92
  examples = [
93
  "How do I create an array in C++ of length 5 which contains all even numbers between 1 and 10?",
94
  "How can I write a Java function to generate the nth Fibonacci number?",