LegalAssist-RAG / app.py
MuhammadMubashir's picture
Update app.py
b8810ee verified
import os
import subprocess
import sys
# Function to install missing dependencies
def install_dependencies():
try:
# Check if langchain is installed
import langchain
except ImportError:
# Install langchain if not found
print("Installing langchain...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "langchain"])
try:
# Check if other dependencies are installed
import streamlit
import fastapi
import requests
import datasets
import pinecone
import sentence_transformers
import dotenv
import PIL
except ImportError as e:
# Install missing dependencies
print(f"Installing missing dependency: {e.name}")
subprocess.check_call([sys.executable, "-m", "pip", "install", e.name])
# Install dependencies before proceeding
install_dependencies()
import os
import requests
import streamlit as st
from fastapi import FastAPI, HTTPException
from langchain.chains import ConversationalRetrievalChain
from langchain.vectorstores import Pinecone
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.memory import ConversationBufferMemory
from datasets import load_dataset
from dotenv import load_dotenv
from pinecone import Pinecone
from PIL import Image
from langchain_community.vectorstores import Pinecone
from pinecone import Pinecone as PineconeClient
from anthropic import Anthropic
from langchain.base_language import BaseLanguageModel
# Load environment variables
load_dotenv()
# Initialize FastAPI
app = FastAPI()
# API Keys
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
PINECONE_ENV = os.getenv("PINECONE_ENV")
INDEX_NAME = "agenticrag"
if not PINECONE_API_KEY:
raise ValueError("Pinecone API Key is missing. Please set it in environment variables.")
# Initialize Hugging Face Embeddings
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vector_store = Pinecone.from_existing_index(index_name=INDEX_NAME, embedding=embeddings)
# Custom Anthropic LLM Wrapper
class AnthropicLLM(BaseLanguageModel):
def __init__(self, model: str, temperature: float, api_key: str):
super().__init__()
self.model = model
self.temperature = temperature
self.anthropic_client = Anthropic(api_key=api_key)
def _call(self, prompt: str, stop: list = None) -> str:
response = self.anthropic_client.completions.create(
model=self.model,
prompt=prompt,
temperature=self.temperature,
max_tokens_to_sample=1000, # Adjust as needed
stop_sequences=stop or [],
)
return response.completion
def count_tokens(self, text: str) -> int:
return self.anthropic_client.count_tokens(text)
@property
def _llm_type(self) -> str:
return "anthropic"
# Initialize Anthropic LLM
llm = AnthropicLLM(
model="claude-2",
temperature=0,
api_key=os.getenv("ANTHROPIC_API_KEY")
)
# Initialize memory
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
# Build RAG Chain
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=vector_store.as_retriever(),
memory=memory,
return_source_documents=True
)
@app.post("/query/")
async def query_agent(query: str):
try:
response = qa_chain.run(query)
return {"response": response}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
def read_root():
return {"message": "Welcome to the Agentic RAG Legal Assistant!"}
# Load dataset
dataset = load_dataset("c4lliope/us-congress")
chunks = [str(text) for text in dataset['train']['text']]
embedding_vectors = embeddings.embed_documents(chunks)
pinecone_data = [(str(i), embedding_vectors[i], {"text": chunks[i]}) for i in range(len(chunks))]
vector_store.upsert(vectors=pinecone_data)
# Streamlit UI
st.set_page_config(page_title="LegalAI Assistant", layout="wide")
bg_image = "https://source.unsplash.com/1600x900/?law,court"
sidebar_image = "https://source.unsplash.com/400x600/?law,justice"
st.markdown(
f"""
<style>
.stApp {{
background: url({bg_image}) no-repeat center center fixed;
background-size: cover;
}}
.sidebar .sidebar-content {{
background: url({sidebar_image}) no-repeat center center;
background-size: cover;
}}
</style>
""",
unsafe_allow_html=True,
)
st.sidebar.title("βš–οΈ Legal AI Assistant")
st.sidebar.markdown("Your AI-powered legal research assistant.")
st.markdown("# πŸ›οΈ Agentic RAG Legal Assistant")
st.markdown("### Your AI-powered assistant for legal research and case analysis.")
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
user_query = st.text_input("πŸ” Enter your legal question:", "")
API_URL = "http://127.0.0.1:8000/query/"
if st.button("Ask AI") and user_query:
with st.spinner("Fetching response..."):
try:
response = requests.post(API_URL, json={"query": user_query})
response_json = response.json()
ai_response = response_json.get("response", "Error: No response received.")
except Exception as e:
ai_response = f"Error: {e}"
st.session_state.chat_history.append((user_query, ai_response))
st.markdown("---")
st.markdown("### πŸ“œ Chat History")
for user_q, ai_r in st.session_state.chat_history:
st.markdown(f"**πŸ§‘β€βš–οΈ You:** {user_q}")
st.markdown(f"**πŸ€– AI:** {ai_r}")
st.markdown("---")
st.markdown("---")
st.markdown("πŸš€ Powered by Anthropic Claude, Pinecone, and LangChain.")