Spaces:
Running
Running
import datetime | |
from io import StringIO | |
from typing import Union | |
from random import sample | |
from collections import defaultdict | |
from streamlit.runtime.uploaded_file_manager import UploadedFile | |
from utilities_language_bert.rus_sentence_bert import TASK, SENTENCE | |
from utilities_language_general.rus_utils import compute_frequency_dict, prepare_tasks, prepare_target_words | |
from utilities_language_general.rus_constants import st, load_bert, load_classifiers, nlp, summarization, BAD_USER_TARGET_WORDS, MINIMUM_SETS | |
def main_workflow( | |
file: Union[UploadedFile, None], | |
text: str, | |
logs, | |
progress, | |
progress_d, | |
level: str, | |
tw_mode_automatic_mode: str, | |
target_words: str, | |
num_distractors: int, | |
save_name: str, | |
global_bad_target_words=BAD_USER_TARGET_WORDS): | |
# Clear bad target_words each time | |
if global_bad_target_words: | |
global_bad_target_words = [] | |
# Define main global variables | |
GLOBAL_DISTRACTORS = set() | |
MAX_FREQUENCY = 0 | |
logs.update(label='Загружаем языковые модели и другие данные', state='running') | |
pos_dict, scaler, classifier = load_classifiers('model3') | |
mask_filler = load_bert() | |
# Get input text | |
if file is not None: | |
stringio = StringIO(file.getvalue().decode("utf-8")) | |
current_text = stringio.read() | |
elif text != '': | |
current_text = text | |
else: | |
st.warning('Вы ни текст не вставили, ни файл не выбрали 😢') | |
current_text = '' | |
st.stop() | |
# Process target words | |
if tw_mode_automatic_mode == 'Самостоятельно': | |
if target_words == '': | |
st.warning('Вы не ввели целевые слова') | |
st.stop() | |
# Cannot make up paradigm, so only USER_TARGET_WORDS is used | |
USER_TARGET_WORDS = prepare_target_words(target_words) | |
tw_mode_automatic_mode = False | |
else: | |
USER_TARGET_WORDS = None | |
tw_mode_automatic_mode = True | |
# Text preprocessing | |
original_text = current_text | |
current_text = current_text.replace('.', '. ').replace('. . .', '...').replace(' ', ' ').replace('…', '...') \ | |
.replace('…', '...').replace('—', '-').replace('\u2014', '-').replace('—', '-').replace('-\n', '') \ | |
.replace('\n', '%^&*') | |
current_text_sentences = [sent.text.strip() for sent in nlp(current_text).sents] | |
logs.update(label='Получили Ваш текст!', state='running') | |
progress.progress(10) | |
# Compute frequency dict | |
FREQ_DICT = compute_frequency_dict(current_text) | |
# Get maximum frequency (top 5% barrier) | |
_frequency_barrier_percent = 0.05 | |
for j, tp in enumerate(FREQ_DICT.items()): | |
if j < len(FREQ_DICT) * _frequency_barrier_percent: | |
MAX_FREQUENCY = tp[1] | |
MAX_FREQUENCY = 3 if MAX_FREQUENCY < 3 else MAX_FREQUENCY | |
logs.update(label="Посчитали немного статистики!", state='running') | |
progress.progress(15) | |
# Choose necessary language minimum according to user's input | |
if level: | |
target_minimum, distractor_minimum = MINIMUM_SETS[level] | |
else: | |
target_minimum = None | |
distractor_minimum = None | |
logs.error('Вы не выбрали языковой уровень!') | |
st.stop() | |
# Start generation process | |
workflow = [SENTENCE(original=sent.strip(), n_sentence=num, max_num_distractors=num_distractors) | |
for num, sent in enumerate(current_text_sentences)] | |
logs.update(label="Запускаем процесс генерации заданий!", state='running') | |
progress.progress(20) | |
# Define summary length | |
text_length = len(current_text_sentences) | |
if text_length <= 15: | |
summary_length = text_length | |
elif text_length <= 25: | |
summary_length = 15 | |
else: | |
n = (text_length - 20) // 5 | |
summary_length = 15 + 2 * n | |
round_summary_length = summary_length - (summary_length % - 10) | |
# Get summary. May choose between round_summary_length and summary_length | |
SUMMARY = summarization(current_text, num_sentences=round_summary_length) | |
logs.update('Нашли интересные предложения. Пригодятся!') | |
progress.progress(25) | |
for sentence in workflow: | |
sentence.lemmatize_sentence() | |
for sentence in workflow: | |
sentence.bind_phrases() | |
logs.update(label="Подготовили предложения для дальнейшей работы!", state='running') | |
progress.progress(30) | |
for j, sentence in enumerate(workflow): | |
sentence.search_target_words(target_words_automatic_mode=tw_mode_automatic_mode, | |
target_minimum=target_minimum, | |
user_target_words=USER_TARGET_WORDS, | |
frequency_dict=FREQ_DICT, | |
summary=SUMMARY) | |
progress.progress(int(30 + (j * (20 / len(workflow))))) | |
progress.progress(50) | |
DUPLICATE_TARGET_WORDS = defaultdict(list) | |
for sentence in workflow: | |
for target_word in sentence.target_words: | |
DUPLICATE_TARGET_WORDS[target_word['lemma']].append(target_word) | |
RESULT_TW = [] | |
for tw_lemma, tw_data in DUPLICATE_TARGET_WORDS.items(): | |
RESULT_TW.append(sample(tw_data, 1)[0]) | |
for sentence in workflow: | |
for target_word in sentence.target_words: | |
if target_word not in RESULT_TW: | |
global_bad_target_words.append(target_word['original_text']) | |
sentence.target_words.remove(target_word) | |
progress.progress(55) | |
logs.update(label='Выбрали слова-пропуски!', state='running') | |
for sentence in workflow: | |
for i, target_word in enumerate(sentence.target_words): | |
temp = current_text_sentences[:] | |
temp[sentence.n_sentence] = target_word['masked_sentence'] | |
sentence.text_with_masked_task = ' '.join(temp).replace('%^&*', '\n') | |
sentence.target_words[i]['text_with_masked_task'] = ' '.join(temp).replace('%^&*', '\n') | |
for sentence in workflow: | |
sentence.filter_target_words(target_words_automatic_mode=tw_mode_automatic_mode) | |
progress.progress(60) | |
RESULT_TASKS = [] | |
for sentence in workflow: | |
for target_word in sentence.target_words: | |
task = TASK(task_data=target_word, max_num_distractors=num_distractors) | |
RESULT_TASKS.append(task) | |
for num, task in enumerate(RESULT_TASKS): | |
task.attach_distractors_to_target_word(model=mask_filler, | |
scaler=scaler, | |
classifier=classifier, | |
pos_dict=pos_dict, | |
level_name=level, | |
global_distractors=GLOBAL_DISTRACTORS, | |
distractor_minimum=distractor_minimum, | |
max_frequency=MAX_FREQUENCY) | |
progress_d.progress(num / len(RESULT_TASKS)) | |
logs.update(label=f'Обработали {num}/{len(RESULT_TASKS)} целевых слов!', state='running') | |
logs.update(label=f'Обработали {len(RESULT_TASKS)}/{len(RESULT_TASKS)} целевых слов!', state='running') | |
progress_d.progress(100) | |
progress.progress(70) | |
logs.update(label='Подобрали неправильные варианты!', state='running') | |
for task in RESULT_TASKS: | |
task.inflect_distractors() | |
progress.progress(80) | |
logs.update(label='Просклоняли и проспрягали неправильные варианты!', state='running') | |
for task in RESULT_TASKS: | |
task.sample_distractors(num_distractors=num_distractors) | |
progress.progress(85) | |
RESULT_TASKS = list(filter(lambda t: not t.bad_target_word, RESULT_TASKS)) | |
for task in RESULT_TASKS[::-1]: | |
if task.bad_target_word: | |
RESULT_TASKS.remove(task) | |
# Compute number of final tasks | |
if len(RESULT_TASKS) >= 20: | |
NUMBER_TASKS = 20 | |
else: | |
if len(RESULT_TASKS) >= 15: | |
NUMBER_TASKS = 15 | |
else: | |
if len(RESULT_TASKS) >= 10: | |
NUMBER_TASKS = 10 | |
else: | |
NUMBER_TASKS = len(RESULT_TASKS) | |
RESULT_TASKS_in_summary = list(filter(lambda task: task.in_summary, RESULT_TASKS)) | |
RESULT_TASTS_not_in_summary = list(filter(lambda task: not task.in_summary, RESULT_TASKS)) | |
if len(RESULT_TASKS_in_summary) >= NUMBER_TASKS: | |
RESULT_TASKS = RESULT_TASKS_in_summary | |
else: | |
RESULT_TASKS = RESULT_TASKS_in_summary + sample(RESULT_TASTS_not_in_summary, NUMBER_TASKS - len(RESULT_TASKS_in_summary)) | |
RESULT_TASKS = sorted(RESULT_TASKS, key=lambda t: (t.sentence_number, t.position_in_sentence)) | |
for task in RESULT_TASKS: | |
task.compile_task(max_num_distractors=num_distractors) | |
progress.progress(90) | |
logs.update(label='Отобрали лучшие задания!', state='running') | |
TEXT_WITH_GAPS = [] | |
VARIANTS = [] | |
tasks_counter = 1 | |
for i, sentence in enumerate(current_text_sentences): | |
for task in RESULT_TASKS: | |
if task.sentence_text == sentence: | |
sentence = sentence.replace(task.original_text, f'__________({tasks_counter})') | |
VARIANTS.append(task.variants) | |
tasks_counter += 1 | |
TEXT_WITH_GAPS.append(sentence) | |
del RESULT_TASKS | |
TEXT_WITH_GAPS = ' '.join([sentence for sentence in TEXT_WITH_GAPS]).replace('%^&*', '\n') | |
PREPARED_TASKS = prepare_tasks(VARIANTS) | |
STUDENT_OUT = f'{TEXT_WITH_GAPS}\n\n{"=" * 70}\n\n{PREPARED_TASKS["TASKS_STUDENT"]}' | |
TEACHER_OUT = f'{TEXT_WITH_GAPS}\n\n{"=" * 70}\n\n{PREPARED_TASKS["TASKS_TEACHER"]}\n\n{"=" * 70}\n\n' \ | |
f'{PREPARED_TASKS["KEYS_ONLY"]}' | |
TOTAL_OUT = f'{original_text}\n\n{"$" * 70}\n\n{STUDENT_OUT}\n\n{"=" * 70}\n\n{PREPARED_TASKS["TASKS_TEACHER"]}' \ | |
f'\n\n{"$" * 70}\n\n{PREPARED_TASKS["KEYS_ONLY"]}' | |
logs.update(label='Сейчас все будет готово!', state='running') | |
progress.progress(95) | |
save_name = save_name if save_name != '' else f'{str(datetime.datetime.now())[:-7]}_{original_text[:20]}' | |
out = { | |
'name': save_name, | |
'STUDENT_OUT': STUDENT_OUT, | |
'TEACHER_OUT': TEACHER_OUT, | |
'TEXT_WITH_GAPS': TEXT_WITH_GAPS, | |
'TASKS_ONLY': PREPARED_TASKS["RAW_TASKS"], | |
'KEYS_ONLY': PREPARED_TASKS["KEYS_ONLY"], | |
'KEYS_ONLY_RAW': PREPARED_TASKS["RAW_KEYS_ONLY"], | |
'TOTAL_OUT': TOTAL_OUT, | |
'ORIGINAL': original_text, | |
'BAD_USER_TARGET_WORDS': sorted(set(global_bad_target_words)) | |
} | |
return out | |