File size: 2,189 Bytes
562d551
4b85d9d
562d551
39e12e7
1c022e5
562d551
 
e804387
 
 
 
 
 
d5f0a3f
e804387
562d551
1c022e5
 
 
dc9a7be
1c022e5
562d551
 
0f81a78
e804387
 
dee5e8a
e804387
1c022e5
dc9a7be
 
4b85d9d
562d551
 
 
 
 
1c022e5
dc9a7be
 
 
562d551
 
dc9a7be
562d551
cf94a33
562d551
 
1c022e5
562d551
27ff797
 
562d551
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from transformers import pipeline
import torch
import json


@st.cache(allow_output_mutation=True)
def load_tokenizer(model_ckpt):
    return AutoTokenizer.from_pretrained(model_ckpt)

@st.cache(allow_output_mutation=True)
def load_model(model_ckpt):
    model = AutoModelForCausalLM.from_pretrained(model_ckpt, low_cpu_mem_usage=True)
    return model

def load_examples():
    with open("examples.json", "r") as f:
        examples = json.load(f)
    return examples

st.set_page_config(page_icon=':parrot:', layout="wide")

model_ckpt = "codeparrot/codeparrot"
tokenizer = load_tokenizer(model_ckpt)
model = load_model(model_ckpt)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

examples = load_examples()
example_names = [example["name"] for example in examples]
name2id = dict([(name, i) for i, name in enumerate(example_names)])
set_seed(42)
gen_kwargs = {}

st.title("CodeParrot 🦜")
st.markdown('##')

st.sidebar.header("Examples:")
selected_example = st.sidebar.selectbox("Select one of the following examples:", example_names)
example_text = examples[name2id[selected_example]]["value"]
default_length = examples[name2id[selected_example]]["length"]
st.sidebar.header("Generation settings:")
gen_kwargs["do_sample"] = st.sidebar.radio("Decoding strategy",  ["Greedy", "Sample"]) == "Sample"
gen_kwargs["max_new_tokens"] = st.sidebar.slider("Number of tokens to generate", value=default_length, min_value=8, step=8, max_value=256)
if gen_kwargs["do_sample"]:
    gen_kwargs["temperature"] = st.sidebar.slider("Temperature", value = 0.2, min_value = 0.0, max_value=2.0, step=0.05)
    gen_kwargs["top_k"] = st.sidebar.slider("Top-k", min_value = 0, max_value=100, value = 0)
    gen_kwargs["top_p"] = st.sidebar.slider("Top-p", min_value = 0.0, max_value=1.0, step = 0.01, value = 0.95)
gen_prompt = st.text_area("Generate code with prompt:", value=example_text, height=220,).strip()
if st.button("Generate code!"):
    with st.spinner("Generating code..."):
        generated_text = pipe(gen_prompt, **gen_kwargs)[0]['generated_text']
    st.code(generated_text)