Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| model_names = { | |
| "gpt2-medium":"gpt2-medium", | |
| "eluether1.3b":"EleutherAI/gpt-neo-1.3B", | |
| } | |
| def generate_texts(pipeline, input_text, **generator_args): | |
| output_sequences = pipeline( | |
| input_text, **generator_args | |
| ) | |
| return output_sequences | |
| def load_tokenizer(model_name): | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| return tokenizer | |
| def load_model(model_name,eos_token_id): | |
| model = AutoModelForCausalLM.from_pretrained(model_name,pad_token_id=eos_token_id) | |
| return model | |
| tokenizers = {model_name:load_tokenizer(model_name) for model_name in model_names.values()} | |
| pipelines = [pipeline('text-generation', model=load_model(model_name,tokenizers[model_name].eos_token_id), tokenizer=tokenizers[model_name]) for model_name in model_names.values()] | |
| print("loaded the pipelines") | |
| default_value = "But not just any roof cleaning will do." | |
| #prompts | |
| st.title("Text Extension or Generation") | |
| st.write("Command + Enter for generation...") | |
| sent = st.text_area("Text", default_value, height = 150) | |
| #generate_button = st.button("Generate") | |
| model_index = st.sidebar.selectbox("Select Model", range(len(model_names)), format_func=lambda x: list(model_names.keys())[x]) | |
| max_length = st.sidebar.slider("Max Length", value = 100, min_value = 30, max_value=256) | |
| temperature = st.sidebar.slider("Temperature", value = 1.0, min_value = 0.0, max_value=1.0, step=0.05) | |
| num_return_sequences = st.sidebar.slider("Num Return Sequences", min_value = 1, max_value=4, value = 1) | |
| num_beams = st.sidebar.slider("Num Beams", min_value = 2, max_value=6, value = 4) | |
| top_k = st.sidebar.slider("Top-k", min_value = 0, max_value=100, value = 90) | |
| top_p = st.sidebar.slider("Top-p", min_value = 0.4, max_value=1.0, step = 0.05, value = 0.9) | |
| repetition_penalty = st.sidebar.slider("Repetition-Penalty", min_value = 0.45, max_value=2.0, step = 0.1, value = 1.2) | |
| if len(sent)<10: | |
| st.write("Input prompt is too small to generate") | |
| else: | |
| print(sent) | |
| st.write(f"Generating for prompt {sent}....") | |
| output_sequences = generate_texts(pipelines[model_index], | |
| sent, | |
| max_length=max_length, | |
| num_return_sequences=num_return_sequences, | |
| num_beams=num_beams, | |
| temperature=temperature, | |
| top_k=top_k, | |
| no_repeat_ngram_size=2, | |
| repetition_penalty = repetition_penalty, | |
| early_stopping=False, | |
| top_p=top_p) | |
| st.write(output_sequences) | |