File size: 7,295 Bytes
b27b5dc
 
 
 
 
 
e99a5e3
b27b5dc
 
 
 
 
 
 
8cc6050
 
b27b5dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cc6050
 
b27b5dc
 
 
8cc6050
 
 
b27b5dc
 
 
 
d1ec534
 
b27b5dc
 
 
 
 
 
 
 
 
 
 
 
82905a4
b27b5dc
e99a5e3
b27b5dc
 
 
d1ec534
b27b5dc
 
 
 
 
 
8cc6050
 
 
 
 
a2d0bc0
d1ec534
 
8cc6050
 
40b24ca
d1ec534
188ee80
 
95f4561
40b24ca
8cc6050
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95f4561
 
188ee80
 
 
8cc6050
 
 
 
 
 
 
 
 
e99a5e3
8cc6050
 
 
d05bec7
8cc6050
 
 
d05bec7
8cc6050
3da6c21
 
 
 
8cc6050
 
e99a5e3
8cc6050
5a35674
b27b5dc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155

import gradio as gr
import transformers
import os
import re
import json
import random

device = "cpu"

model = None
tokenizer = None

def init_model():
    global model, tokenizer

    model_id = os.environ.get("MODEL_ID") or "treadon/prompt-fungineer-355M"
    auth_token = os.environ.get("HUB_TOKEN") or True

    print(f"Using model {model_id}.")

    if auth_token != True:
        print("Using auth token.")

    model = transformers.AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True,use_auth_token=auth_token)
    tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")


def format_prompt(prompt, enhancers=True, inspiration=False, negative_prompt=False):
    try:
        pattern = r"(BRF:|POS:|ENH:|INS:|NEG:) (.*?)(?= (BRF:|POS:|ENH:|INS:|NEG:)|$)"
        matches = re.findall(pattern, prompt)
        vals = {key: value.strip() for key, value,ex in matches}
        result = vals["POS:"]
        if enhancers:
            result += " " + vals["ENH:"]
        if inspiration:
            result += " " + vals["INS:"]
        if negative_prompt:
            result += "\n\n--no " + vals["NEG:"]

        return result
    except Exception as e:
        return "Failed to generate prompt."

    
def generate_text(prompt, extra=False, top_k=100, top_p=0.95, temperature=0.85, enhancers = True, inpspiration = False , negative_prompt = False):
    global model, tokenizer
    
    try:
        if model is None:
            init_model()
    except Exception as e:
        print(e)
        return ["Try Again"] * 4

    if model is None:
        return ["Try Again"] * 4
    
    prompt = prompt.strip()
    
    if not prompt.startswith("BRF:"):
        prompt = "BRF: " + prompt

    if not extra:
        prompt = prompt + " POS:"

    model.eval()
    # SOFT SAMPLE
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    samples = []
    try:
        for i in range(1):
            print(f"Generating sample for prompt: {prompt}")
            outputs = model.generate(**inputs, max_length=256, do_sample=True, top_k=top_k, top_p=top_p, temperature=temperature, num_return_sequences=4, pad_token_id=tokenizer.eos_token_id)
            print(f"Generated {len(outputs)} samples.")
            for output in outputs:
                sample = tokenizer.decode(output, skip_special_tokens=True)
                sample = format_prompt(sample, enhancers, inpspiration, negative_prompt)
                print(f"Sample: {sample}")
                samples.append(sample)
    except Exception as e:
        print(e)

    return samples

