FredZhang7's picture
Update README.md
65a9a78
|
raw
history blame
3.87 kB
metadata
license: creativeml-openrail-m
tags:
  - stable-diffusion
  - prompt-generator
widget:
  - text: amazing
  - text: a photo of
  - text: a sci-fi
  - text: a portrait of
  - text: a person standing
  - text: a boy watching
datasets:
  - poloclub/diffusiondb
  - Gustavosta/Stable-Diffusion-Prompts
  - bartman081523/stable-diffusion-discord-prompts
  - FredZhang7/krea-ai-prompts

DistilGPT2 Stable Diffusion V2 Model Card

DistilGPT2 Stable Diffusion V2 is a text generation model used to generate creative and coherent prompts for text-to-image models, given any text. This model was trained on 2.47 million descriptive stable diffusion prompts on the FredZhang7/distilgpt2-stable-diffusion checkpoint for another 4.27 million steps.

Compared to other prompt generation models using GPT2, this one runs with 50% faster forwardpropagation and 40% less disk space & RAM.

Major improvements from v1 are:

  • 25% more variations
  • more capable of generating story-like prompts
  • cleaned training data
    • removed prompts that generate images with nsfw scores > 0.5
    • removed duplicates, including prompts that differ by capitalization and punctuations
    • removed punctuations at random places
    • removed prompts shorter than 15 characters

PyTorch

pip install --upgrade transformers

Faster but less fluent generation:

from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model = GPT2LMHeadModel.from_pretrained('FredZhang7/distilgpt2-stable-diffusion-v2', pad_token_id=tokenizer.eos_token_id)

prompt = r'a cat sitting'

# generate text using fine-tuned model
from transformers import pipeline
nlp = pipeline('text-generation', model=model, tokenizer=tokenizer)

# generate 5 samples
outs = nlp(prompt, max_length=80, num_return_sequences=5)

print('\nInput:\n' + 100 * '-')
print('\033[96m' + prompt + '\033[0m')
print('\nOutput:\n' + 100 * '-')
for i in range(len(outs)):
    outs[i] = str(outs[i]['generated_text']).replace('  ', '')
print('\033[92m' + '\n\n'.join(outs) + '\033[0m\n')

Example output: greedy search


Slower but more fluent generation:

from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model = GPT2LMHeadModel.from_pretrained('FredZhang7/distilgpt2-stable-diffusion-v2', pad_token_id=tokenizer.eos_token_id)
model.eval()

prompt = r'a cat sitting'     # the beginning of the prompt
temperature = 0.9             # a higher temperature will produce more diverse results, but with a higher risk of less coherent text.
top_k = 8                     # the number of tokens to sample from at each step
max_length = 80               # the maximum number of tokens for the output of the model
repitition_penalty = 1.2      # the penalty value for each repetition of a token
num_beams=10
num_return_sequences=5        # the number of results with the highest probabilities out of num_beams

# generate the result with contrastive search.
input_ids = tokenizer(prompt, return_tensors='pt').input_ids
output = model.generate(input_ids, do_sample=True, temperature=temperature, top_k=top_k, max_length=max_length, num_return_sequences=num_return_sequences, num_beams=num_beams, repetition_penalty=repitition_penalty, penalty_alpha=0.6, no_repeat_ngram_size=1, early_stopping=True)

# print results
print('\nInput:\n' + 100 * '-')
print('\033[96m' + prompt + '\033[0m')
print('\nOutput:\n' + 100 * '-')
for i in range(len(output)):
    print('\033[92m' + tokenizer.decode(output[i], skip_special_tokens=True) + '\033[0m\n')

Example output: constrastive search