lewtun HF staff commited on
Commit
5d26322
1 Parent(s): a9e7d31

Refactor for StackLLama:

Browse files
Files changed (1) hide show
  1. app.py +25 -26
app.py CHANGED
@@ -1,10 +1,12 @@
 
1
  import os
 
 
2
  import gradio as gr
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, TextIteratorStreamer
4
  import torch
5
- from threading import Thread
6
  from huggingface_hub import Repository
7
- import json
 
8
 
9
  theme = gr.themes.Monochrome(
10
  primary_hue="indigo",
@@ -16,15 +18,15 @@ theme = gr.themes.Monochrome(
16
  # filesystem to save input and outputs
17
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
18
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
19
- if HF_TOKEN:
20
- repo = Repository(
21
- local_dir="data", clone_from="philschmid/playground-prompts", use_auth_token=HF_TOKEN, repo_type="dataset"
22
- )
23
 
24
 
25
  # Load peft config for pre-trained checkpoint etc.
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
- model_id = "philschmid/instruct-igel-001"
28
  if device == "cpu":
29
  model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True)
30
  else:
@@ -34,11 +36,11 @@ else:
34
 
35
  tokenizer = AutoTokenizer.from_pretrained(model_id)
36
 
37
- prompt_template = f"### Anweisung:\n{{input}}\n\n### Antwort:"
38
 
39
 
40
  def generate(instruction, temperature, max_new_tokens, top_p, length_penalty):
41
- formatted_instruction = prompt_template.format(input=instruction)
42
  # COMMENT IN FOR NON STREAMING
43
  # generation_config = GenerationConfig(
44
  # do_sample=True,
@@ -65,9 +67,7 @@ def generate(instruction, temperature, max_new_tokens, top_p, length_penalty):
65
 
66
  # streaming
67
  streamer = TextIteratorStreamer(tokenizer)
68
- model_inputs = tokenizer(formatted_instruction, return_tensors="pt", truncation=True, max_length=2048)
69
- # move to gpu
70
- model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
71
 
72
  generate_kwargs = dict(
73
  top_p=top_p,
@@ -93,16 +93,16 @@ def generate(instruction, temperature, max_new_tokens, top_p, length_penalty):
93
  new_text = new_text.replace(tokenizer.eos_token, "")
94
  output += new_text
95
  yield output
96
- if HF_TOKEN:
97
- save_inputs_and_outputs(formatted_instruction, output, generate_kwargs)
98
  return output
99
 
100
 
101
- def save_inputs_and_outputs(inputs, outputs, generate_kwargs):
102
- with open(os.path.join("data", "prompts.jsonl"), "a") as f:
103
- json.dump({"inputs": inputs, "outputs": outputs, "generate_kwargs": generate_kwargs}, f, ensure_ascii=False)
104
- f.write("\n")
105
- commit_url = repo.push_to_hub()
106
 
107
 
108
  examples = [
@@ -124,12 +124,11 @@ Frage: Wann wurde Hugging Face gegründet?""",
124
  with gr.Blocks(theme=theme) as demo:
125
  with gr.Column():
126
  gr.Markdown(
127
- """<h1><center>IGEL - Instruction-tuned German large Language Model for Text</center></h1>
128
- <p>
129
- IGEL is a LLM model family developed for the German language. The first version of IGEL is built on top <a href="https://bigscience.huggingface.co/blog/bloom" target="_blank">BigScience BLOOM</a> adapted to the <a href="https://huggingface.co/malteos/bloom-6b4-clp-german">German language by Malte Ostendorff</a>. IGEL designed to provide accurate and reliable language understanding capabilities for a wide range of natural language understanding tasks, including sentiment analysis, language translation, and question answering.
130
-
131
- The IGEL family includes instruction [instruct-igel-001](https://huggingface.co/philschmid/instruct-igel-001) and `chat-igel-001` _coming soon_.
132
- </p>
133
  """
134
  )
135
  with gr.Row():
 
1
+ import json
2
  import os
3
+ from threading import Thread
4
+
5
  import gradio as gr
 
6
  import torch
 
7
  from huggingface_hub import Repository
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ GenerationConfig, TextIteratorStreamer)
10
 
11
  theme = gr.themes.Monochrome(
12
  primary_hue="indigo",
 
18
  # filesystem to save input and outputs
19
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
20
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
+ # if HF_TOKEN:
22
+ # repo = Repository(
23
+ # local_dir="data", clone_from="philschmid/playground-prompts", use_auth_token=HF_TOKEN, repo_type="dataset"
24
+ # )
25
 
26
 
27
  # Load peft config for pre-trained checkpoint etc.
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ model_id = "HuggingFaceH4/llama-se-rl-ed"
30
  if device == "cpu":
31
  model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True)
32
  else:
 
36
 
37
  tokenizer = AutoTokenizer.from_pretrained(model_id)
38
 
39
+ PROMPT_TEMPLATE = """Question: {prompt}\n\nAnswer: """
40
 
41
 
42
  def generate(instruction, temperature, max_new_tokens, top_p, length_penalty):
43
+ formatted_instruction = PROMPT_TEMPLATE.format(input=instruction)
44
  # COMMENT IN FOR NON STREAMING
45
  # generation_config = GenerationConfig(
46
  # do_sample=True,
 
67
 
68
  # streaming
69
  streamer = TextIteratorStreamer(tokenizer)
70
+ model_inputs = tokenizer(formatted_instruction, return_tensors="pt", truncation=True, max_length=2048).to(device)
 
 
71
 
72
  generate_kwargs = dict(
73
  top_p=top_p,
 
93
  new_text = new_text.replace(tokenizer.eos_token, "")
94
  output += new_text
95
  yield output
96
+ # if HF_TOKEN:
97
+ # save_inputs_and_outputs(formatted_instruction, output, generate_kwargs)
98
  return output
99
 
100
 
101
+ # def save_inputs_and_outputs(inputs, outputs, generate_kwargs):
102
+ # with open(os.path.join("data", "prompts.jsonl"), "a") as f:
103
+ # json.dump({"inputs": inputs, "outputs": outputs, "generate_kwargs": generate_kwargs}, f, ensure_ascii=False)
104
+ # f.write("\n")
105
+ # commit_url = repo.push_to_hub()
106
 
107
 
108
  examples = [
 
124
  with gr.Blocks(theme=theme) as demo:
125
  with gr.Column():
126
  gr.Markdown(
127
+ """<h1><center>🦙🦙🦙 StackLLaMa 🦙🦙🦙</center></h1>
128
+
129
+ StackLLaMa is a 7 billion parameter language model that has been trained on pairs of programming questions and answers from [Stack Overflow](https://stackoverflow.com) using Reinforcement Learning from Human Feedback (RLHF) with the [TRL library](https://github.com/lvwerra/trl). For more details, check out our blog post [ADD LINK].
130
+
131
+ Type in the box below and click the button to generate answers to your most pressing coding questions 🔥!
 
132
  """
133
  )
134
  with gr.Row():