Spaces:
Running
Running
File size: 6,310 Bytes
da67b7d 3fb1820 da67b7d 3fb1820 7ce4b70 3fb1820 7ce4b70 3fb1820 7ce4b70 3fb1820 9d6d4d4 3fb1820 7ce4b70 3fb1820 1eff0a6 3fb1820 1eff0a6 3fb1820 1eff0a6 3fb1820 1eff0a6 3fb1820 c5dc3e4 3fb1820 00f9a93 d8e2c18 3fb1820 1eff0a6 3fb1820 7ce4b70 2599853 7ce4b70 8127dfc 3fb1820 e1b63be 8127dfc 7ce4b70 fe32c8f 8469ea4 c5dc3e4 3fb1820 1eff0a6 3fb1820 1eff0a6 3fb1820 1eff0a6 3fb1820 1eff0a6 8469ea4 e51c162 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import torch
import gradio as gr
from transformers import AutoModelForSequenceClassification
from transformers import pipeline, GPT2Tokenizer, AutoTokenizer
############################################################################################
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator_name_1 = 'MasterAlex69/gpt2_edline_gan'
generator_name_0 = 'MasterAlex69/gpt2_edline'
generator_tokenizer_1 = GPT2Tokenizer.from_pretrained(generator_name_1)
generator_tokenizer_1.pad_token_id = generator_tokenizer_1.eos_token_id
generator_tokenizer_0 = GPT2Tokenizer.from_pretrained(generator_name_0)
generator_tokenizer_0.pad_token_id = generator_tokenizer_0.eos_token_id
generator_pipeline_1 = pipeline('text-generation', model = generator_name_1, tokenizer = generator_tokenizer_1, device = device)
generator_pipeline_0 = pipeline('text-generation', model = generator_name_0, tokenizer = generator_tokenizer_0, device = device)
############################################################################################
discriminator_name_1 = 'MasterAlex69/bert_edline_gan'
discriminator_1 = AutoModelForSequenceClassification.from_pretrained(discriminator_name_1).to(device)
discriminator_tokenizer_1 = AutoTokenizer.from_pretrained(discriminator_name_1)
############################################################################################
def generate_text_1():
result = generator_pipeline_1("Строка состоит из символов", max_length = 225, truncation = False)[0]['generated_text']
return [result, result]
def generate_text_0():
result = generator_pipeline_0("Строка состоит из символов", max_length = 225, truncation = False)[0]['generated_text']
return [result, result]
def discriminate_text_1(text):
inputs = discriminator_tokenizer_1(text
, return_tensors = "pt"
, padding = True
, truncation = True).to(device)
result = discriminator_1(**inputs).logits[:, -1]
return torch.round(torch.sigmoid(result)).long().tolist()[0]
def d_test_1():
count = 100
if count == "": count = 0
count = int(count)
if count == 0: return 'Введите количество итераций...'
if count > 256: return 'Максимальное количество итераций: 256.'
result = generator_pipeline_1(['Строка состоит из символов'] * count, max_length = 225, batch_size = count)
texts = [item['generated_text'] for sublist in result for item in sublist]
results = [discriminate_text_1(text) for text in texts]
i = 0
m = 0
for result in results:
real_result = 0
if get_correct_answer(texts[i]).find('(не корректно)') == -1: real_result = 1
if result == real_result: m += 1
i += 1
return str(round(m / count * 100, 2)) + '%'
def test():
count = 100
if count == "": count = 0
right = 0
count = int(count)
if count == 0: return 'Введите количество итераций...'
if count > 256: return 'Максимальное количество итераций: 256.'
result = generator_pipeline_1(['Строка состоит из символов'] * count, max_length = 225, batch_size = count)
texts = [item['generated_text'] for sublist in result for item in sublist]
for text in texts:
if get_correct_answer(text).find('не корректно') == -1: right += 1
return str(round(right / count * 100, 2)) + '%'
def get_correct_answer(t):
if len(t) == 0: return 'Введите задание...'
start_index = t.find("(")
end_index = t.find(")", start_index)
a = t[start_index + 8: end_index]
start_index = t.find("д символов ")
end_index = t.find(".", start_index)
c = t[start_index + 11 : end_index]
start_index = t.find("а: ")
end_index = t.find(".", start_index)
t = t[start_index + 3: end_index]
t = t.replace(c, '*')
max_length = 0
current_length = 0
for char in t:
if char == '*':
current_length += 1
if current_length > max_length: max_length = current_length
else: current_length = 0
return str(max_length) + (' (корректно)' if str(max_length) == a else ' (не корректно)')
############################################################################################
with gr.Blocks(theme = gr.themes.Monochrome()) as iface:
gr.Markdown("## Генератор учебных заданий ☕")
with gr.Row():
with gr.Column():
button_gen_0 = gr.Button("Сгенерировать задание (ДО)")
button_gen_0_output_text = gr.Textbox(label="Результат генерации", interactive = False)
with gr.Column():
button_gen_1 = gr.Button("Сгенерировать задание (ПОСЛЕ)")
button_gen_1_output_text = gr.Textbox(label="Результат генерации", interactive = False)
with gr.Column():
button_get_correct_answer = gr.Button("Получить правильный ответ")
get_correct_answer_input_text = gr.Textbox(label = "Задание")
get_correct_answer_output_text = gr.Textbox(label = "Ответ")
button_gen_0.click(fn = generate_text_0, outputs = [button_gen_0_output_text, get_correct_answer_input_text])
button_gen_1.click(fn = generate_text_1, outputs = [button_gen_1_output_text, get_correct_answer_input_text])
button_get_correct_answer.click(fn = get_correct_answer, inputs = get_correct_answer_input_text, outputs = get_correct_answer_output_text)
with gr.Row():
with gr.Column():
button_test_ = gr.Button("Провести испытание (генератор)")
test_output_text_ = gr.Textbox(label = "Корректных заданий")
button_test_.click(fn = test, outputs = test_output_text_)
with gr.Column():
bn_test_d_1 = gr.Button("Провести испытание (дискриминатор)")
bn_test_d_1_text_output = gr.Textbox(label = "Совпадений")
bn_test_d_1.click(fn = d_test_1, outputs = bn_test_d_1_text_output)
iface.launch(ssr_mode = True, debug = False) |