Text Generation
Transformers
PyTorch
Safetensors
gpt2
stable-diffusion
prompt-generator
distilgpt2
text-generation-inference
Inference Endpoints
Edit model card

DistilGPT2 Stable Diffusion Model Card

Version 2 is here!

DistilGPT2 Stable Diffusion is a text generation model used to generate creative and coherent prompts for text-to-image models, given any text. This model was finetuned on 2.03 million descriptive stable diffusion prompts from Stable Diffusion discord, Lexica.art, and (my hand-picked) Krea.ai. I filtered the hand-picked prompts based on the output results from Stable Diffusion v1.4.

Compared to other prompt generation models using GPT2, this one runs with 50% faster forwardpropagation and 40% less disk space & RAM.

PyTorch

pip install --upgrade transformers
from transformers import GPT2Tokenizer, GPT2LMHeadModel

# load the pretrained tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tokenizer.max_len = 512

# load the fine-tuned model
model = GPT2LMHeadModel.from_pretrained('FredZhang7/distilgpt2-stable-diffusion')

# generate text using fine-tuned model
from transformers import pipeline
nlp = pipeline('text-generation', model=model, tokenizer=tokenizer)
ins = "a beautiful city"

# generate 10 samples
outs = nlp(ins, max_length=80, num_return_sequences=10)

# print the 10 samples
for i in range(len(outs)):
    outs[i] = str(outs[i]['generated_text']).replace('  ', '')
print('\033[96m' + ins + '\033[0m')
print('\033[93m' + '\n\n'.join(outs) + '\033[0m')

Example Output: Example Output

Downloads last month
671
Safetensors
Model size
88.2M params
Tensor type
F32
·
U8
·
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Datasets used to train FredZhang7/distilgpt2-stable-diffusion

Space using FredZhang7/distilgpt2-stable-diffusion 1