king007 commited on
Commit
fb5f4e7
1 Parent(s): 4897c17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -7
app.py CHANGED
@@ -1,23 +1,33 @@
1
  import torch
 
2
  from transformers import GPT2Tokenizer, GPT2LMHeadModel, pipeline
 
3
  tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
4
  tokenizer.add_special_tokens({'pad_token': '[PAD]'})
5
  model = GPT2LMHeadModel.from_pretrained('FredZhang7/anime-anything-promptgen-v2')
6
 
7
- prompt = r'1girl, genshin'
8
 
9
  # generate text using fine-tuned model
10
  nlp = pipeline('text-generation', model=model, tokenizer=tokenizer)
11
 
12
- # generate 10 samples using contrastive search
13
- outs = nlp(prompt, max_length=76, num_return_sequences=10, do_sample=True, repetition_penalty=1.2, temperature=0.7, top_k=4, early_stopping=True)
 
 
 
 
 
14
 
15
- print('\nInput:\n' + 100 * '-')
16
- print('\033[96m' + prompt + '\033[0m')
17
- print('\nOutput:\n' + 100 * '-')
18
- print(outs)
19
  # for i in range(len(outs)):
20
  # remove trailing commas and double spaces
21
  # outs[i] = str(outs[i]['generated_text']).replace(' ', '').rstrip(',')
22
  # print('\033[92m' + '\n\n'.join(outs) + '\033[0m\n')
23
  # print(str(outs[i]['generated_text']))
 
 
 
 
 
 
 
 
1
  import torch
2
+ import gradio as gr
3
  from transformers import GPT2Tokenizer, GPT2LMHeadModel, pipeline
4
+
5
  tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
6
  tokenizer.add_special_tokens({'pad_token': '[PAD]'})
7
  model = GPT2LMHeadModel.from_pretrained('FredZhang7/anime-anything-promptgen-v2')
8
 
9
+ # prompt = r'1girl, genshin'
10
 
11
  # generate text using fine-tuned model
12
  nlp = pipeline('text-generation', model=model, tokenizer=tokenizer)
13
 
14
+ def generate(prompt):
15
+ # generate 10 samples using contrastive search
16
+ outs = nlp(prompt, max_length=76, num_return_sequences=10, do_sample=True, repetition_penalty=1.2, temperature=0.7, top_k=4, early_stopping=True)
17
+ print(prompt)
18
+ print(outs)
19
+ return outs
20
+
21
 
 
 
 
 
22
  # for i in range(len(outs)):
23
  # remove trailing commas and double spaces
24
  # outs[i] = str(outs[i]['generated_text']).replace(' ', '').rstrip(',')
25
  # print('\033[92m' + '\n\n'.join(outs) + '\033[0m\n')
26
  # print(str(outs[i]['generated_text']))
27
+
28
+ input_component = gr.Textbox(label = "Input a prompt", value = "1girl, genshin")
29
+ output_component = gr.Textbox(label = "detail Prompt")
30
+ examples = []
31
+ description = ""
32
+ gr.Interface(generate, inputs = input_component, outputs=output_component, examples=examples, title = "anything prompt", description=description).launch()
33
+