|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline |
|
import matplotlib.pyplot as plt |
|
import networkx as nx |
|
import io |
|
from PIL import Image |
|
import torch |
|
import os |
|
|
|
print("Installation complete. Loading models...") |
|
|
|
model_name = "csebuetnlp/mT5_multilingual_XLSum" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Using device: {device}") |
|
model = model.to(device) |
|
|
|
question_generator = pipeline( |
|
"text2text-generation", |
|
model="valhalla/t5-small-e2e-qg", |
|
device=device if device == "cuda" else -1 |
|
) |
|
|
|
def summarize_text(text, src_lang): |
|
inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to(device) |
|
|
|
|
|
summary_ids = model.generate( |
|
inputs["input_ids"], |
|
max_length=150, |
|
min_length=30, |
|
length_penalty=2.0, |
|
num_beams=4, |
|
early_stopping=True |
|
) |
|
|
|
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
return summary |
|
|
|
def generate_questions(summary): |
|
questions = [] |
|
for _ in range(3): |
|
result = question_generator( |
|
summary, |
|
max_length=64, |
|
num_beams=4, |
|
do_sample=True, |
|
top_k=30, |
|
top_p=0.95, |
|
temperature=0.7 |
|
) |
|
questions.append(result[0]['generated_text']) |
|
|
|
questions = list(set(questions)) |
|
return questions |
|
|
|
def generate_concept_map(summary, questions): |
|
|
|
G = nx.DiGraph() |
|
|
|
|
|
summary_short = summary[:50] + "..." if len(summary) > 50 else summary |
|
G.add_node("summary", label=summary_short) |
|
|
|
|
|
for i, question in enumerate(questions): |
|
q_short = question[:30] + "..." if len(question) > 30 else question |
|
node_id = f"Q{i}" |
|
G.add_node(node_id, label=q_short) |
|
G.add_edge("summary", node_id) |
|
|
|
|
|
plt.figure(figsize=(10, 8)) |
|
pos = nx.spring_layout(G, seed=42) |
|
nx.draw(G, pos, with_labels=False, node_color='skyblue', |
|
node_size=1500, arrows=True, connectionstyle='arc3,rad=0.1', |
|
edgecolors='black', linewidths=1) |
|
|
|
|
|
labels = nx.get_node_attributes(G, 'label') |
|
nx.draw_networkx_labels(G, pos, labels=labels, font_size=9, |
|
font_family='sans-serif') |
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') |
|
buf.seek(0) |
|
plt.close() |
|
|
|
return Image.open(buf) |
|
|
|
def analyze_text(text, lang): |
|
if not text.strip(): |
|
return "Please enter some text.", "No questions generated.", None |
|
|
|
|
|
try: |
|
print("Generating summary...") |
|
summary = summarize_text(text, lang) |
|
|
|
print("Generating questions...") |
|
questions = generate_questions(summary) |
|
|
|
print("Creating concept map...") |
|
concept_map_image = generate_concept_map(summary, questions) |
|
|
|
|
|
questions_text = "\n".join([f"- {q}" for q in questions]) |
|
|
|
return summary, questions_text, concept_map_image |
|
except Exception as e: |
|
import traceback |
|
print(f"Error processing text: {str(e)}") |
|
print(traceback.format_exc()) |
|
return f"Error processing text: {str(e)}", "", None |
|
|
|
def generate_simple_concept_map(summary, questions): |
|
"""Fallback concept map generator with minimal dependencies""" |
|
plt.figure(figsize=(10, 8)) |
|
|
|
n_questions = len(questions) |
|
|
|
plt.scatter([0], [0], s=1000, color='skyblue', edgecolors='black') |
|
plt.text(0, 0, summary[:50] + "..." if len(summary) > 50 else summary, |
|
ha='center', va='center', fontsize=9) |
|
|
|
radius = 5 |
|
for i, question in enumerate(questions): |
|
angle = 2 * 3.14159 * i / max(n_questions, 1) |
|
x = radius * 0.8 * -1 * (max(n_questions, 1) - 1) * ((i / max(n_questions - 1, 1)) - 0.5) |
|
y = radius * 0.6 * (i % 2 * 2 - 1) |
|
|
|
plt.scatter([x], [y], s=800, color='lightgreen', edgecolors='black') |
|
|
|
plt.plot([0, x], [0, y], 'k-', alpha=0.6) |
|
|
|
plt.text(x, y, question[:30] + "..." if len(question) > 30 else question, |
|
ha='center', va='center', fontsize=8) |
|
|
|
plt.axis('equal') |
|
plt.axis('off') |
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') |
|
buf.seek(0) |
|
plt.close() |
|
|
|
return Image.open(buf) |
|
|
|
examples = [ |
|
["الذكاء الاصطناعي هو فرع من علوم الكمبيوتر يهدف إلى إنشاء آلات ذكية تعمل وتتفاعل مثل البشر. بعض الأنشطة التي صممت أجهزة الكمبيوتر الذكية للقيام بها تشمل: التعرف على الصوت، التعلم، التخطيط، وحل المشاكل.", "ar"], |
|
["Artificial intelligence is a branch of computer science that aims to create intelligent machines that work and react like humans. Some of the activities computers with artificial intelligence are designed for include: Speech recognition, learning, planning, and problem-solving.", "en"] |
|
] |
|
|
|
print("Creating Gradio interface...") |
|
|
|
def analyze_text_with_fallback(text, lang): |
|
if not text.strip(): |
|
return "Please enter some text.", "No questions generated.", None |
|
|
|
try: |
|
print("Generating summary...") |
|
summary = summarize_text(text, lang) |
|
|
|
print("Generating questions...") |
|
questions = generate_questions(summary) |
|
|
|
print("Creating concept map...") |
|
try: |
|
concept_map_image = generate_concept_map(summary, questions) |
|
except Exception as e: |
|
print(f"Main concept map failed: {e}, using fallback") |
|
concept_map_image = generate_simple_concept_map(summary, questions) |
|
|
|
questions_text = "\n".join([f"- {q}" for q in questions]) |
|
|
|
return summary, questions_text, concept_map_image |
|
except Exception as e: |
|
import traceback |
|
print(f"Error processing text: {str(e)}") |
|
print(traceback.format_exc()) |
|
return f"Error processing text: {str(e)}", "", None |
|
|
|
iface = gr.Interface( |
|
fn=analyze_text_with_fallback, |
|
inputs=[gr.Textbox(lines=10, placeholder="Enter text here..."), gr.Dropdown(["ar", "en"], label="Language")], |
|
outputs=[gr.Textbox(label="Summary"), gr.Textbox(label="Questions"), gr.Image(label="Concept Map")], |
|
examples=examples, |
|
title="AI Study Assistant", |
|
description="Enter a text in Arabic or English and the model will summarize it and generate questions and a concept map." |
|
) |
|
|
|
iface.launch(share=True) |