Spaces:
Sleeping
Sleeping
# main.py | |
import os | |
from fastapi import FastAPI, HTTPException, Depends | |
from fastapi.security import OAuth2PasswordBearer | |
from sqlalchemy.orm import Session | |
from pydantic import BaseModel | |
from typing import List | |
import autogen | |
from crewai import Agent, Task, Crew, Process | |
from huggingface_hub import InferenceClient | |
import redis | |
import json | |
import logging | |
from database import SessionLocal, engine, Base | |
from models import User, Query, Response | |
from auth import create_access_token, get_current_user | |
# Initialize FastAPI app | |
app = FastAPI(title="Zerodha Support System MVP") | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Initialize LLM client | |
hf_client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2") | |
# Initialize Redis client | |
redis_client = redis.Redis(host='localhost', port=6379, db=0) | |
# AutoGen configuration | |
config_list = [{"model": "gpt-3.5-turbo"}] | |
# AutoGen Agents | |
query_analyzer = autogen.AssistantAgent( | |
name="QueryAnalyzer", | |
system_message="Analyze and categorize incoming customer queries for Zerodha support. Determine query priority and complexity.", | |
llm_config={"config_list": config_list}, | |
) | |
compliance_agent = autogen.AssistantAgent( | |
name="ComplianceAgent", | |
system_message="Ensure all responses comply with financial regulations and Zerodha policies.", | |
llm_config={"config_list": config_list}, | |
) | |
kb_manager = autogen.AssistantAgent( | |
name="KnowledgeBaseManager", | |
system_message="Update and organize Zerodha's knowledge base based on customer interactions.", | |
llm_config={"config_list": config_list}, | |
) | |
sentiment_analyzer = autogen.AssistantAgent( | |
name="SentimentAnalyzer", | |
system_message="Analyze customer sentiment from interactions.", | |
llm_config={"config_list": config_list}, | |
) | |
coordinator = autogen.AssistantAgent( | |
name="Coordinator", | |
system_message="Coordinate responses from different agents and synthesize a final response.", | |
llm_config={"config_list": config_list}, | |
) | |
# CrewAI Agents | |
account_specialist = Agent( | |
role='Account Specialist', | |
goal='Handle account-related queries and processes', | |
backstory='Expert in Zerodha\'s account management systems and procedures.', | |
verbose=True | |
) | |
trading_expert = Agent( | |
role='Trading Expert', | |
goal='Assist with trading-related questions and provide market insights', | |
backstory='Seasoned trader with deep knowledge of Zerodha\'s trading platforms.', | |
verbose=True | |
) | |
technical_support = Agent( | |
role='Technical Support', | |
goal='Troubleshoot platform issues and provide technical guidance', | |
backstory='Technical expert familiar with all Zerodha platforms and common issues.', | |
verbose=True | |
) | |
learning_dev = Agent( | |
role='Learning and Development', | |
goal='Design educational content and trading tutorials', | |
backstory='Educational expert specializing in financial literacy and trading education.', | |
verbose=True | |
) | |
product_specialist = Agent( | |
role='Product Specialist', | |
goal='Provide information on Zerodha\'s products and compare with competitors', | |
backstory='Expert in Zerodha\'s product line and the broader financial services market.', | |
verbose=True | |
) | |
# CrewAI Tasks | |
account_task = Task( | |
description='Handle account-related query and provide detailed guidance', | |
agent=account_specialist | |
) | |
trading_task = Task( | |
description='Address trading-related question and offer market insights', | |
agent=trading_expert | |
) | |
tech_support_task = Task( | |
description='Troubleshoot technical issue and provide step-by-step guidance', | |
agent=technical_support | |
) | |
learning_task = Task( | |
description='Create educational content based on user query and skill level', | |
agent=learning_dev | |
) | |
product_task = Task( | |
description='Provide product information and recommendations', | |
agent=product_specialist | |
) | |
# Create CrewAI Crew | |
zerodha_crew = Crew( | |
agents=[account_specialist, trading_expert, technical_support, learning_dev, product_specialist], | |
tasks=[account_task, trading_task, tech_support_task, learning_task, product_task], | |
verbose=2 | |
) | |
# Pydantic models | |
class QueryInput(BaseModel): | |
text: str | |
class QueryOutput(BaseModel): | |
response: str | |
sentiment: str | |
# Dependency to get the database session | |
def get_db(): | |
db = SessionLocal() | |
try: | |
yield db | |
finally: | |
db.close() | |
# Helper function to generate LLM response | |
def generate_llm_response(prompt): | |
return hf_client.text_generation(prompt, max_new_tokens=200, temperature=0.7) | |
# Helper function to check cache | |
def check_cache(query): | |
cached_response = redis_client.get(query) | |
if cached_response: | |
return json.loads(cached_response) | |
return None | |
# Helper function to update cache | |
def update_cache(query, response): | |
redis_client.setex(query, 3600, json.dumps(response)) # Cache for 1 hour | |
# Main query processing function | |
async def process_query(query: str, db: Session): | |
try: | |
# Check cache | |
cached_result = check_cache(query) | |
if cached_result: | |
logger.info(f"Cache hit for query: {query[:50]}...") | |
return cached_result | |
# Step 1: Query Analysis | |
analysis = query_analyzer.generate_response(f"Analyze this query: {query}") | |
# Step 2: Route to Appropriate Specialist Agents | |
specialist_responses = {} | |
if "account" in analysis.lower(): | |
specialist_responses['account'] = account_specialist.execute(account_task, {"query": query}) | |
if "trading" in analysis.lower(): | |
specialist_responses['trading'] = trading_expert.execute(trading_task, {"query": query}) | |
if "technical" in analysis.lower(): | |
specialist_responses['technical'] = technical_support.execute(tech_support_task, {"query": query}) | |
if "product" in analysis.lower(): | |
specialist_responses['product'] = product_specialist.execute(product_task, {"query": query}) | |
# Step 3: Compliance Check | |
for key in specialist_responses: | |
specialist_responses[key] = compliance_agent.generate_response(f"Ensure this response is compliant: {specialist_responses[key]}") | |
# Step 4: Coordinate Final Response | |
final_response = coordinator.generate_response(f"Synthesize these responses into a final answer: {specialist_responses}") | |
# Step 5: Sentiment Analysis | |
sentiment = sentiment_analyzer.generate_response(f"Analyze the sentiment of this interaction: Query: {query}, Response: {final_response}") | |
# Step 6: Update Knowledge Base | |
kb_manager.generate_response(f"Update knowledge base based on: Query: {query}, Response: {final_response}") | |
# Step 7: Generate Learning Content (if needed) | |
if "educational" in analysis.lower(): | |
learning_dev.execute(learning_task, {"query": query, "response": final_response}) | |
# Save query and response to database | |
db_query = Query(text=query) | |
db.add(db_query) | |
db.commit() | |
db.refresh(db_query) | |
db_response = Response(text=final_response, query_id=db_query.id) | |
db.add(db_response) | |
db.commit() | |
result = {"response": final_response, "sentiment": sentiment} | |
# Update cache | |
update_cache(query, result) | |
return result | |
except Exception as e: | |
logger.error(f"Error processing query: {str(e)}", exc_info=True) | |
raise HTTPException(status_code=500, detail="An error occurred while processing your query") | |
# API Endpoints | |
async def handle_query(query: QueryInput, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): | |
result = await process_query(query.text, db) | |
return QueryOutput(**result) | |
# Run the application | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |
# models.py | |
from sqlalchemy import Column, Integer, String, ForeignKey | |
from sqlalchemy.orm import relationship | |
from database import Base | |
class User(Base): | |
__tablename__ = "users" | |
id = Column(Integer, primary_key=True, index=True) | |
username = Column(String, unique=True, index=True) | |
hashed_password = Column(String) | |
class Query(Base): | |
__tablename__ = "queries" | |
id = Column(Integer, primary_key=True, index=True) | |
text = Column(String) | |
responses = relationship("Response", back_populates="query") | |
class Response(Base): | |
__tablename__ = "responses" | |
id = Column(Integer, primary_key=True, index=True) | |
text = Column(String) | |
query_id = Column(Integer, ForeignKey("queries.id")) | |
query = relationship("Query", back_populates="responses") | |
# database.py | |
from sqlalchemy import create_engine | |
from sqlalchemy.ext.declarative import declarative_base | |
from sqlalchemy.orm import sessionmaker | |
SQLALCHEMY_DATABASE_URL = "sqlite:///./zerodha_support.db" | |
engine = create_engine( | |
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} | |
) | |
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | |
Base = declarative_base() | |
# auth.py | |
from datetime import datetime, timedelta | |
from jose import JWTError, jwt | |
from passlib.context import CryptContext | |
from fastapi import Depends, HTTPException, status | |
from fastapi.security import OAuth2PasswordBearer | |
from sqlalchemy.orm import Session | |
from models import User | |
from database import get_db | |
SECRET_KEY = "your-secret-key" | |
ALGORITHM = "HS256" | |
ACCESS_TOKEN_EXPIRE_MINUTES = 30 | |
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") | |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") | |
def verify_password(plain_password, hashed_password): | |
return pwd_context.verify(plain_password, hashed_password) | |
def get_password_hash(password): | |
return pwd_context.hash(password) | |
def create_access_token(data: dict): | |
to_encode = data.copy() | |
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | |
to_encode.update({"exp": expire}) | |
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) | |
return encoded_jwt | |
def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)): | |
credentials_exception = HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Could not validate credentials", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
try: | |
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
username: str = payload.get("sub") | |
if username is None: | |
raise credentials_exception | |
except JWTError: | |
raise credentials_exception | |
user = db.query(User).filter(User.username == username).first() | |
if user is None: | |
raise credentials_exception | |
return user |