johnsu6616 commited on
Commit
25faaaf
0 Parent(s):

Duplicate from johnsu6616/SD_Helper

Browse files
Files changed (4) hide show
  1. .gitattributes +34 -0
  2. README.md +14 -0
  3. app.py +143 -0
  4. requirements.txt +5 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SD_Helper
3
+ emoji: 📊
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 3.24.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: openrail
11
+ duplicated_from: johnsu6616/SD_Helper
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import re
3
+
4
+ import gradio as gr
5
+ import torch
6
+
7
+ from transformers import AutoModelForCausalLM
8
+ from transformers import AutoTokenizer
9
+ from transformers import AutoModelForSeq2SeqLM
10
+
11
+ from transformers import AutoProcessor
12
+
13
+ from transformers import pipeline
14
+
15
+ from transformers import set_seed
16
+
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ big_processor = AutoProcessor.from_pretrained("microsoft/git-base-coco")
20
+ big_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")
21
+
22
+ text_pipe = pipeline('text-generation', model='succinctly/text2image-prompt-generator')
23
+
24
+ zh2en_model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-zh-en').eval()
25
+ zh2en_tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-zh-en')
26
+
27
+ en2zh_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-zh").eval()
28
+ en2zh_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-zh")
29
+
30
+ def load_prompter():
31
+ prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
32
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
33
+ tokenizer.pad_token = tokenizer.eos_token
34
+ tokenizer.padding_side = "left"
35
+ return prompter_model, tokenizer
36
+
37
+ prompter_model, prompter_tokenizer = load_prompter()
38
+
39
+ def generate_prompter(plain_text, max_new_tokens=75, num_return_sequences=3):
40
+ input_ids = prompter_tokenizer(plain_text.strip() + " Rephrase:", return_tensors="pt").input_ids
41
+ eos_id = prompter_tokenizer.eos_token_id
42
+ outputs = prompter_model.generate(
43
+ input_ids,
44
+ do_sample=False,
45
+ max_new_tokens=75,
46
+ num_beams=6,
47
+ num_return_sequences=num_return_sequences,
48
+ eos_token_id=eos_id,
49
+ pad_token_id=eos_id,
50
+ length_penalty=-1
51
+
52
+ )
53
+
54
+ output_texts = prompter_tokenizer.batch_decode(outputs, skip_special_tokens=True)
55
+ result = ""
56
+ for output_text in output_texts:
57
+ result.append(output_text.replace(plain_text + " Rephrase:", "").strip())
58
+
59
+ return "\n".join(result)
60
+
61
+ def translate_zh2en(text):
62
+ with torch.no_grad():
63
+ text = text.replace('\n', ',').replace('\r', ',')
64
+ text = re.sub('^,+', ',', text)
65
+ encoded = zh2en_tokenizer([text], return_tensors='pt')
66
+ sequences = zh2en_model.generate(**encoded)
67
+ return zh2en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
68
+
69
+ def translate_en2zh(text):
70
+ with torch.no_grad():
71
+ encoded = en2zh_tokenizer([text], return_tensors="pt")
72
+ sequences = en2zh_model.generate(**encoded)
73
+ return en2zh_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
74
+
75
+ def text_generate(text):
76
+ seed = random.randint(100, 1000000)
77
+ set_seed(seed)
78
+
79
+ text_in_english = translate_zh2en(text)
80
+ result = ""
81
+ for _ in range(6):
82
+ sequences = text_pipe(text_in_english, max_length=random.randint(60, 90), num_return_sequences=8)
83
+ list = []
84
+ for sequence in sequences:
85
+
86
+ line = sequence['generated_text'].strip()
87
+
88
+ if line != text_in_english and len(line) > (len(text_in_english) + 4) and line.endswith(
89
+ (':', '-', '—')) is False:
90
+ list.append(line)
91
+
92
+ result = "\n".join(list)
93
+
94
+ result = re.sub('[^ ]+\.[^ ]+', '', result)
95
+
96
+ result = result.replace('<', '').replace('>', '').replace('"', '')
97
+ if result != '':
98
+ break
99
+
100
+ return result, "\n".join(translate_en2zh(line) for line in result.split("\n") if len(line) > 0)
101
+
102
+ def get_prompt_from_image(input_image):
103
+ image = input_image.convert('RGB')
104
+ pixel_values = big_processor(images=image, return_tensors="pt").to(device).pixel_values
105
+ generated_ids = big_model.to(device).generate(pixel_values=pixel_values, max_length=50)
106
+ generated_caption = big_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
107
+ print(generated_caption)
108
+ return generated_caption
109
+
110
+
111
+ with gr.Blocks() as block:
112
+ with gr.Column():
113
+ with gr.Tab('文生文'):
114
+ with gr.Row():
115
+ input_text = gr.Textbox(lines=12, label='輸入文字', placeholder='在此输入文字...')
116
+
117
+ with gr.Row():
118
+ txt_prompter_btn = gr.Button('執行')
119
+
120
+ with gr.Tab('圖生文'):
121
+ with gr.Row():
122
+ input_image = gr.Image(type='pil')
123
+
124
+ with gr.Row():
125
+ pic_prompter_btn = gr.Button('執行')
126
+
127
+ Textbox_1 = gr.Textbox(lines=6, label='輸出結果')
128
+ Textbox_2 = gr.Textbox(lines=6, label='中文翻譯')
129
+
130
+ txt_prompter_btn.click(
131
+
132
+ fn=text_generate,
133
+ inputs=input_text,
134
+ outputs=[Textbox_1,Textbox_2]
135
+ )
136
+
137
+ pic_prompter_btn.click(
138
+ fn=get_prompt_from_image,
139
+ inputs=input_image,
140
+ outputs=Textbox_1
141
+ )
142
+
143
+ block.queue(max_size=64).launch(show_api=False, enable_queue=True, debug=True, share=False, server_name='0.0.0.0')
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers==4.27.4
2
+ torch==2.0.0
3
+ gradio==3.24.1
4
+ sentencepiece==0.1.97
5
+ sacremoses==0.0.53