edline / app.py
MasterAlex69's picture
Update app.py
e1b63be verified
import torch
import gradio as gr
from transformers import AutoModelForSequenceClassification
from transformers import pipeline, GPT2Tokenizer, AutoTokenizer
############################################################################################
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator_name_1 = 'MasterAlex69/gpt2_edline_gan'
generator_name_0 = 'MasterAlex69/gpt2_edline'
generator_tokenizer_1 = GPT2Tokenizer.from_pretrained(generator_name_1)
generator_tokenizer_1.pad_token_id = generator_tokenizer_1.eos_token_id
generator_tokenizer_0 = GPT2Tokenizer.from_pretrained(generator_name_0)
generator_tokenizer_0.pad_token_id = generator_tokenizer_0.eos_token_id
generator_pipeline_1 = pipeline('text-generation', model = generator_name_1, tokenizer = generator_tokenizer_1, device = device)
generator_pipeline_0 = pipeline('text-generation', model = generator_name_0, tokenizer = generator_tokenizer_0, device = device)
############################################################################################
discriminator_name_1 = 'MasterAlex69/bert_edline_gan'
discriminator_1 = AutoModelForSequenceClassification.from_pretrained(discriminator_name_1).to(device)
discriminator_tokenizer_1 = AutoTokenizer.from_pretrained(discriminator_name_1)
############################################################################################
def generate_text_1():
result = generator_pipeline_1("Строка состоит из символов", max_length = 225, truncation = False)[0]['generated_text']
return [result, result]
def generate_text_0():
result = generator_pipeline_0("Строка состоит из символов", max_length = 225, truncation = False)[0]['generated_text']
return [result, result]
def discriminate_text_1(text):
inputs = discriminator_tokenizer_1(text
, return_tensors = "pt"
, padding = True
, truncation = True).to(device)
result = discriminator_1(**inputs).logits[:, -1]
return torch.round(torch.sigmoid(result)).long().tolist()[0]
def d_test_1():
count = 100
if count == "": count = 0
count = int(count)
if count == 0: return 'Введите количество итераций...'
if count > 256: return 'Максимальное количество итераций: 256.'
result = generator_pipeline_1(['Строка состоит из символов'] * count, max_length = 225, batch_size = count)
texts = [item['generated_text'] for sublist in result for item in sublist]
results = [discriminate_text_1(text) for text in texts]
i = 0
m = 0
for result in results:
real_result = 0
if get_correct_answer(texts[i]).find('(не корректно)') == -1: real_result = 1
if result == real_result: m += 1
i += 1
return str(round(m / count * 100, 2)) + '%'
def test():
count = 100
if count == "": count = 0
right = 0
count = int(count)
if count == 0: return 'Введите количество итераций...'
if count > 256: return 'Максимальное количество итераций: 256.'
result = generator_pipeline_1(['Строка состоит из символов'] * count, max_length = 225, batch_size = count)
texts = [item['generated_text'] for sublist in result for item in sublist]
for text in texts:
if get_correct_answer(text).find('не корректно') == -1: right += 1
return str(round(right / count * 100, 2)) + '%'
def get_correct_answer(t):
if len(t) == 0: return 'Введите задание...'
start_index = t.find("(")
end_index = t.find(")", start_index)
a = t[start_index + 8: end_index]
start_index = t.find("д символов ")
end_index = t.find(".", start_index)
c = t[start_index + 11 : end_index]
start_index = t.find("а: ")
end_index = t.find(".", start_index)
t = t[start_index + 3: end_index]
t = t.replace(c, '*')
max_length = 0
current_length = 0
for char in t:
if char == '*':
current_length += 1
if current_length > max_length: max_length = current_length
else: current_length = 0
return str(max_length) + (' (корректно)' if str(max_length) == a else ' (не корректно)')
############################################################################################
with gr.Blocks(theme = gr.themes.Monochrome()) as iface:
gr.Markdown("## Генератор учебных заданий ☕")
with gr.Row():
with gr.Column():
button_gen_0 = gr.Button("Сгенерировать задание (ДО)")
button_gen_0_output_text = gr.Textbox(label="Результат генерации", interactive = False)
with gr.Column():
button_gen_1 = gr.Button("Сгенерировать задание (ПОСЛЕ)")
button_gen_1_output_text = gr.Textbox(label="Результат генерации", interactive = False)
with gr.Column():
button_get_correct_answer = gr.Button("Получить правильный ответ")
get_correct_answer_input_text = gr.Textbox(label = "Задание")
get_correct_answer_output_text = gr.Textbox(label = "Ответ")
button_gen_0.click(fn = generate_text_0, outputs = [button_gen_0_output_text, get_correct_answer_input_text])
button_gen_1.click(fn = generate_text_1, outputs = [button_gen_1_output_text, get_correct_answer_input_text])
button_get_correct_answer.click(fn = get_correct_answer, inputs = get_correct_answer_input_text, outputs = get_correct_answer_output_text)
with gr.Row():
with gr.Column():
button_test_ = gr.Button("Провести испытание (генератор)")
test_output_text_ = gr.Textbox(label = "Корректных заданий")
button_test_.click(fn = test, outputs = test_output_text_)
with gr.Column():
bn_test_d_1 = gr.Button("Провести испытание (дискриминатор)")
bn_test_d_1_text_output = gr.Textbox(label = "Совпадений")
bn_test_d_1.click(fn = d_test_1, outputs = bn_test_d_1_text_output)
iface.launch(ssr_mode = True, debug = False)