reddit-chatbot / chatbot.py
ahmadgenus
new_chatbot_minor_chnages_2
a1a35ae
from langchain.chains import LLMChain
import os
import sqlite3
import praw
import json
from datetime import datetime, timedelta
from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from langchain.prompts import ChatPromptTemplate
from langchain.chains import ConversationChain, LLMChain
from langchain.memory import ConversationBufferMemory
load_dotenv()
# Initialize the LLM via LangChain (using Groq)
llm = ChatGroq(
groq_api_key=os.getenv("GROQ_API_KEY"),
model_name="meta-llama/llama-4-maverick-17b-128e-instruct",
temperature=0.2
)
# Embedding Model
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
# Reddit API Setup
reddit = praw.Reddit(
client_id=os.getenv("REDDIT_CLIENT_ID"),
client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
user_agent=os.getenv("REDDIT_USER_AGENT")
)
# SQLite DB Connection
def get_db_conn():
return sqlite3.connect("reddit_data.db", check_same_thread=False)
# Set up the database schema
def setup_db():
conn = get_db_conn()
cur = conn.cursor()
try:
cur.execute("""
CREATE TABLE IF NOT EXISTS reddit_posts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
reddit_id TEXT UNIQUE,
keyword TEXT,
title TEXT,
post_text TEXT,
comments TEXT,
created_at TEXT,
embedding TEXT,
metadata TEXT
);
""")
conn.commit()
except Exception as e:
print("DB Setup Error:", e)
finally:
cur.close()
conn.close()
# Keyword filter
def keyword_in_post_or_comments(post, keyword):
keyword_lower = keyword.lower()
combined_text = (post.title + " " + post.selftext).lower()
if keyword_lower in combined_text:
return True
post.comments.replace_more(limit=None)
for comment in post.comments.list():
if keyword_lower in comment.body.lower():
return True
return False
# Fetch and process Reddit data
def fetch_reddit_data(keyword, days=7, limit=None):
end_time = datetime.utcnow()
start_time = end_time - timedelta(days=days)
subreddit = reddit.subreddit("all")
posts_generator = subreddit.search(keyword, sort="new", time_filter="all", limit=limit)
data = []
for post in posts_generator:
created = datetime.utcfromtimestamp(post.created_utc)
if created < start_time:
break
if not keyword_in_post_or_comments(post, keyword):
continue
post.comments.replace_more(limit=None)
comments = [comment.body for comment in post.comments.list()]
combined_text = f"{post.title}\n{post.selftext}\n{' '.join(comments)}"
embedding = embedder.encode(combined_text).tolist()
metadata = {
"url": post.url,
"subreddit": post.subreddit.display_name,
"comments_count": len(comments)
}
data.append({
"reddit_id": post.id,
"keyword": keyword,
"title": post.title,
"post_text": post.selftext,
"comments": comments,
"created_at": created.isoformat(),
"embedding": embedding,
"metadata": metadata
})
if data:
save_to_db(data)
# Save data into SQLite
def save_to_db(posts):
conn = get_db_conn()
cur = conn.cursor()
for post in posts:
try:
cur.execute("""
INSERT OR IGNORE INTO reddit_posts
(reddit_id, keyword, title, post_text, comments, created_at, embedding, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?);
""", (
post["reddit_id"],
post["keyword"],
post["title"],
post["post_text"],
json.dumps(post["comments"]),
post["created_at"],
json.dumps(post["embedding"]),
json.dumps(post["metadata"])
))
except Exception as e:
print("Insert Error:", e)
conn.commit()
cur.close()
conn.close()
# Retrieve similar context from DB
def retrieve_context(question, keyword, reddit_id=None, top_k=10):
lower_q = question.lower()
requested_top_k = 50 if any(word in lower_q for word in ["summarize", "overview", "all posts"]) else top_k
conn = get_db_conn()
cur = conn.cursor()
if reddit_id:
cur.execute("""
SELECT title, post_text, comments FROM reddit_posts
WHERE reddit_id = ?;
""", (reddit_id,))
else:
cur.execute("""
SELECT title, post_text, comments FROM reddit_posts
WHERE keyword = ? ORDER BY datetime(created_at) DESC LIMIT ?;
""", (keyword, requested_top_k))
results = cur.fetchall()
cur.close()
conn.close()
return results
# Summarizer
summarize_prompt = ChatPromptTemplate.from_template("""
You are a summarizer. Summarize the following context from Reddit posts into a concise summary that preserves the key insights. Do not add extra commentary.
Context:
{context}
Summary:
""")
summarize_chain = LLMChain(llm=llm, prompt=summarize_prompt)
# Chatbot memory and prompt
memory = ConversationBufferMemory(memory_key="chat_history")
chat_prompt = ChatPromptTemplate.from_template("""
Chat History:
{chat_history}
Context from Reddit and User Question:
{input}
Act as a Professional Assistant as an incremental chat agent. Provide reasoning and answer clearly based on the context and chat history. Your response should be valid, concise, Attractive and relevant.
""")
chat_chain = LLMChain(
llm=llm,
prompt=chat_prompt,
memory=memory,
verbose=True
)
# Chatbot response
def get_chatbot_response(question, keyword, reddit_id=None):
context_posts = retrieve_context(question, keyword, reddit_id)
context = "\n\n".join([f"{p[0]}:\n{p[1]}" for p in context_posts])
if len(context) > 3000:
context = summarize_chain.invoke({"context": context})
combined_input = f"Context:\n{context}\n\nUser Question: {question}"
response = chat_chain.invoke({"input": combined_input})
return response, context_posts