Text Generation
Transformers
PyTorch
Safetensors
gpt2
stable-diffusion
prompt-generator
arxiv:2210.14140
Inference Endpoints
text-generation-inference
FredZhang7 commited on
Commit
18ab019
1 Parent(s): 83eaed0

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -6
README.md CHANGED
@@ -73,7 +73,7 @@ Slower but more fluent generation:
73
  from transformers import GPT2Tokenizer, GPT2LMHeadModel
74
  tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
75
  tokenizer.add_special_tokens({'pad_token': '[PAD]'})
76
- model = GPT2LMHeadModel.from_pretrained('FredZhang7/distilgpt2-stable-diffusion-v2', pad_token_id=tokenizer.eos_token_id)
77
  model.eval()
78
 
79
  prompt = r'a cat sitting' # the beginning of the prompt
@@ -81,14 +81,12 @@ temperature = 0.9 # a higher temperature will produce more diverse r
81
  top_k = 8 # the number of tokens to sample from at each step
82
  max_length = 80 # the maximum number of tokens for the output of the model
83
  repitition_penalty = 1.2 # the penalty value for each repetition of a token
84
- num_beams=10
85
- num_return_sequences=5 # the number of results with the highest probabilities out of num_beams
86
 
87
- # generate the result with contrastive search
88
  input_ids = tokenizer(prompt, return_tensors='pt').input_ids
89
- output = model.generate(input_ids, do_sample=True, temperature=temperature, top_k=top_k, max_length=max_length, num_return_sequences=num_return_sequences, num_beams=num_beams, repetition_penalty=repitition_penalty, penalty_alpha=0.6, no_repeat_ngram_size=1, early_stopping=True)
90
 
91
- # print results
92
  print('\nInput:\n' + 100 * '-')
93
  print('\033[96m' + prompt + '\033[0m')
94
  print('\nOutput:\n' + 100 * '-')
 
73
  from transformers import GPT2Tokenizer, GPT2LMHeadModel
74
  tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
75
  tokenizer.add_special_tokens({'pad_token': '[PAD]'})
76
+ model = GPT2LMHeadModel.from_pretrained('FredZhang7/distilgpt2-stable-diffusion-v2')
77
  model.eval()
78
 
79
  prompt = r'a cat sitting' # the beginning of the prompt
 
81
  top_k = 8 # the number of tokens to sample from at each step
82
  max_length = 80 # the maximum number of tokens for the output of the model
83
  repitition_penalty = 1.2 # the penalty value for each repetition of a token
84
+ num_return_sequences=5 # the number of results to generate
 
85
 
86
+ # generate the result with contrastive search. generate 5 results with the highest probability out of 10.
87
  input_ids = tokenizer(prompt, return_tensors='pt').input_ids
88
+ output = model.generate(input_ids, do_sample=True, temperature=temperature, top_k=top_k, max_length=max_length, num_return_sequences=num_return_sequences, repetition_penalty=repitition_penalty, penalty_alpha=0.6, no_repeat_ngram_size=1, early_stopping=True)
89
 
 
90
  print('\nInput:\n' + 100 * '-')
91
  print('\033[96m' + prompt + '\033[0m')
92
  print('\nOutput:\n' + 100 * '-')