import spaces import gradio as gr from PyPDF2 import PdfReader from io import BytesIO from fpdf import FPDF # Importiere FPDF, um PDF-Dateien zu erstellen from transformers import AutoTokenizer, AutoModelForCausalLM import torch import re # Global variables for model and tokenizer model = None tokenizer = None def load_model(): global model, tokenizer device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") model_name = "microsoft/phi-2" # Changed to a smaller model that doesn't require quantization tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto" if device == "cuda" else None, torch_dtype=torch.float32 # Use float32 for CPU compatibility ) if device == "cpu": model = model.to(device) print("Model loaded successfully") def extract_text_from_pdf(pdf_file): text = "" try: pdf_reader = PdfReader(BytesIO(pdf_file)) for page in pdf_reader.pages: text += page.extract_text() except Exception as e: text = f"Fehler beim Lesen der PDF: {str(e)}" return text def generate_flashcards(text): global model, tokenizer if model is None or tokenizer is None: return "Modell wurde nicht geladen. Bitte versuchen Sie es erneut." prompt = f"Erstelle Karteikarten mit Frage und Antwort basierend auf dem folgenden Text. Formatiere jede Karteikarte als 'Frage: [Frage] Antwort: [Antwort]' und trenne die Karteikarten mit einer Leerzeile:\n\n{text[:1000]}\n\nKarteikarten:\n" inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(model.device) try: with torch.no_grad(): outputs = model.generate(**inputs, max_length=2000, num_return_sequences=1, temperature=0.7) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) flashcards_text = generated_text.split("Karteikarten:")[-1].strip() if not flashcards_text: return "Es konnten keine Karteikarten generiert werden. Bitte versuchen Sie es erneut." flashcards = re.findall(r'Frage: (.*?)\s*Antwort: (.*?)(?:\n\n|\Z)', flashcards_text, re.DOTALL) return flashcards except Exception as e: return f"Fehler bei der Generierung der Karteikarten: {str(e)}" def create_pdf(flashcards): pdf = FPDF() pdf.add_page() pdf.set_font("Arial", size=12) for question, answer in flashcards: pdf.multi_cell(0, 10, f"Frage: {question}") pdf.multi_cell(0, 10, f"Antwort: {answer}") pdf.ln(10) # Adds space between flashcards pdf_file_path = "/mnt/data/flashcards_output.pdf" pdf.output(pdf_file_path) return pdf_file_path @spaces.GPU(duration=60) def process_pdf(pdf_file): if pdf_file is None: return [], None text = extract_text_from_pdf(pdf_file) if text.startswith("Fehler beim Lesen der PDF"): return [], None flashcards = generate_flashcards(text) if isinstance(flashcards, list): pdf_path = create_pdf(flashcards) return flashcards, pdf_path return [], None def update_flashcard(flashcards, index, current_side): if not flashcards or index >= len(flashcards): return gr.update(visible=False), gr.update(visible=False) question, answer = flashcards[index] if current_side == "question": return gr.update(visible=True, value=question), gr.update(visible=False) else: return gr.update(visible=False), gr.update(visible=True, value=answer) def flip_card(flashcards, index, current_side): new_side = "answer" if current_side == "question" else "question" return update_flashcard(flashcards, index, new_side) + (new_side,) with gr.Blocks() as iface: gr.Markdown("# FlashcardAI") gr.Markdown("Laden Sie eine PDF hoch und klicken Sie auf 'Karteikarten generieren', um zu beginnen.") with gr.Row(): pdf_input = gr.File(label="PDF hochladen", type="binary") generate_button = gr.Button("Karteikarten generieren", variant="primary") flashcards_state = gr.State([]) current_card_index = gr.State(0) current_side = gr.State("question") with gr.Row(): prev_button = gr.Button("Vorherige Karte") flip_button = gr.Button("Karte umdrehen") next_button = gr.Button("Nächste Karte") with gr.Row(): question_box = gr.Textbox(label="Frage", interactive=False) answer_box = gr.Textbox(label="Antwort", interactive=False, visible=False) pdf_download_button = gr.File(label="Download Flashcards PDF") generate_button.click( process_pdf, inputs=[pdf_input], outputs=[flashcards_state, pdf_download_button] ).then( lambda cards: update_flashcard(cards, 0, "question") + (0, "question"), inputs=[flashcards_state], outputs=[question_box, answer_box, current_card_index, current_side] ) flip_button.click( flip_card, inputs=[flashcards_state, current_card_index, current_side], outputs=[question_box, answer_box, current_side] ) prev_button.click( lambda cards, index, side: update_flashcard(cards, max(0, index - 1), "question") + (max(0, index - 1), "question"), inputs=[flashcards_state, current_card_index, current_side], outputs=[question_box, answer_box, current_card_index, current_side] ) next_button.click( lambda cards, index, side: update_flashcard(cards, min(len(cards) - 1, index + 1), "question") + (min(len(cards) - 1, index + 1), "question"), inputs=[flashcards_state, current_card_index, current_side], outputs=[question_box, answer_box, current_card_index, current_side] ) if __name__ == "__main__": load_model() # Load the model at startup iface.launch()