Text Generation
Transformers
PyTorch
Safetensors
gpt2
stable-diffusion
prompt-generator
arxiv:2210.14140
Inference Endpoints
text-generation-inference
File size: 3,500 Bytes
acb96f1
 
c776d18
 
 
f77f96f
9ceae57
 
 
 
 
 
 
c776d18
e7b4039
c776d18
 
 
acb96f1
1821970
96f79f9
7c5bf67
 
 
 
 
 
 
 
 
 
 
96f79f9
d4945a6
b027a2c
c776d18
6ab84be
c776d18
6ab84be
5b515e8
7c227ba
c776d18
 
 
 
 
6464d27
d9c29c7
1bfee05
bc85ce2
c776d18
aa141d4
c776d18
 
 
 
 
 
 
 
 
18ab019
c776d18
 
25733c4
c776d18
 
 
18ab019
c776d18
b61e4ef
c776d18
18ab019
c776d18
 
 
 
 
 
 
 
0050e9d
568d8b6
 
 
e4831f4
568d8b6
e4831f4
568d8b6
 
0050e9d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
---
license: creativeml-openrail-m
tags:
- stable-diffusion
- prompt-generator
- arxiv:2210.14140
widget:
- text: "amazing"
- text: "a photo of"
- text: "a sci-fi"
- text: "a portrait of"
- text: "a person standing"
- text: "a boy watching"
datasets:
- FredZhang7/stable-diffusion-prompts-2.47M
- poloclub/diffusiondb
- Gustavosta/Stable-Diffusion-Prompts
- bartman081523/stable-diffusion-discord-prompts
---
# Fast GPT2 PromptGen

<style>
.container {
  padding-left: 20px;
  border-left: 5px solid gray;
}
</style>

<div class="container">
  <p><strong><a href="https://huggingface.co/FredZhang7/anime-anything-promptgen-v2">Fast Anime PromptGen</a></strong> generates descriptive safebooru and danbooru tags for anime text-to-image models.</p>
</div>


This model was trained on 2,470,000 descriptive stable diffusion prompts on the [FredZhang7/distilgpt2-stable-diffusion](https://huggingface.co/FredZhang7/distilgpt2-stable-diffusion) checkpoint for another 4,270,000 steps.

Compared to other prompt generation models using GPT2, this one runs with 50% faster forwardpropagation and 40% less disk space & RAM.

Major improvements from v1 are:
- 25% more variations
- faster and more fluent prompt generation
- cleaned training data
  * removed prompts that generate images with nsfw scores > 0.5
  * removed duplicates, including prompts that differ by capitalization and punctuations
  * removed punctuations at random places
  * removed prompts shorter than 15 characters


## Live WebUI Demo
See the Prompt Generator tab of [Paint Journey Demo](https://huggingface.co/spaces/FredZhang7/paint-journey-demo).


## Contrastive Search

```bash
pip install --upgrade transformers
```

```python
from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model = GPT2LMHeadModel.from_pretrained('FredZhang7/distilgpt2-stable-diffusion-v2')

prompt = r'a cat sitting'     # the beginning of the prompt
temperature = 0.9             # a higher temperature will produce more diverse results, but with a higher risk of less coherent text
top_k = 8                     # the number of tokens to sample from at each step
max_length = 80               # the maximum number of tokens for the output of the model
repitition_penalty = 1.2      # the penalty value for each repetition of a token
num_return_sequences=5        # the number of results to generate

# generate the result with contrastive search
input_ids = tokenizer(prompt, return_tensors='pt').input_ids
output = model.generate(input_ids, do_sample=True, temperature=temperature, top_k=top_k, max_length=max_length, num_return_sequences=num_return_sequences, repetition_penalty=repitition_penalty, penalty_alpha=0.6, no_repeat_ngram_size=1, early_stopping=True)

print('\nInput:\n' + 100 * '-')
print('\033[96m' + prompt + '\033[0m')
print('\nOutput:\n' + 100 * '-')
for i in range(len(output)):
    print('\033[92m' + tokenizer.decode(output[i], skip_special_tokens=True) + '\033[0m\n')
```

No comma style:
![constrastive search](./constrastive_search.png)


To bring back the commas, assign output without `penalty_alpha` and `no_repeat_ngram_size`:
```python
output = model.generate(input_ids, do_sample=True, temperature=temperature, top_k=top_k, max_length=max_length, num_return_sequences=num_return_sequences, repetition_penalty=repitition_penalty, early_stopping=True)
```

![constrastive search](./contrastive_comma_style.png)