File size: 6,310 Bytes
da67b7d
3fb1820
da67b7d
 
 
3fb1820
 
 
 
 
 
7ce4b70
 
3fb1820
 
7ce4b70
 
 
 
3fb1820
7ce4b70
3fb1820
 
 
 
 
 
 
 
 
9d6d4d4
 
3fb1820
7ce4b70
 
 
 
3fb1820
 
 
 
 
 
 
 
 
 
1eff0a6
3fb1820
1eff0a6
3fb1820
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1eff0a6
3fb1820
1eff0a6
3fb1820
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5dc3e4
3fb1820
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00f9a93
d8e2c18
3fb1820
 
1eff0a6
3fb1820
7ce4b70
 
2599853
 
7ce4b70
8127dfc
3fb1820
e1b63be
 
 
 
8127dfc
7ce4b70
fe32c8f
8469ea4
c5dc3e4
3fb1820
1eff0a6
3fb1820
1eff0a6
3fb1820
1eff0a6
3fb1820
1eff0a6
 
 
 
8469ea4
e51c162
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import torch
import gradio as gr

from transformers import AutoModelForSequenceClassification
from transformers import pipeline, GPT2Tokenizer, AutoTokenizer

############################################################################################

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

generator_name_1 = 'MasterAlex69/gpt2_edline_gan'
generator_name_0 = 'MasterAlex69/gpt2_edline'

generator_tokenizer_1 = GPT2Tokenizer.from_pretrained(generator_name_1)
generator_tokenizer_1.pad_token_id = generator_tokenizer_1.eos_token_id

generator_tokenizer_0 = GPT2Tokenizer.from_pretrained(generator_name_0)
generator_tokenizer_0.pad_token_id = generator_tokenizer_0.eos_token_id

generator_pipeline_1 = pipeline('text-generation', model = generator_name_1, tokenizer = generator_tokenizer_1, device = device)
generator_pipeline_0 = pipeline('text-generation', model = generator_name_0, tokenizer = generator_tokenizer_0, device = device)

############################################################################################

discriminator_name_1 = 'MasterAlex69/bert_edline_gan'
discriminator_1 = AutoModelForSequenceClassification.from_pretrained(discriminator_name_1).to(device)
discriminator_tokenizer_1 = AutoTokenizer.from_pretrained(discriminator_name_1)

############################################################################################
def generate_text_1():
    result = generator_pipeline_1("Строка состоит из символов", max_length = 225, truncation = False)[0]['generated_text']
    return [result, result]

def generate_text_0():
    result = generator_pipeline_0("Строка состоит из символов", max_length = 225, truncation = False)[0]['generated_text']
    return [result, result]

def discriminate_text_1(text):

  inputs = discriminator_tokenizer_1(text
                                   , return_tensors  = "pt"
                                   , padding         = True
                                   , truncation      = True).to(device)

  result = discriminator_1(**inputs).logits[:, -1]
  return torch.round(torch.sigmoid(result)).long().tolist()[0]

def d_test_1():

  count = 100
  if count == "": count = 0
  count = int(count)
  if count == 0: return 'Введите количество итераций...'
  if count > 256: return 'Максимальное количество итераций: 256.'
  result = generator_pipeline_1(['Строка состоит из символов'] * count, max_length = 225, batch_size = count)
  texts = [item['generated_text'] for sublist in result for item in sublist]
  results = [discriminate_text_1(text) for text in texts]

  i = 0
  m = 0
  for result in results:

    real_result = 0
    if get_correct_answer(texts[i]).find('(не корректно)') == -1: real_result = 1
    if result == real_result: m += 1
    i += 1

  return str(round(m / count * 100, 2)) + '%'


def test():

  count = 100
  if count == "": count = 0

  right = 0
  count = int(count)
  if count == 0: return 'Введите количество итераций...'
  if count > 256: return 'Максимальное количество итераций: 256.'

  result = generator_pipeline_1(['Строка состоит из символов'] * count, max_length = 225, batch_size = count)
  texts = [item['generated_text'] for sublist in result for item in sublist]

  for text in texts:
    if get_correct_answer(text).find('не корректно') == -1: right += 1

  return str(round(right / count * 100, 2)) + '%'

def get_correct_answer(t):

  if len(t) == 0: return 'Введите задание...'

  start_index   = t.find("(")
  end_index     = t.find(")", start_index)
  a = t[start_index + 8: end_index]

  start_index   = t.find("д символов ")
  end_index     = t.find(".", start_index)
  c = t[start_index + 11 : end_index]

  start_index   = t.find("а: ")
  end_index     = t.find(".", start_index)
  t = t[start_index + 3: end_index]

  t = t.replace(c, '*')

  max_length      = 0
  current_length  = 0

  for char in t:
    if char == '*':

      current_length += 1
      if current_length > max_length: max_length = current_length

    else: current_length = 0

  return str(max_length) + (' (корректно)' if str(max_length) == a else ' (не корректно)')

############################################################################################
with gr.Blocks(theme = gr.themes.Monochrome()) as iface:
  
  gr.Markdown("## Генератор учебных заданий ☕")

  with gr.Row():
      
    with gr.Column():
      button_gen_0 = gr.Button("Сгенерировать задание (ДО)")
      button_gen_0_output_text = gr.Textbox(label="Результат генерации", interactive = False)
    
    with gr.Column():     
      button_gen_1 = gr.Button("Сгенерировать задание (ПОСЛЕ)")
      button_gen_1_output_text = gr.Textbox(label="Результат генерации", interactive = False)

  with gr.Column():
    button_get_correct_answer = gr.Button("Получить правильный ответ")
    get_correct_answer_input_text = gr.Textbox(label = "Задание")
    get_correct_answer_output_text = gr.Textbox(label = "Ответ")

  button_gen_0.click(fn = generate_text_0, outputs = [button_gen_0_output_text, get_correct_answer_input_text])
  button_gen_1.click(fn = generate_text_1, outputs = [button_gen_1_output_text, get_correct_answer_input_text])
  button_get_correct_answer.click(fn = get_correct_answer, inputs = get_correct_answer_input_text, outputs = get_correct_answer_output_text)
      
  with gr.Row():
      
    with gr.Column():
      button_test_ = gr.Button("Провести испытание (генератор)")
      test_output_text_ = gr.Textbox(label = "Корректных заданий")
      button_test_.click(fn = test, outputs = test_output_text_)

    with gr.Column():
      bn_test_d_1 = gr.Button("Провести испытание (дискриминатор)")
      bn_test_d_1_text_output = gr.Textbox(label = "Совпадений")
      bn_test_d_1.click(fn = d_test_1, outputs = bn_test_d_1_text_output)
      
iface.launch(ssr_mode = True, debug = False)