lvwerra HF staff commited on
Commit
ec23712
1 Parent(s): f1751f2

stack-llama-2 (#37)

Browse files

- update to stack-llama-v2 (0bc403488329943436e326a81e639c0d4efbaf54)

Files changed (1) hide show
  1. app.py +12 -9
app.py CHANGED
@@ -3,13 +3,13 @@ import os
3
  import shutil
4
 
5
  import gradio as gr
6
- from huggingface_hub import Repository
7
  from text_generation import Client
8
 
9
  from share_btn import community_icon_html, loading_icon_html, share_js, share_btn_css
10
 
11
  HF_TOKEN = os.environ.get("TRL_TOKEN", None)
12
- API_URL = os.environ.get("API_URL")
13
 
14
 
15
  theme = gr.themes.Monochrome(
@@ -25,10 +25,15 @@ if HF_TOKEN:
25
  except:
26
  pass
27
 
28
- repo = Repository(
29
- local_dir="./data/", clone_from="trl-lib/stack-llama-prompts", use_auth_token=HF_TOKEN, repo_type="dataset"
 
 
 
 
 
30
  )
31
- repo.git_pull()
32
 
33
  client = Client(
34
  API_URL,
@@ -42,8 +47,6 @@ def save_inputs_and_outputs(inputs, outputs, generate_kwargs):
42
  with open(os.path.join("data", "prompts.jsonl"), "a") as f:
43
  json.dump({"inputs": inputs, "outputs": outputs, "generate_kwargs": generate_kwargs}, f, ensure_ascii=False)
44
  f.write("\n")
45
- commit_url = repo.push_to_hub()
46
-
47
 
48
  def generate(instruction, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, do_save=True):
49
  formatted_instruction = PROMPT_TEMPLATE.format(prompt=instruction)
@@ -106,11 +109,11 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
106
  """![](https://huggingface.co/spaces/trl-lib/stack-llama/resolve/main/stackllama_logo.png)
107
 
108
 
109
- StackLLaMa is a 7 billion parameter language model based on [Meta's LLaMA model](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) that has been trained on pairs of questions and answers from [Stack Exchange](https://stackexchange.com) using Reinforcement Learning from Human Feedback (RLHF) with the [TRL library](https://github.com/lvwerra/trl). For more details, check out our [blog post](https://huggingface.co/blog/stackllama).
110
 
111
  Type in the box below and click the button to generate answers to your most pressing questions!
112
 
113
- ⚠️ **Intended Use**: this app and its [supporting model](https://huggingface.co/trl-lib/llama-7b-se-rl-peft) are provided as educational tools to explain RLHF with the TRL library; not to serve as replacement for human expertise. For more details on the model's limitations in terms of factuality and biases, see the [model card.](https://huggingface.co/trl-lib/llama-7b-se-rl-peft#intended-uses--limitations)
114
 
115
  ⚠️ **Data Collection**: by default, we are collecting the prompts entered in this app to further improve and evaluate the model. Do not share any personal or sensitive information while using the app! You can opt out of this data collection by removing the checkbox below:
116
  """
 
3
  import shutil
4
 
5
  import gradio as gr
6
+ from huggingface_hub import Repository, CommitScheduler
7
  from text_generation import Client
8
 
9
  from share_btn import community_icon_html, loading_icon_html, share_js, share_btn_css
10
 
11
  HF_TOKEN = os.environ.get("TRL_TOKEN", None)
12
+ API_URL = "https://api-inference.huggingface.co/models/kashif/stack-llama-2"
13
 
14
 
15
  theme = gr.themes.Monochrome(
 
25
  except:
26
  pass
27
 
28
+ # Schedule regular uploads every 10 minutes. Remote repo and local folder are created if they don't already exist.
29
+ scheduler = CommitScheduler(
30
+ repo_id="trl-lib/stack-llama-2-prompts",
31
+ repo_type="dataset",
32
+ folder_path="./data/",
33
+ path_in_repo="./",
34
+ every=10,
35
  )
36
+
37
 
38
  client = Client(
39
  API_URL,
 
47
  with open(os.path.join("data", "prompts.jsonl"), "a") as f:
48
  json.dump({"inputs": inputs, "outputs": outputs, "generate_kwargs": generate_kwargs}, f, ensure_ascii=False)
49
  f.write("\n")
 
 
50
 
51
  def generate(instruction, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, do_save=True):
52
  formatted_instruction = PROMPT_TEMPLATE.format(prompt=instruction)
 
109
  """![](https://huggingface.co/spaces/trl-lib/stack-llama/resolve/main/stackllama_logo.png)
110
 
111
 
112
+ StackLLaMa-2 is a 7 billion parameter language model based on [Meta's LLaMA 2 model](https://ai.meta.com/llama/) that has been trained on pairs of questions and answers from [Stack Exchange](https://stackexchange.com) using Direct Preference Optimization (DPO) with the [TRL library](https://github.com/lvwerra/trl). For more details, check out our [blog post](https://huggingface.co/blog/dpo-trl).
113
 
114
  Type in the box below and click the button to generate answers to your most pressing questions!
115
 
116
+ ⚠️ **Intended Use**: this app and its [supporting model](https://huggingface.co/kashif/stack-llama-2) are provided as educational tools to explain RLHF with the TRL library; not to serve as replacement for human expertise. For more details on the model's limitations in terms of factuality and biases, see the [model card.](https://huggingface.co/kashif/stack-llama-2#intended-uses--limitations)
117
 
118
  ⚠️ **Data Collection**: by default, we are collecting the prompts entered in this app to further improve and evaluate the model. Do not share any personal or sensitive information while using the app! You can opt out of this data collection by removing the checkbox below:
119
  """