lvwerra HF staff commited on
Commit
e804387
1 Parent(s): b212dd9

modify loading logic to match highlighting

Browse files
Files changed (1) hide show
  1. app.py +12 -8
app.py CHANGED
@@ -5,10 +5,14 @@ import torch
5
  import json
6
 
7
 
8
- #@st.cache(allow_output_mutation=True)
9
- def load_pipeline(model_ckpt, device):
10
- pipe = pipeline('text-generation', model=model_ckpt, device=device)
11
- return pipe
 
 
 
 
12
 
13
  @st.cache()
14
  def load_examples():
@@ -18,10 +22,11 @@ def load_examples():
18
 
19
  st.set_page_config(page_icon=':parrot:', layout="wide")
20
 
21
- device = 0 if torch.cuda.is_available() else -1
22
- device_name = "GPU" if torch.cuda.is_available() else "CPU"
23
  model_ckpt = "lvwerra/codeparrot"
24
- pipe = load_pipeline(model_ckpt, device)
 
 
 
25
  examples = load_examples()
26
  example_names = [example["name"] for example in examples]
27
  name2id = dict([(name, i) for i, name in enumerate(example_names)])
@@ -42,7 +47,6 @@ if gen_kwargs["do_sample"]:
42
  gen_kwargs["temperature"] = st.sidebar.slider("Temperature", value = 0.2, min_value = 0.0, max_value=2.0, step=0.05)
43
  gen_kwargs["top_k"] = st.sidebar.slider("Top-k", min_value = 0, max_value=100, value = 0)
44
  gen_kwargs["top_p"] = st.sidebar.slider("Top-p", min_value = 0.0, max_value=1.0, step = 0.01, value = 0.95)
45
- st.sidebar.markdown(f"Device: _{device_name}_")
46
  gen_prompt = st.text_area("Generate code with prompt:", value=example_text, height=220,).strip()
47
  if st.button("Generate code!"):
48
  with st.spinner("Generating code..."):
 
5
  import json
6
 
7
 
8
+ @st.cache(allow_output_mutation=True)
9
+ def load_tokenizer(model_ckpt):
10
+ return AutoTokenizer.from_pretrained(model_ckpt)
11
+
12
+ @st.cache(allow_output_mutation=True)
13
+ def load_model(model_ckpt):
14
+ model = AutoModelForCausalLM.from_pretrained(model_ckpt)
15
+ return model
16
 
17
  @st.cache()
18
  def load_examples():
 
22
 
23
  st.set_page_config(page_icon=':parrot:', layout="wide")
24
 
 
 
25
  model_ckpt = "lvwerra/codeparrot"
26
+ tokenizer = load_tokenizer(model_ckpt)
27
+ model = load_model(model_ckpt)
28
+ pipe = pipeline(model=model, tokenizer=tokenizer)
29
+
30
  examples = load_examples()
31
  example_names = [example["name"] for example in examples]
32
  name2id = dict([(name, i) for i, name in enumerate(example_names)])
 
47
  gen_kwargs["temperature"] = st.sidebar.slider("Temperature", value = 0.2, min_value = 0.0, max_value=2.0, step=0.05)
48
  gen_kwargs["top_k"] = st.sidebar.slider("Top-k", min_value = 0, max_value=100, value = 0)
49
  gen_kwargs["top_p"] = st.sidebar.slider("Top-p", min_value = 0.0, max_value=1.0, step = 0.01, value = 0.95)
 
50
  gen_prompt = st.text_area("Generate code with prompt:", value=example_text, height=220,).strip()
51
  if st.button("Generate code!"):
52
  with st.spinner("Generating code..."):