Blog_generator / app.py
parjun's picture
Rename bloggenpart2.py to app.py
05411e4 verified
import os
from typing import Dict, List, Tuple, Any, Optional
from pydantic import BaseModel, Field
import streamlit as st
from dotenv import load_dotenv
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_groq import ChatGroq
from langgraph.graph import StateGraph, END
# Load environment variables (still useful as fallback)
load_dotenv()
# Configure page
st.set_page_config(page_title="AI Blog Generator", layout="wide")
# API Key handling in sidebar
with st.sidebar:
st.title("Configuration")
# LLM Provider Selection
provider = st.radio("LLM Provider", ["OpenAI", "Groq"])
if provider == "OpenAI":
openai_api_key = st.text_input("OpenAI API Key", type="password", help="Enter your OpenAI API key here")
model = st.selectbox("Model", ["gpt-3.5-turbo", "gpt-4", "gpt-4o"])
if openai_api_key:
os.environ["OPENAI_API_KEY"] = openai_api_key
else: # Groq
groq_api_key = st.text_input("Groq API Key", type="password", help="Enter your Groq API key here")
model = st.selectbox("Model", ["llama-3.3-70b-versatile","gemma2-9b-it","qwen-2.5-32b","mistral-saba-24b", "deepseek-r1-distill-qwen-32b"])
if groq_api_key:
os.environ["GROQ_API_KEY"] = groq_api_key
st.divider()
st.write("## About")
st.write("This app uses LangGraph to generate structured blog posts with a multi-step workflow.")
st.write("Made with ❤️ using LangGraph and Streamlit")
# Define the state schema
class BlogGeneratorState(BaseModel):
topic: str = Field(default="")
audience: str = Field(default="")
tone: str = Field(default="")
word_count: int = Field(default=500)
outline: List[str] = Field(default_factory=list)
sections: Dict[str, str] = Field(default_factory=dict)
final_blog: str = Field(default="")
error: Optional[str] = Field(default=None)
# Initialize LLM based on selected provider
def get_llm():
global provider, model
if provider == "OpenAI":
if not os.environ.get("OPENAI_API_KEY"):
st.error("Please enter your OpenAI API key in the sidebar")
st.stop()
return ChatOpenAI(model=model, temperature=0.7)
else: # Groq
if not os.environ.get("GROQ_API_KEY"):
st.error("Please enter your Groq API key in the sidebar")
st.stop()
return ChatGroq(model=model, temperature=0.7)
# Create prompt templates
outline_prompt = ChatPromptTemplate.from_template(
"""You are a professional blog writer. Create an outline for a blog post about {topic}.
The audience is {audience} and the tone should be {tone}.
The blog should be approximately {word_count} words.
Return ONLY the outline as a list of section headings (without numbers or bullets).
Each heading should be concise and engaging."""
)
section_prompt = ChatPromptTemplate.from_template(
"""Write content for the following section of a blog post about {topic}:
Section: {section}
The audience is {audience} and the tone should be {tone}.
Make this section approximately {section_word_count} words.
Make the content engaging, informative, and valuable to the reader.
Return ONLY the content for this section, without the heading."""
)
final_assembly_prompt = ChatPromptTemplate.from_template(
"""You have a blog post with the following sections:
{sections_content}
Format this into a complete, professional blog post in Markdown format with:
1. An engaging title at the top as an H1 heading
2. A brief introduction before the first section
3. Each section heading as an H2
4. A conclusion at the end
5. Proper spacing between sections
6. 2-3 relevant markdown formatting elements like bold, italic, blockquotes, or bullet points where appropriate
The blog should maintain the {tone} tone and be targeted at {audience}.
Make it flow naturally between sections."""
)
# Define the nodes for the graph
def get_outline(state: BlogGeneratorState) -> BlogGeneratorState:
"""Generate an outline for the blog post."""
try:
llm = get_llm()
chain = outline_prompt | llm
response = chain.invoke({
"topic": state.topic,
"audience": state.audience,
"tone": state.tone,
"word_count": state.word_count
})
# Parse the outline into a list
output_text = response.content
outline = [line.strip() for line in output_text.split('\n') if line.strip()]
return BlogGeneratorState(**{**state.model_dump(), "outline": outline})
except Exception as e:
st.error(f"Outline Error: {str(e)}")
return BlogGeneratorState(**{**state.model_dump(), "error": f"Error generating outline: {str(e)}"})
def generate_sections(state: BlogGeneratorState) -> BlogGeneratorState:
"""Generate content for each section in the outline."""
if state.error:
return state
sections = {}
section_word_count = state.word_count // len(state.outline)
try:
llm = get_llm()
chain = section_prompt | llm
# Show progress
progress_bar = st.progress(0)
status_text = st.empty()
for i, section in enumerate(state.outline):
status_text.text(f"Generating section {i+1}/{len(state.outline)}: {section}")
response = chain.invoke({
"topic": state.topic,
"section": section,
"audience": state.audience,
"tone": state.tone,
"section_word_count": section_word_count
})
sections[section] = response.content
progress_bar.progress((i + 1) / len(state.outline))
status_text.empty()
progress_bar.empty()
return BlogGeneratorState(**{**state.model_dump(), "sections": sections})
except Exception as e:
return BlogGeneratorState(**{**state.model_dump(), "error": f"Error generating sections: {str(e)}"})
def assemble_blog(state: BlogGeneratorState) -> BlogGeneratorState:
"""Assemble the final blog post in Markdown format."""
if state.error:
return state
try:
llm = get_llm()
chain = final_assembly_prompt | llm
sections_content = "\n\n".join([f"Section: {heading}\nContent: {content}"
for heading, content in state.sections.items()])
response = chain.invoke({
"sections_content": sections_content,
"tone": state.tone,
"audience": state.audience
})
final_blog = response.content
return BlogGeneratorState(**{**state.model_dump(), "final_blog": final_blog})
except Exception as e:
return BlogGeneratorState(**{**state.model_dump(), "error": f"Error assembling blog: {str(e)}"})
# Define the workflow graph
def create_blog_generator_graph():
workflow = StateGraph(BlogGeneratorState)
# Add nodes
workflow.add_node("get_outline", get_outline)
workflow.add_node("generate_sections", generate_sections)
workflow.add_node("assemble_blog", assemble_blog)
# Add edges
workflow.add_edge("get_outline", "generate_sections")
workflow.add_edge("generate_sections", "assemble_blog")
workflow.add_edge("assemble_blog", END)
# Set the entry point
workflow.set_entry_point("get_outline")
return workflow.compile()
# Create the Streamlit app main content
st.title("AI Blog Generator")
st.write("Generate professional blog posts with a structured workflow")
with st.form("blog_generator_form"):
topic = st.text_input("Blog Topic", placeholder="E.g., Sustainable Living in Urban Environments")
col1, col2 = st.columns(2)
with col1:
audience = st.text_input("Target Audience", placeholder="E.g., Young professionals")
tone = st.selectbox("Tone", ["Informative", "Conversational", "Professional", "Inspirational", "Technical"])
with col2:
word_count = st.slider("Approximate Word Count", min_value=300, max_value=2000, value=800, step=100)
submit_button = st.form_submit_button("Generate Blog")
if submit_button:
if (provider == "OpenAI" and not os.environ.get("OPENAI_API_KEY")) or \
(provider == "Groq" and not os.environ.get("GROQ_API_KEY")):
st.error(f"Please enter your {provider} API key in the sidebar before generating a blog")
elif not topic or not audience:
st.error("Please fill out all required fields.")
else:
with st.spinner(f"Initializing blog generation using {provider} {model}..."):
try:
# Initialize the graph
blog_generator = create_blog_generator_graph()
# Set the initial state
initial_state = BlogGeneratorState(
topic=topic,
audience=audience,
tone=tone,
word_count=word_count
)
# Run the graph
result = blog_generator.invoke(initial_state)
# Check if result is a dict and has expected keys
if isinstance(result, dict):
final_blog = result.get("final_blog", "")
outline = result.get("outline", [])
error = result.get("error")
if error:
st.error(f"Error: {error}")
elif final_blog:
# Display the blog post
st.success("Blog post generated successfully!")
st.subheader("Generated Blog Post")
st.markdown(final_blog)
# Download button for the blog post
st.download_button(
label="Download Blog as Markdown",
data=final_blog,
file_name=f"{topic.replace(' ', '_').lower()}_blog.md",
mime="text/markdown",
)
# Show metadata about the generation
st.info(f"Generated using {provider} {model}")
# Optionally show the outline
with st.expander("View Blog Outline"):
for i, section in enumerate(outline, 1):
st.write(f"{i}. {section}")
else:
st.error("Blog generation completed but no content was produced")
else:
st.error(f"Unexpected result type: {type(result)}")
except Exception as e:
st.error(f"An error occurred: {str(e)}")
st.info("Please check your API key and try again.")