Spaces:
Build error
Build error
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import torch # Import torch for device management | |
| import os # For file operations | |
| # --- Configuration and Model Loading --- | |
| # You can choose a different model here if you have access to more powerful ones. | |
| # For larger models, ensure you have sufficient VRAM (GPU memory). | |
| # For CPU, smaller models might be necessary or use quantization. | |
| MODEL_NAME = "google/flan-t5-large" # Changed to 'large' for slightly better performance than 'base' and still manageable. | |
| # If you have a powerful GPU, consider "google/flan-t5-xl" or even "google/flan-t5-xxl" | |
| # For even larger models, consider using model.to(torch.bfloat16) or bitsandbytes for 4-bit loading if available. | |
| try: | |
| # Determine the device to use (GPU if available, else CPU) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Loading model on device: {device}") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| # Load model with half-precision (float16) to save VRAM if on GPU | |
| # Or load in 8-bit/4-bit if using libraries like bitsandbytes (requires installation) | |
| if device == "cuda": | |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to(device) | |
| else: | |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device) | |
| model.eval() # Set model to evaluation mode | |
| print(f"Model '{MODEL_NAME}' loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| print("Please check your internet connection, model name, and available resources (RAM/VRAM).") | |
| # Exit or handle gracefully if model loading fails | |
| tokenizer, model = None, None | |
| # --- Prompt Engineering Functions (more structured) --- | |
| def create_arabic_prompt(topic, style): | |
| if style == "Blog Post (Descriptive)": | |
| return f"اكتب مقالاً احترافياً بأسلوب شخصي عن: {topic}. ركز على التفاصيل، الوصف الجذاب، قدم نصائح عملية. اجعل النص منسقاً بفقرات وعناوين فرعية." | |
| elif style == "Social Media Post (Short & Catchy)": | |
| return f"اكتب منشوراً قصيراً وجذاباً ومثيراً للتفاعل عن: {topic}. أضف 2-3 إيموجي مناسبة واقترح 4 هاشتاغات شائعة. ابدأ بسؤال أو جملة جذابة." | |
| else: # Video Script (Storytelling) | |
| return f"اكتب سيناريو فيديو احترافي ومقنع عن: {topic}. اجعل الأسلوب قصصي وسردي، مقسماً إلى مشاهد رئيسية، مع اقتراح لقطات بصرية (B-roll) وأصوات (SFX) لكل مشهد. ركز على إثارة المشاعر." | |
| def create_english_prompt(topic, style): | |
| if style == "Blog Post (Descriptive)": | |
| return f"Write a detailed and professional blog post about: {topic}. Focus on personal insights, vivid descriptions, and practical advice. Structure it with clear paragraphs and subheadings." | |
| elif style == "Social Media Post (Short & Catchy)": | |
| return f"Write a short, catchy, and engaging social media post about: {topic}. Include 2-3 relevant emojis and suggest 4 trending hashtags. Start with a hook question or statement." | |
| else: # Video Script (Storytelling) | |
| return f"Write a professional, compelling video script about: {topic}. Make it emotionally engaging and story-driven, divided into key scenes, with suggested visual shots (B-roll) and sound effects (SFX) for each scene." | |
| # --- Content Generation Function --- | |
| # Disable gradient calculations for inference to save memory | |
| def generate_content(topic, style_choice, lang_choice, length_choice, creativity, detail_level, diversity_penalty): | |
| if tokenizer is None or model is None: | |
| return "⚠️ Error: Model not loaded. Please check the console for details." | |
| if not topic: | |
| return "⚠️ Please enter a topic to generate content." | |
| # Max length based on desired length and model's context window | |
| # Flan-T5 has a context window of 512, so max_length should be within this. | |
| if length_choice == "Short": | |
| max_new_tokens = 150 | |
| min_new_tokens = 50 | |
| elif length_choice == "Medium": | |
| max_new_tokens = 300 | |
| min_new_tokens = 100 | |
| else: # Long | |
| max_new_tokens = 450 # Max for Flan-T5 effectively | |
| min_new_tokens = 150 | |
| # Adjust generation parameters based on user input | |
| temperature = creativity # Direct mapping | |
| top_p = detail_level # Direct mapping, higher means more detail/diversity | |
| no_repeat_ngram_size = diversity_penalty # Higher means less repetition | |
| # Build the prompt | |
| if lang_choice == "Arabic": | |
| prompt = create_arabic_prompt(topic, style_choice) | |
| else: # English | |
| prompt = create_english_prompt(topic, style_choice) | |
| # Add detail level instruction to prompt if high | |
| if detail_level > 0.7: # Only if user explicitly wants high detail | |
| prompt += " Ensure comprehensive coverage and rich descriptions." | |
| if creativity > 0.8: | |
| prompt += " Be highly creative and imaginative in your writing." | |
| try: | |
| inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(device) | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| min_new_tokens=min_new_tokens, | |
| num_beams=5, # Beam search for better quality | |
| do_sample=True, # Enable sampling for creativity | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=50, # Consider top 50 words | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| length_penalty=1.0, # Adjust to control output length | |
| early_stopping=True | |
| ) | |
| content = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return content | |
| except RuntimeError as e: | |
| if "out of memory" in str(e): | |
| return "⚠️ Generation failed: Out of memory. Try a shorter length, a less complex model, or restart the application if on GPU." | |
| return f"⚠️ Generation failed due as runtime error: {str(e)}" | |
| except Exception as e: | |
| return f"⚠️ An unexpected error occurred during generation: {str(e)}" | |
| # --- Gradio Interface --- | |
| # Custom CSS for a more polished look | |
| custom_css = """ | |
| h1, h2, h3 { color: #4B0082; } /* Dark Purple */ | |
| .gradio-container { | |
| background-color: #F8F0FF; /* Light Lavender */ | |
| font-family: 'Segoe UI', sans-serif; | |
| } | |
| .gr-button { | |
| background-color: #8A2BE2; /* Blue Violet */ | |
| color: white; | |
| border-radius: 10px; | |
| padding: 10px 20px; | |
| font-size: 1.1em; | |
| } | |
| .gr-button:hover { | |
| background-color: #9370DB; /* Medium Purple */ | |
| } | |
| .gr-text-input, .gr-textarea { | |
| border: 1px solid #DDA0DD; /* Plum */ | |
| border-radius: 8px; | |
| padding: 10px; | |
| } | |
| .gradio-radio input:checked + label { | |
| background-color: #DA70D6 !important; /* Orchid */ | |
| color: white !important; | |
| } | |
| .gradio-radio label { | |
| border: 1px solid #DDA0DD; | |
| border-radius: 8px; | |
| padding: 8px 15px; | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as iface: | |
| gr.Markdown("# ✨ AI Content Creation Studio") | |
| gr.Markdown("## Generate professional blogs, social media posts, or video scripts in seconds!") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| topic = gr.Textbox( | |
| label="Topic / الموضوع", | |
| placeholder="e.g., The Future of AI in Healthcare / مثال: مستقبل الذكاء الاصطناعي في الرعاية الصحية", | |
| lines=2 | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| creativity = gr.Slider( | |
| minimum=0.1, maximum=1.0, value=0.7, step=0.1, | |
| label="Creativity (Temperature)", | |
| info="Higher values lead to more creative, less predictable text. Lower values are more focused." | |
| ) | |
| detail_level = gr.Slider( | |
| minimum=0.1, maximum=1.0, value=0.9, step=0.1, | |
| label="Detail Level (Top-p Sampling)", | |
| info="Higher values allow for more diverse and detailed vocabulary. Lower values prune less likely words." | |
| ) | |
| with gr.Row(): | |
| diversity_penalty = gr.Slider( | |
| minimum=1, maximum=5, value=2, step=1, | |
| label="Repetition Penalty (N-gram)", | |
| info="Higher values reduce the chance of repeating the same phrases or words. Set to 1 for no penalty." | |
| ) | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| style_choice = gr.Radio( | |
| ["Blog Post (Descriptive)", "Social Media Post (Short & Catchy)", "Video Script (Storytelling)"], | |
| label="Content Style / نوع المحتوى", | |
| value="Blog Post (Descriptive)", | |
| interactive=True | |
| ) | |
| with gr.Group(): | |
| lang_choice = gr.Radio( | |
| ["English", "Arabic"], | |
| label="Language / اللغة", | |
| value="English", | |
| interactive=True | |
| ) | |
| with gr.Group(): | |
| length_choice = gr.Radio( | |
| ["Short", "Medium", "Long"], | |
| label="Content Length / طول النص", | |
| value="Medium", | |
| interactive=True | |
| ) | |
| gr.Markdown("*(Note: 'Long' is relative to model capabilities, max ~450 words)*") | |
| btn = gr.Button("🚀 Generate Content", variant="primary") | |
| output = gr.Textbox(label="Generated Content", lines=20, interactive=True) | |
| # Download button logic | |
| def download_file(content): | |
| if content and not content.startswith("⚠️"): # Only provide file if content is valid | |
| file_path = "generated_content.txt" | |
| with open(file_path, "w", encoding="utf-8") as f: | |
| f.write(content) | |
| return file_path | |
| return None # Return None if no valid content to download | |
| download_button = gr.DownloadButton("⬇️ Download Content", file_path=None, interactive=False) | |
| # Event handlers | |
| btn.click( | |
| fn=generate_content, | |
| inputs=[topic, style_choice, lang_choice, length_choice, creativity, detail_level, diversity_penalty], | |
| outputs=output | |
| ) | |
| # Enable download button only when there's valid content | |
| output.change(fn=download_file, inputs=[output], outputs=[download_button]) | |
| if __name__ == "__main__": | |
| iface.launch() |