File size: 5,417 Bytes
b5f6465
 
 
 
 
 
 
 
429022a
 
b5f6465
6fd8787
b5f6465
6fd8787
b5f6465
6fd8787
b5f6465
6fd8787
 
de665e6
 
6fd8787
 
 
b5f6465
 
 
 
6fd8787
b5f6465
 
6fd8787
b5f6465
 
 
 
 
6fd8787
b5f6465
 
 
6fd8787
b5f6465
 
 
 
6fd8787
b5f6465
 
 
6fd8787
b5f6465
 
 
 
 
 
 
 
 
 
6fd8787
b5f6465
 
 
 
 
 
 
429022a
 
439ffd3
b5f6465
4be8834
b5f6465
 
 
 
 
 
 
 
 
 
 
 
 
429022a
b5f6465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88ac7ef
b5f6465
 
6fd8787
 
b5f6465
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import streamlit as st
import time

from better_transformer import *

def main():

    # Enable CUDA if available and load in tokenizer
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    tokenizer, EMPTY_TOKENS = load_tokenizer(device)

    st.title("Short Story Transformer Demo")
    st.subheader("UCLA DSU Project, Fall 2023")
    st.markdown("By Daniel Mendelevitch, Terry Ming, Casey Tattersall, Sean Tjoa")

    st.header("Data and Training")
    
    st.markdown("""We used the dataset from the [TinyStories Research Paper](https://arxiv.org/pdf/2305.07759.pdf) (Ronen Eldan and Yuanzhi Li, Microsoft), 
    which consists of 2.1 million synthetic short children's stories generated by GPT-4, to train a Transformer LLM that we built from scratch in PyTorch.""")
    st.markdown("""Our final model uses EleutherAI's [gpt-neo-1.3B tokenizer](https://huggingface.co/EleutherAI/gpt-neo-1.3B) (vocab size 50,257) and consists of 8 transformer blocks, 
    16 attention heads, and an embedding dimension of 768, for a total of ~56M non-embedding parameters. The model was trained on 8 H100 GPUs for ~7 hours, achieving a cross-entropy validation loss of 1.16,
    which is superior to any model in the TinyStories paper (likely due to a larger vocab size and far more compute).""")
    st.markdown("""Despite the simple themes and limited vocabulary present in the training data, the model is
    quite effective at generating new short stories. **Try it out below!**""")

    st.header("Let's make some stories! 📖")

    # Input from user
    user_input = st.text_input("Enter your prompt:", placeholder="Write a prompt to make a story of your own, or leave it empty for a random story!").strip()

    if st.checkbox("Show Prompting Tips"):
        st.markdown("The model can struggle with some prompts, especially those outside of its limited domain. If a response isn't satisfactory, try repeating the generation, or make the following modifications:")
        st.markdown(
            """
            - Use simple vocabulary - words and themes that would appear in a children's story
            - Avoid using idioms - for example, instead of "hit the gym", say "went to the gym"
            - Include plenty of descriptive adjectives
            - The model often struggles with names. **Using common names and sticking with first names only can help.**
            """
        )
    ## Default values for advanced settings
    user_seed = None # Remove if we're not rigging the "random" demo
    generation_method = "top-k"
    specified_k = 5
    specified_nucleus = 0.5
    specified_temperature = 0.9
    max_tokens = 750

    if st.checkbox("Show Advanced Settings"):
        user_seed = st.number_input("Randomness Seed:", value = None, step = 1, placeholder="Use to replicate response", min_value = 1)
        generation_method = st.selectbox("Method of Generation:", ("top-k", "nucleus", "temperature", "multinomial", "greedy"), index = 0).strip()

        if generation_method == "top-k":
            specified_k = st.number_input("Value for k:", value = 5, step = 1)

        if generation_method == "nucleus":
            specified_nucleus = st.number_input("Value for k:", value = 0.5, step = 0.05, min_value = 0.0, max_value = 1.0)

        if generation_method == "temperature":
            specified_temperature = st.number_input("Value for temperature:", value = 0.9, step = 0.05, min_value = 0.0, max_value = 1.0)

        max_tokens = st.slider('Max Tokens Generated:', 50, 750, 750)





    # model_version = st.radio("Which model would you like to use?", ["smoll", "beeg"])
    # small_model = load_casey_model(tokenizer, device)
    model = load_big_model(tokenizer, device)
    model.to('cuda')
    model.cuda()


    if st.button('Write my story!'):
        placeholder = st.empty()
        # if model_version == 'smoll':
        #     model = load_casey_model(tokenizer, device)
        # elif model_version == 'beeg':
        #     model = load_big_model(tokenizer, device)
        # with placeholder.container():
        #     st.write("Model Loaded! Preparing to Generate...")


        

        with st.spinner(""):
            result = generate(model, tokenizer, device, method=generation_method, k=specified_k, 
                            p_nucleus=specified_nucleus, temp=specified_temperature, max_new_tokens=max_tokens, 
                            cond=user_input, deterministic=user_seed)

        streamed_input = ""
        for word in user_input.split(' '):
            streamed_input += word
            with placeholder.container():
                st.markdown(f"**{streamed_input}**")
            streamed_input += " "
            time.sleep(0.1)

        if user_input != "": ##conditional
            result = result[len(user_input) + 3 :]
            streamed_result = f"**{streamed_input[:-1]}**"
            time.sleep(1)
        else: ##unconditional
            streamed_result = ""


        for word in result.split(' '):
            streamed_result += word + ' '
            with placeholder.container():
                st.write(streamed_result)
            time.sleep(0.1)
        if st.button('Clear Output'):
            placeholder = st.empty()

    st.markdown('####')
    st.caption(r'Data Attribution: Tinystories (License: CDLA-Sharing-1.0)  https://arxiv.org/abs/2305.07759')


if __name__ == "__main__":
    main()