lvwerra HF staff commited on
Commit
39e12e7
1 Parent(s): dc9a7be

add gpu support

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
3
  from transformers import pipeline
 
4
  import json
5
 
6
  @st.cache(allow_output_mutation=True)
@@ -20,10 +21,8 @@ def load_examples():
20
 
21
  st.set_page_config(page_icon=':parrot:', layout="wide")
22
 
23
- default_code = '''\
24
- def print_hello_world():\
25
- '''
26
-
27
  model_ckpt = "lvwerra/codeparrot"
28
  tokenizer = load_tokenizer(model_ckpt)
29
  model = load_model(model_ckpt)
@@ -36,7 +35,7 @@ gen_kwargs = {}
36
  st.title("CodeParrot 🦜")
37
  st.markdown('##')
38
 
39
- pipe = pipeline('text-generation', model=model, tokenizer=tokenizer)
40
  st.sidebar.header("Examples:")
41
  selected_example = st.sidebar.selectbox("Select one of the following examples:", example_names)
42
  example_text = examples[name2id[selected_example]]["value"]
@@ -48,7 +47,7 @@ if gen_kwargs["do_sample"]:
48
  gen_kwargs["temperature"] = st.sidebar.slider("Temperature", value = 0.2, min_value = 0.0, max_value=2.0, step=0.05)
49
  gen_kwargs["top_k"] = st.sidebar.slider("Top-k", min_value = 0, max_value=100, value = 0)
50
  gen_kwargs["top_p"] = st.sidebar.slider("Top-p", min_value = 0.0, max_value=1.0, step = 0.01, value = 0.95)
51
-
52
  gen_prompt = st.text_area("Generate code with prompt:", value=example_text, height=220,).strip()
53
  if st.button("Generate code!"):
54
  with st.spinner("Generating code..."):
 
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
3
  from transformers import pipeline
4
+ import torch
5
  import json
6
 
7
  @st.cache(allow_output_mutation=True)
 
21
 
22
  st.set_page_config(page_icon=':parrot:', layout="wide")
23
 
24
+ device = 1 if torch.cuda.is_available() else 0
25
+ device_name = "GPU" if torch.cuda.is_available() else "CPU"
 
 
26
  model_ckpt = "lvwerra/codeparrot"
27
  tokenizer = load_tokenizer(model_ckpt)
28
  model = load_model(model_ckpt)
 
35
  st.title("CodeParrot 🦜")
36
  st.markdown('##')
37
 
38
+ pipe = pipeline('text-generation', model=model, tokenizer=tokenizer, device=device)
39
  st.sidebar.header("Examples:")
40
  selected_example = st.sidebar.selectbox("Select one of the following examples:", example_names)
41
  example_text = examples[name2id[selected_example]]["value"]
 
47
  gen_kwargs["temperature"] = st.sidebar.slider("Temperature", value = 0.2, min_value = 0.0, max_value=2.0, step=0.05)
48
  gen_kwargs["top_k"] = st.sidebar.slider("Top-k", min_value = 0, max_value=100, value = 0)
49
  gen_kwargs["top_p"] = st.sidebar.slider("Top-p", min_value = 0.0, max_value=1.0, step = 0.01, value = 0.95)
50
+ st.sidebar.markdown(f"Device: _{device_name}_")
51
  gen_prompt = st.text_area("Generate code with prompt:", value=example_text, height=220,).strip()
52
  if st.button("Generate code!"):
53
  with st.spinner("Generating code..."):