if __name__ == "__main__":
    with gr.Blocks() as fungineer:
        with gr.Row():
            gr.Markdown("""# Midjourney / Dalle 2 / Stable Diffusion Prompt Generator
    This is the 355M parameter model.  There is also a 7B parameter model that is much better but far slower (access coming soon).
    Just enter a basic prompt and the fungineering model will use its wildest imagination to expand the prompt in detail.  You can then use this prompt to generate images with Midjourney, Dalle 2, Stable Diffusion, Bing Image Creator, or any other image generation model.  Read more about this project [on my blog post](https://riteshkhanna.com/2023/04/12/image-prompt-generator/).
    ## TIP: Keep the base prompt short and simple.  The model will do the rest.
    """)
        with gr.Row():
            with gr.Column():

                base_prompt = gr.Textbox(lines=1, label="Base Prompt (Shorter is Better)", placeholder="An astronaut in space.", info="Enter a very simple prompt that will be fungineered into something exciting!")
                submit = gr.Button(label="Fungineer",variant="primary")

                extra = gr.Checkbox(value=False, label="Wild Imagination", info="If checked, the model will be allowed to go wild with its imagination.")

                with gr.Accordion("Advanced Generation Settings", open=False):
                    top_k = gr.Slider( minimum=10, maximum=1000, value=100, label="Top K", info="Top K sampling")
                    top_p = gr.Slider( minimum=0.1, maximum=1, value=0.95, step=0.01, label="Top P", info="Top P sampling")
                    temperature = gr.Slider( minimum=0.1, maximum=1.2, value=0.85, step=0.01, label="Temperature", info="Temperature sampling.  Higher values will make the model more creative")

                with gr.Accordion("Advanced Output Settings", open=False):
                    enh = gr.Checkbox(value=True, label="Enhancers", info="Add image meta information such as lens type, shuffter speed, camera model, etc.")
                    insp = gr.Checkbox(value=False, label="Inpsiration", info="Include inspirational photographers that are known for this type of photography.  Sometimes random people will appear here, needs more training.")
                    neg = gr.Checkbox(value=False, label="Negative Prompt", info="Include a negative prompt, more often used in Stable Diffusion.  If you're a Stable Diffusion user, chances are you already have a better negative prompt you like to use.")

            with gr.Column():
                outputs = [
                    gr.Textbox(lines=2, label="Fungineered Text 1"),
                    gr.Textbox(lines=2, label="Fungineered Text 2"),
                    gr.Textbox(lines=2, label="Fungineered Text 3"),
                    gr.Textbox(lines=2, label="Fungineered Text 4"),
                ]

                gr.Markdown("### Got something good? [Share it](https://huggingface.co/spaces/treadon/prompt-fungineer-355M/discussions/1) with the community in the showcase!")

        for textbox in outputs:
            textbox.style(show_copy_button=True)

        inputs = [base_prompt, extra, top_k, top_p, temperature, enh, insp, neg]

        submit.click(generate_text, inputs=inputs, outputs=outputs)

        examples = []
        with open("examples.json") as f:
            examples = json.load(f)

        for i, example in enumerate(examples):
            with gr.Tab(f"Example {i+1}", id=i):
                with gr.Row():
                    with gr.Column():
                        gr.Markdown(f"### Base Prompt")
                        gr.HTML(f"<img src='https://huggingface.co/spaces/treadon/prompt-fungineer-355M/resolve/main/{example['base']['src']}' style='width: 100%; border-radius: 15px; border: 1px solid #444' />")
                        gr.Markdown(f"{example['base']['prompt']}")
                    with gr.Column():
                        gr.Markdown(f"### 355M Prompt Fungineered")
                        gr.HTML(f"<img src='https://huggingface.co/spaces/treadon/prompt-fungineer-355M/resolve/main/{example['355M']['src']}' style='width: 100%; border-radius: 15px; border: 1px solid #444' />")
                        gr.Markdown(f"{example['355M']['prompt']}")
                    with gr.Column():
                        gr.Markdown(f"### 7B Prompt Fungineered")
                        gr.HTML(f"<img src='https://huggingface.co/spaces/treadon/prompt-fungineer-355M/resolve/main/{example['7B']['src']}' style='width: 100%; border-radius: 15px; border: 1px solid #444' />")
                        gr.Markdown(f"{example['7B']['prompt']}")



    init_model()
    fungineer.launch(enable_queue=True, show_api=False, debug=True)