Text Generation
Transformers
PyTorch
Safetensors
gpt2
stable-diffusion
prompt-generator
distilgpt2
text-generation-inference
Inference Endpoints
FredZhang7 commited on
Commit
ab114ee
1 Parent(s): 0aabb74

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +53 -1
README.md CHANGED
@@ -1,3 +1,55 @@
1
  ---
2
- license: mit
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ license: creativeml-openrail-m
3
+ tags:
4
+ - stable-diffusion
5
+ - prompt-generator
6
+ - distilgpt2
7
  ---
8
+ # Distilgpt2 Stable Diffusion Model Card
9
+ Distilgpt2 Stable Diffusion is a text-to-text model used to generate creative and coherent prompts given any text.
10
+ This model was finetuned on 2.03M stable diffusion prompts from [Stable Diffusion discord](https://huggingface.co/datasets/bartman081523/stable-diffusion-discord-prompts), [Lexica.art](https://huggingface.co/datasets/Gustavosta/Stable-Diffusion-Prompts), and (my hand-picked) [Krea.ai](./krea.ai.txt). I filtered the hand-picked prompts based on the output results from Stable Diffusion v1.4.
11
+
12
+ ### PyTorch
13
+
14
+ ```bash
15
+ pip install --upgrade transformers
16
+ ```
17
+
18
+ ```python
19
+ # download DistilGPT2 Stable Diffusion if haven't already
20
+ import os
21
+ if not os.path.exists('./distil-sd-gpt2.pt'):
22
+ import urllib.request
23
+ print('Downloading model...')
24
+ urllib.request.urlretrieve('https://huggingface.co/FredZhang7/distilgpt2-stable-diffusion/resolve/main/distil-sd-gpt2.pt', './distil-sd-gpt2.pt')
25
+ print('Model downloaded.')
26
+
27
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
28
+
29
+ # load the pretrained tokenizer
30
+ tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
31
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
32
+ tokenizer.max_len = 512
33
+
34
+ # load the fine-tuned model
35
+ import torch
36
+ model = GPT2LMHeadModel.from_pretrained('distilgpt2')
37
+ model.load_state_dict(torch.load('model.pt'))
38
+
39
+ # generate text using fine-tuned model
40
+ from transformers import pipeline
41
+ nlp = pipeline('text-generation', model=model, tokenizer=tokenizer)
42
+ ins = "a beautiful city"
43
+
44
+ # generate 5 samples
45
+ outs = nlp(ins, max_length=80, num_return_sequences=10)
46
+
47
+ # print the 5 samples
48
+ for i in range(len(outs)):
49
+ outs[i] = str(outs[i]['generated_text']).replace(' ', '')
50
+ print('\033[96m' + ins + '\033[0m')
51
+ print('\033[93m' + '\n\n'.join(outs) + '\033[0m')
52
+ ```
53
+
54
+ Example Output:
55
+ ![Example Output](https://media.discordapp.net/attachments/884528247998664744/1049544706163482704/image.png?width=1440&height=479)