ncoop57
fix bug with newlines not showing up and causing model errors
1d3f8ed
raw history blame
No virus
4.82 kB
import urllib
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
# model_name = "flax-community/gpt-neo-1.3B-apps-all"
model_name = "flax-community/gpt-neo-125M-apps-all"
@st.cache(allow_output_mutation=True, max_entries=1)
def get_model():
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
return (model, tokenizer)
def format_input(question, starter_code=""):
answer_type = (
"\nUse Call-Based format\n" if starter_code else "\nUse Standard Input format\n"
)
return f"\nQUESTION:\n{question}\n{starter_code}\n{answer_type}\nANSWER:\n"
def clean_text(generation):
# clean up text has discussed in OpenAI's paper "Evaluating Large Language Models Trained on Code"
generation = generation.split("\ndef")[0]
generation = generation.split("\nclass")[0]
generation = generation.split("\n#")[0]
generation = generation.split("\nif")[0]
return generation
def generate_solution(
model, tokenizer, question, starter_code="", temperature=1.0, num_beams=1
):
prompt = format_input(question, starter_code)
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
start = len(input_ids[0])
output = model.generate(
input_ids,
max_length=start + 150,
do_sample=True,
top_p=0.95,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
early_stopping=True,
temperature=temperature,
num_beams=int(num_beams),
no_repeat_ngram_size=None,
repetition_penalty=None,
num_return_sequences=None,
)
output_str = tokenizer.decode(output[0][start:], skip_special_tokens=True).strip()
output_str = clean_text(output_str)
return output_str
_EXAMPLES = [
[
"""
Given a 2D list of size `m * n`. Your task is to find the sum of minimum value in each row.
For Example:
```python
[
[1, 2, 3, 4, 5], # minimum value of row is 1
[5, 6, 7, 8, 9], # minimum value of row is 5
[20, 21, 34, 56, 100] # minimum value of row is 20
]
```
So, the function should return `26` because sum of minimums is as `1 + 5 + 20 = 26`
""",
"",
0.8,
],
[
"""
# Personalized greeting
Create a function that gives a personalized greeting. This function takes two parameters: `name` and `owner`.
""",
"""
Use conditionals to return the proper message:
case| return
--- | ---
name equals owner | 'Hello boss'
otherwise | 'Hello guest'
def greet(name, owner):
""",
0.8,
],
]
def run():
st.set_page_config(page_title="Code Clippy Problem Solver")
# sidebar
st.sidebar.title("Code Clippy")
st.sidebar.image(
"https://raw.githubusercontent.com/ncoop57/gpt-code-clippy/camera-ready/code_clippy_logo.jpg",
caption="(c) awesome Aimee Trevett",
)
st.sidebar.markdown("[Github](https://github.com/ncoop57/gpt-code-clippy)")
st.sidebar.markdown("[Report](https://github.com/ncoop57/gpt-code-clippy/wiki)")
st.sidebar.markdown("### Controls:")
temperature = st.sidebar.slider(
"Temperature",
min_value=0.5,
max_value=1.5,
value=0.8,
step=0.1,
)
num_beams = st.sidebar.slider(
"Num beams",
min_value=1,
max_value=4,
step=1,
)
# main body
model, tokenizer = get_model()
question = st.text_input(
"Problem: ",
value="A function that can greet user by name. Given a name it should say hello to user.",
help="Text description of the coding problem to be solved",
)
starter_code = st.text_input(
"Started code: ", value="def greet(name):", help="Optional starter code"
)
submit_button = st.button("Solve")
if submit_button:
text = st.text("Generating solution...")
# gif from https://giphy.com/gifs/alan-DfSXiR60W9MVq
gif_runner = st.image("./loading.gif")
output = generate_solution(
model, tokenizer, question, starter_code, temperature, num_beams
)
text.empty()
gif_runner.empty()
st.text("Solution:")
st.code(output, language="python")
# Create link to carbon to make a nice screenshot of the generated code
url_code = urllib.parse.quote(f"# {question}\n{output}")
st.markdown(
f"[Would you like a Carbon Copy?](https://carbon.now.sh/?bg=rgba%280%2C0%2C0%2C0%29&t=seti&wt=none&l=python&ds=false&dsyoff=20px&dsblur=68px&wc=true&wa=false&pv=56px&ph=56px&ln=false&fl=1&fm=Hack&fs=14px&lh=133%25&si=false&es=2x&wm=false&code={url_code})"
)
if __name__ == "__main__":
run()