leandro commited on
Commit
1c022e5
1 Parent(s): 891f0a3

add examples

Browse files
Files changed (2) hide show
  1. app.py +13 -2
  2. examples.json +31 -0
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
3
  from transformers import pipeline
 
4
 
5
  @st.cache(allow_output_mutation=True)
6
  def load_tokenizer(model_ckpt):
@@ -11,6 +12,12 @@ def load_model(model_ckpt):
11
  model = AutoModelForCausalLM.from_pretrained(model_ckpt)
12
  return model
13
 
 
 
 
 
 
 
14
  st.set_page_config(page_icon=':parrot:', layout="wide")
15
 
16
  default_code = '''\
@@ -20,6 +27,7 @@ def print_hello_world():\
20
  model_ckpt = "lvwerra/codeparrot"
21
  tokenizer = load_tokenizer(model_ckpt)
22
  model = load_model(model_ckpt)
 
23
  set_seed(42)
24
  gen_kwargs = {}
25
 
@@ -27,15 +35,18 @@ st.title("CodeParrot 🦜")
27
  st.markdown('##')
28
 
29
  pipe = pipeline('text-generation', model=model, tokenizer=tokenizer)
 
 
 
30
  st.sidebar.header("Generation settings:")
31
  gen_kwargs["do_sample"] = st.sidebar.radio("Decoding strategy", ["Greedy", "Sample"]) == "Sample"
32
- gen_kwargs["max_new_tokens"] = st.sidebar.slider("Number of tokens to generate", value=16, min_value=8, step=8, max_value=256)
33
  if gen_kwargs["do_sample"]:
34
  gen_kwargs["temperature"] = st.sidebar.slider("Temperature", value = 0.2, min_value = 0.0, max_value=2.0, step=0.05)
35
  gen_kwargs["top_k"] = st.sidebar.slider("Top-k", min_value = 0, max_value=100, value = 0)
36
  gen_kwargs["top_p"] = st.sidebar.slider("Top-p", min_value = 0.0, max_value=1.0, step = 0.01, value = 0.95)
37
 
38
- gen_prompt = st.text_area("Generate code with prompt:", value=default_code, height=220,).strip()
39
  if st.button("Generate code!"):
40
  with st.spinner("Generating code..."):
41
  generated_text = pipe(gen_prompt, **gen_kwargs)[0]['generated_text']
 
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
3
  from transformers import pipeline
4
+ import json
5
 
6
  @st.cache(allow_output_mutation=True)
7
  def load_tokenizer(model_ckpt):
 
12
  model = AutoModelForCausalLM.from_pretrained(model_ckpt)
13
  return model
14
 
15
+ @st.cache()
16
+ def load_examples():
17
+ with open("examples.json", "r") as f:
18
+ examples = json.load(f)
19
+ return dict([(x["name"], x["value"]) for x in examples])
20
+
21
  st.set_page_config(page_icon=':parrot:', layout="wide")
22
 
23
  default_code = '''\
 
27
  model_ckpt = "lvwerra/codeparrot"
28
  tokenizer = load_tokenizer(model_ckpt)
29
  model = load_model(model_ckpt)
30
+ examples = load_examples()
31
  set_seed(42)
32
  gen_kwargs = {}
33
 
 
35
  st.markdown('##')
36
 
37
  pipe = pipeline('text-generation', model=model, tokenizer=tokenizer)
38
+ st.sidebar.header("Examples:")
39
+ selected_example = st.sidebar.selectbox("Select one of the following examples:", examples.keys())
40
+ example_text = examples[selected_example]
41
  st.sidebar.header("Generation settings:")
42
  gen_kwargs["do_sample"] = st.sidebar.radio("Decoding strategy", ["Greedy", "Sample"]) == "Sample"
43
+ gen_kwargs["max_new_tokens"] = st.sidebar.slider("Number of tokens to generate", value=32, min_value=8, step=8, max_value=256)
44
  if gen_kwargs["do_sample"]:
45
  gen_kwargs["temperature"] = st.sidebar.slider("Temperature", value = 0.2, min_value = 0.0, max_value=2.0, step=0.05)
46
  gen_kwargs["top_k"] = st.sidebar.slider("Top-k", min_value = 0, max_value=100, value = 0)
47
  gen_kwargs["top_p"] = st.sidebar.slider("Top-p", min_value = 0.0, max_value=1.0, step = 0.01, value = 0.95)
48
 
49
+ gen_prompt = st.text_area("Generate code with prompt:", value=example_text, height=220,).strip()
50
  if st.button("Generate code!"):
51
  with st.spinner("Generating code..."):
52
  generated_text = pipe(gen_prompt, **gen_kwargs)[0]['generated_text']
examples.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "name": "Hello World!",
4
+ "value": "def print_hello_world():\n \"\"\"Print 'Hello World!'.\"\"\""
5
+ },
6
+ {
7
+ "name": "Filesize",
8
+ "value": "def get_file_size(filepath):"
9
+ },
10
+ {
11
+ "name": "Python to Numpy",
12
+ "value": "# calculate mean in native Python:\ndef mean(a):\n return sum(a)/len(a)\n\n# calculate mean numpy:\nimport numpy as np\n\ndef mean(a):"
13
+ },
14
+ {
15
+ "name": "unittest",
16
+ "value": "def is_even(value):\n \"\"\"Returns True if value is an even number.\"\"\"\n return value % 2 == 0\n\n# setup unit tests for is_even\nimport unittest"
17
+
18
+ },
19
+ {
20
+ "name": "Scikit-Learn",
21
+ "value": "import numpy as np\nfrom sklearn.ensemble import RandomForestClassifier\n\n# create training data\nX = np.random.randn(100, 100)\ny = np.random.randint(0, 1, 100)\n\n# setup train test split"
22
+ },
23
+ {
24
+ "name": "Pandas",
25
+ "value": "# load dataframe from csv\ndf = pd.read_csv(filename)\n\n# columns: \"age_group\", \"income\"\n# calculate average income per age group"
26
+ },
27
+ {
28
+ "name": "Transformers",
29
+ "value": "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n\n# build a BERT classifier"
30
+ }
31
+ ]