Text Generation
Transformers
PyTorch
Safetensors
gpt2
stable-diffusion
prompt-generator
arxiv:2210.14140
Inference Endpoints
text-generation-inference
FredZhang7 commited on
Commit
c776d18
1 Parent(s): 8f8b4b7

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +87 -6
README.md CHANGED
@@ -1,5 +1,8 @@
1
  ---
2
  license: creativeml-openrail-m
 
 
 
3
  widget:
4
  - text: "amazing"
5
  - text: "a photo of"
@@ -7,14 +10,92 @@ widget:
7
  - text: "a portrait of"
8
  - text: "a person standing"
9
  - text: "a boy watching"
 
 
 
 
 
10
  ---
 
 
 
11
 
12
- **Under Development**
13
 
14
- Major improvements from v1 will include:
15
  - 25% more variations
16
- - more capable of generating story-like prompts, using contrastive search instead of greedy search
17
  - cleaned training data
18
- * removed prompts that generate images with nsfw scores >= 0.5
19
- * removed duplicates (including prompts that differ by capitalization and punctuations)
20
- * removed punctuations in random places
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: creativeml-openrail-m
3
+ tags:
4
+ - stable-diffusion
5
+ - prompt-generator
6
  widget:
7
  - text: "amazing"
8
  - text: "a photo of"
 
10
  - text: "a portrait of"
11
  - text: "a person standing"
12
  - text: "a boy watching"
13
+ datasets:
14
+ - poloclub/diffusiondb
15
+ - Gustavosta/Stable-Diffusion-Prompts
16
+ - bartman081523/stable-diffusion-discord-prompts
17
+ - FredZhang7/krea-ai-prompts
18
  ---
19
+ # DistilGPT2 Stable Diffusion V2 Model Card
20
+ DistilGPT2 Stable Diffusion V2 is a text generation model used to generate creative and coherent prompts for text-to-image models, given any text.
21
+ This model was trained on 2.47 million descriptive stable diffusion prompts on the [FredZhang7/distilgpt2-stable-diffusion](https://huggingface.co/FredZhang7/distilgpt2-stable-diffusion) checkpoint for 4.27 million steps.
22
 
23
+ Compared to other prompt generation models using GPT2, this one runs with 50% faster forwardpropagation and 40% less disk space & RAM.
24
 
25
+ Major improvements from v1 are:
26
  - 25% more variations
27
+ - more capable of generating story-like prompts
28
  - cleaned training data
29
+ * removed prompts that generate images with nsfw scores > 0.5
30
+ * removed duplicates, including prompts that differ by capitalization and punctuations
31
+ * removed punctuations at random places
32
+ * removed prompts shorter than 15 characters
33
+
34
+
35
+ ### PyTorch
36
+
37
+ ```bash
38
+ pip install --upgrade transformers
39
+ ```
40
+
41
+ Faster but less fluent generation:
42
+
43
+ ```python
44
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
45
+ tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
46
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
47
+ model = GPT2LMHeadModel.from_pretrained('FredZhang7/distilgpt2-stable-diffusion-v2', pad_token_id=tokenizer.eos_token_id)
48
+
49
+ prompt = r'a cat sitting'
50
+
51
+ # generate text using fine-tuned model
52
+ from transformers import pipeline
53
+ nlp = pipeline('text-generation', model=model, tokenizer=tokenizer)
54
+
55
+ # generate 5 samples
56
+ outs = nlp(prompt, max_length=80, num_return_sequences=5)
57
+
58
+ print('\nInput:\n' + 100 * '-')
59
+ print('\033[96m' + prompt + '\033[0m')
60
+ print('\nOutput:\n' + 100 * '-')
61
+ for i in range(len(outs)):
62
+ outs[i] = str(outs[i]['generated_text']).replace(' ', '')
63
+ print('\033[92m' + '\n\n'.join(outs) + '\033[0m\n')
64
+ ```
65
+
66
+ Example output:
67
+ ![greedy search](./greedy_search.png)
68
+
69
+ <br>
70
+
71
+ Slower but more fluent generation:
72
+
73
+ ```python
74
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
75
+ tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
76
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
77
+ model = GPT2LMHeadModel.from_pretrained('FredZhang7/distilgpt2-stable-diffusion-v2', pad_token_id=tokenizer.eos_token_id)
78
+ model.eval()
79
+
80
+ prompt = r'a cat sitting' # the beginning of the prompt
81
+ temperature = 0.9 # a higher temperature will produce more diverse results, but with a higher risk of less coherent text.
82
+ top_k = 8 # the number of tokens to sample from at each step
83
+ max_length = 80 # the maximum number of tokens for the output of the model
84
+ repitition_penalty = 1.2 # the penalty value for each repetition of a token
85
+ num_beams=10
86
+ num_return_sequences=5 # the number of results with the highest probabilities out of num_beams
87
+
88
+ # generate the result with contrastive search.
89
+ input_ids = tokenizer(prompt, return_tensors='pt').input_ids
90
+ output = model.generate(input_ids, do_sample=True, temperature=temperature, top_k=top_k, max_length=max_length, num_return_sequences=num_return_sequences, num_beams=num_beams, repetition_penalty=repitition_penalty, penalty_alpha=0.6, no_repeat_ngram_size=1, early_stopping=True)
91
+
92
+ # print results
93
+ print('\nInput:\n' + 100 * '-')
94
+ print('\033[96m' + prompt + '\033[0m')
95
+ print('\nOutput:\n' + 100 * '-')
96
+ for i in range(len(output)):
97
+ print('\033[92m' + tokenizer.decode(output[i], skip_special_tokens=True) + '\033[0m\n')
98
+ ```
99
+
100
+ Example output:
101
+ ![constrastive search](./constrastive_search.png)