Spaces:
Sleeping
Sleeping
import gradio as gr | |
import glob | |
from docx import Document | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.metrics.pairwise import cosine_similarity | |
import torch | |
from transformers import T5ForConditionalGeneration, T5Tokenizer | |
import numpy as np | |
def is_header(txt): | |
if not txt or len(txt) < 35: | |
if txt == txt.upper() and not txt.endswith(('.', ':', '?', '!')): | |
return True | |
if txt.istitle() and len(txt.split()) < 6 and not txt.endswith(('.', ':', '?', '!')): | |
return True | |
return False | |
def get_blocks_from_docx(): | |
docx_list = glob.glob("*.docx") | |
if not docx_list: | |
return [], [] | |
doc = Document(docx_list[0]) | |
blocks = [] | |
normal_blocks = [] | |
for p in doc.paragraphs: | |
txt = p.text.strip() | |
if ( | |
txt | |
and not (len(txt) <= 3 and txt.isdigit()) | |
and len(txt.split()) > 3 | |
): | |
blocks.append(txt) | |
if not is_header(txt) and len(txt) > 25: | |
normal_blocks.append(txt) | |
for table in doc.tables: | |
for row in table.rows: | |
row_text = " | ".join(cell.text.strip() for cell in row.cells if cell.text.strip()) | |
if row_text and len(row_text.split()) > 3 and len(row_text) > 25: | |
blocks.append(row_text) | |
if not is_header(row_text): | |
normal_blocks.append(row_text) | |
# remove duplicates | |
seen = set(); blocks_clean = [] | |
for b in blocks: | |
if b not in seen: | |
blocks_clean.append(b) | |
seen.add(b) | |
seen = set(); normal_blocks_clean = [] | |
for b in normal_blocks: | |
if b not in seen: | |
normal_blocks_clean.append(b) | |
seen.add(b) | |
return blocks_clean, normal_blocks_clean | |
blocks, normal_blocks = get_blocks_from_docx() | |
if not blocks or not normal_blocks: | |
blocks = ["База знаний пуста: проверьте содержимое и структуру вашего .docx!"] | |
normal_blocks = ["База знаний пуста: проверьте содержимое и структуру вашего .docx!"] | |
vectorizer = TfidfVectorizer(lowercase=True).fit(blocks) | |
matrix = vectorizer.transform(blocks) | |
tokenizer = T5Tokenizer.from_pretrained("cointegrated/rut5-base-multitask") | |
model = T5ForConditionalGeneration.from_pretrained("cointegrated/rut5-base-multitask") | |
model.eval() | |
device = 'cpu' | |
def rut5_answer(question, context): | |
prompt = f"question: {question} context: {context}" | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
with torch.no_grad(): | |
output_ids = model.generate( | |
input_ids, | |
max_length=250, num_beams=4, min_length=40, | |
no_repeat_ngram_size=3, do_sample=False | |
) | |
return tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
def flatten_index(idx): | |
# Универсальный способ из всего достать int | |
if isinstance(idx, (int, float, np.integer, np.floating)): | |
return int(idx) | |
if isinstance(idx, (list, tuple, np.ndarray)): | |
if len(idx) == 0: | |
return 0 | |
return flatten_index(idx) | |
if hasattr(idx, "tolist"): | |
item = idx.tolist() | |
return flatten_index(item) | |
try: | |
return int(idx) | |
except Exception: | |
return 0 | |
def ask_chatbot(question): | |
question = question.strip() | |
if not question: | |
return "Пожалуйста, введите вопрос." | |
if not normal_blocks or normal_blocks == ["База знаний пуста: проверьте содержимое и структуру вашего .docx!"]: | |
return "Ошибка: база знаний пуста. Проверьте .docx и перезапустите Space." | |
user_vec = vectorizer.transform([question.lower()]) | |
sims = cosine_similarity(user_vec, matrix)[0] | |
n_blocks = min(3, len(blocks)) | |
if n_blocks == 0: | |
return "Ошибка: база знаний отсутствует или пуста." | |
sorted_idxs = sims.argsort()[-n_blocks:][::-1] | |
context_blocks = [] | |
for idx in sorted_idxs: | |
idx_int = flatten_index(idx) | |
if isinstance(idx_int, int) and 0 <= idx_int < len(blocks): | |
context_blocks.append(blocks[idx_int]) | |
context = " ".join(context_blocks) | |
# Ответ только из абзацев, не заголовков! | |
best_normal_block = "" | |
max_sim = -1 | |
for nb in normal_blocks: | |
v_nb = vectorizer.transform([nb.lower()]) | |
sim = cosine_similarity(user_vec, v_nb)[0] | |
if sim > max_sim: | |
max_sim = sim | |
best_normal_block = nb | |
if not best_normal_block: | |
best_normal_block = context_blocks if context_blocks else "" | |
answer = rut5_answer(question, context) | |
if len(answer.strip().split()) < 8 or answer.count('.') < 2: | |
answer += "\n\n" + best_normal_block | |
if is_header(answer): | |
answer = best_normal_block | |
return answer | |
EXAMPLES = [ | |
"Как оформить список литературы?", | |
"Какие сроки сдачи и защиты ВКР?", | |
"Какой процент оригинальности требуется?", | |
"Как оформлять формулы?" | |
] | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
"# Русскоязычный Чат-бот по методичке (AI+документ)\nЗадайте вопрос — получите развернутый ответ на основании вашего документа!" | |
) | |
question = gr.Textbox(label="Ваш вопрос", lines=2) | |
ask_btn = gr.Button("Получить ответ") | |
answer = gr.Markdown(label="Ответ", visible=True) | |
def with_spinner(q): | |
yield "Чат-бот думает..." | |
yield ask_chatbot(q) | |
ask_btn.click(with_spinner, question, answer) | |
question.submit(with_spinner, question, answer) | |
gr.Markdown("#### Примеры вопросов:") | |
gr.Examples(EXAMPLES, inputs=question) | |
gr.Markdown(""" | |
--- | |
### Контакты (укажите свои) | |
Преподаватель: ___________________ | |
Email: ___________________________ | |
Кафедра: _________________________ | |
""") | |
demo.launch() |