Spaces:
Running
Running
import torch | |
import gradio as gr | |
from transformers import AutoModelForSequenceClassification | |
from transformers import pipeline, GPT2Tokenizer, AutoTokenizer | |
############################################################################################ | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
generator_name_1 = 'MasterAlex69/gpt2_edline_gan' | |
generator_name_0 = 'MasterAlex69/gpt2_edline' | |
generator_tokenizer_1 = GPT2Tokenizer.from_pretrained(generator_name_1) | |
generator_tokenizer_1.pad_token_id = generator_tokenizer_1.eos_token_id | |
generator_tokenizer_0 = GPT2Tokenizer.from_pretrained(generator_name_0) | |
generator_tokenizer_0.pad_token_id = generator_tokenizer_0.eos_token_id | |
generator_pipeline_1 = pipeline('text-generation', model = generator_name_1, tokenizer = generator_tokenizer_1, device = device) | |
generator_pipeline_0 = pipeline('text-generation', model = generator_name_0, tokenizer = generator_tokenizer_0, device = device) | |
############################################################################################ | |
discriminator_name_1 = 'MasterAlex69/bert_edline_gan' | |
discriminator_1 = AutoModelForSequenceClassification.from_pretrained(discriminator_name_1).to(device) | |
discriminator_tokenizer_1 = AutoTokenizer.from_pretrained(discriminator_name_1) | |
############################################################################################ | |
def generate_text_1(): | |
result = generator_pipeline_1("Строка состоит из символов", max_length = 225, truncation = False)[0]['generated_text'] | |
return [result, result] | |
def generate_text_0(): | |
result = generator_pipeline_0("Строка состоит из символов", max_length = 225, truncation = False)[0]['generated_text'] | |
return [result, result] | |
def discriminate_text_1(text): | |
inputs = discriminator_tokenizer_1(text | |
, return_tensors = "pt" | |
, padding = True | |
, truncation = True).to(device) | |
result = discriminator_1(**inputs).logits[:, -1] | |
return torch.round(torch.sigmoid(result)).long().tolist()[0] | |
def d_test_1(): | |
count = 100 | |
if count == "": count = 0 | |
count = int(count) | |
if count == 0: return 'Введите количество итераций...' | |
if count > 256: return 'Максимальное количество итераций: 256.' | |
result = generator_pipeline_1(['Строка состоит из символов'] * count, max_length = 225, batch_size = count) | |
texts = [item['generated_text'] for sublist in result for item in sublist] | |
results = [discriminate_text_1(text) for text in texts] | |
i = 0 | |
m = 0 | |
for result in results: | |
real_result = 0 | |
if get_correct_answer(texts[i]).find('(не корректно)') == -1: real_result = 1 | |
if result == real_result: m += 1 | |
i += 1 | |
return str(round(m / count * 100, 2)) + '%' | |
def test(): | |
count = 100 | |
if count == "": count = 0 | |
right = 0 | |
count = int(count) | |
if count == 0: return 'Введите количество итераций...' | |
if count > 256: return 'Максимальное количество итераций: 256.' | |
result = generator_pipeline_1(['Строка состоит из символов'] * count, max_length = 225, batch_size = count) | |
texts = [item['generated_text'] for sublist in result for item in sublist] | |
for text in texts: | |
if get_correct_answer(text).find('не корректно') == -1: right += 1 | |
return str(round(right / count * 100, 2)) + '%' | |
def get_correct_answer(t): | |
if len(t) == 0: return 'Введите задание...' | |
start_index = t.find("(") | |
end_index = t.find(")", start_index) | |
a = t[start_index + 8: end_index] | |
start_index = t.find("д символов ") | |
end_index = t.find(".", start_index) | |
c = t[start_index + 11 : end_index] | |
start_index = t.find("а: ") | |
end_index = t.find(".", start_index) | |
t = t[start_index + 3: end_index] | |
t = t.replace(c, '*') | |
max_length = 0 | |
current_length = 0 | |
for char in t: | |
if char == '*': | |
current_length += 1 | |
if current_length > max_length: max_length = current_length | |
else: current_length = 0 | |
return str(max_length) + (' (корректно)' if str(max_length) == a else ' (не корректно)') | |
############################################################################################ | |
with gr.Blocks(theme = gr.themes.Monochrome()) as iface: | |
gr.Markdown("## Генератор учебных заданий ☕") | |
with gr.Row(): | |
with gr.Column(): | |
button_gen_0 = gr.Button("Сгенерировать задание (ДО)") | |
button_gen_0_output_text = gr.Textbox(label="Результат генерации", interactive = False) | |
with gr.Column(): | |
button_gen_1 = gr.Button("Сгенерировать задание (ПОСЛЕ)") | |
button_gen_1_output_text = gr.Textbox(label="Результат генерации", interactive = False) | |
with gr.Column(): | |
button_get_correct_answer = gr.Button("Получить правильный ответ") | |
get_correct_answer_input_text = gr.Textbox(label = "Задание") | |
get_correct_answer_output_text = gr.Textbox(label = "Ответ") | |
button_gen_0.click(fn = generate_text_0, outputs = [button_gen_0_output_text, get_correct_answer_input_text]) | |
button_gen_1.click(fn = generate_text_1, outputs = [button_gen_1_output_text, get_correct_answer_input_text]) | |
button_get_correct_answer.click(fn = get_correct_answer, inputs = get_correct_answer_input_text, outputs = get_correct_answer_output_text) | |
with gr.Row(): | |
with gr.Column(): | |
button_test_ = gr.Button("Провести испытание (генератор)") | |
test_output_text_ = gr.Textbox(label = "Корректных заданий") | |
button_test_.click(fn = test, outputs = test_output_text_) | |
with gr.Column(): | |
bn_test_d_1 = gr.Button("Провести испытание (дискриминатор)") | |
bn_test_d_1_text_output = gr.Textbox(label = "Совпадений") | |
bn_test_d_1.click(fn = d_test_1, outputs = bn_test_d_1_text_output) | |
iface.launch(ssr_mode = True, debug = False) |