File size: 5,734 Bytes
59b4d60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ba80b7
bc0779e
1ba80b7
 
 
 
 
 
 
 
 
bc0779e
62b8fca
 
1a8c15c
b8810ee
bc0779e
1ba80b7
 
 
 
 
 
 
dd7662a
 
1ba80b7
 
 
 
 
97b410a
1ba80b7
 
0388a2a
 
97b410a
b8810ee
97b410a
 
 
 
 
 
 
 
 
 
 
b8810ee
97b410a
 
 
1a8c15c
 
 
 
97b410a
 
 
 
 
 
1a8c15c
 
97b410a
c56570d
1ba80b7
97b410a
 
c56570d
1ba80b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a8c15c
bc0779e
1ba80b7
 
bc0779e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ba80b7
bc0779e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97b410a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import os
import subprocess
import sys

# Function to install missing dependencies
def install_dependencies():
    try:
        # Check if langchain is installed
        import langchain
    except ImportError:
        # Install langchain if not found
        print("Installing langchain...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", "langchain"])
    try:
        # Check if other dependencies are installed
        import streamlit
        import fastapi
        import requests
        import datasets
        import pinecone
        import sentence_transformers
        import dotenv
        import PIL
    except ImportError as e:
        # Install missing dependencies
        print(f"Installing missing dependency: {e.name}")
        subprocess.check_call([sys.executable, "-m", "pip", "install", e.name])

# Install dependencies before proceeding
install_dependencies()

import os
import requests
import streamlit as st
from fastapi import FastAPI, HTTPException
from langchain.chains import ConversationalRetrievalChain
from langchain.vectorstores import Pinecone
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.memory import ConversationBufferMemory
from datasets import load_dataset
from dotenv import load_dotenv
from pinecone import Pinecone
from PIL import Image
from langchain_community.vectorstores import Pinecone
from pinecone import Pinecone as PineconeClient
from anthropic import Anthropic
from langchain.base_language import BaseLanguageModel

# Load environment variables
load_dotenv()

# Initialize FastAPI
app = FastAPI()

# API Keys
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
PINECONE_ENV = os.getenv("PINECONE_ENV")
INDEX_NAME = "agenticrag"

if not PINECONE_API_KEY:
    raise ValueError("Pinecone API Key is missing. Please set it in environment variables.")

# Initialize Hugging Face Embeddings
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

vector_store = Pinecone.from_existing_index(index_name=INDEX_NAME, embedding=embeddings)

# Custom Anthropic LLM Wrapper
class AnthropicLLM(BaseLanguageModel):
    def __init__(self, model: str, temperature: float, api_key: str):
        super().__init__()
        self.model = model
        self.temperature = temperature
        self.anthropic_client = Anthropic(api_key=api_key)

    def _call(self, prompt: str, stop: list = None) -> str:
        response = self.anthropic_client.completions.create(
            model=self.model,
            prompt=prompt,
            temperature=self.temperature,
            max_tokens_to_sample=1000,  # Adjust as needed
            stop_sequences=stop or [],
        )
        return response.completion

    def count_tokens(self, text: str) -> int:
        return self.anthropic_client.count_tokens(text)

    @property
    def _llm_type(self) -> str:
        return "anthropic"

# Initialize Anthropic LLM
llm = AnthropicLLM(
    model="claude-2",
    temperature=0,
    api_key=os.getenv("ANTHROPIC_API_KEY")
)

# Initialize memory
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)

# Build RAG Chain
qa_chain = ConversationalRetrievalChain.from_llm(
    llm=llm,
    retriever=vector_store.as_retriever(),
    memory=memory,
    return_source_documents=True
)

@app.post("/query/")
async def query_agent(query: str):
    try:
        response = qa_chain.run(query)
        return {"response": response}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/")
def read_root():
    return {"message": "Welcome to the Agentic RAG Legal Assistant!"}

# Load dataset
dataset = load_dataset("c4lliope/us-congress")
chunks = [str(text) for text in dataset['train']['text']]
embedding_vectors = embeddings.embed_documents(chunks)
pinecone_data = [(str(i), embedding_vectors[i], {"text": chunks[i]}) for i in range(len(chunks))]
vector_store.upsert(vectors=pinecone_data)

# Streamlit UI
st.set_page_config(page_title="LegalAI Assistant", layout="wide")

bg_image = "https://source.unsplash.com/1600x900/?law,court"
sidebar_image = "https://source.unsplash.com/400x600/?law,justice"

st.markdown(
    f"""
    <style>
    .stApp {{
        background: url({bg_image}) no-repeat center center fixed;
        background-size: cover;
    }}
    .sidebar .sidebar-content {{
        background: url({sidebar_image}) no-repeat center center;
        background-size: cover;
    }}
    </style>
    """,
    unsafe_allow_html=True,
)

st.sidebar.title("βš–οΈ Legal AI Assistant")
st.sidebar.markdown("Your AI-powered legal research assistant.")

st.markdown("# πŸ›οΈ Agentic RAG Legal Assistant")
st.markdown("### Your AI-powered assistant for legal research and case analysis.")

if "chat_history" not in st.session_state:
    st.session_state.chat_history = []

user_query = st.text_input("πŸ” Enter your legal question:", "")
API_URL = "http://127.0.0.1:8000/query/"

if st.button("Ask AI") and user_query:
    with st.spinner("Fetching response..."):
        try:
            response = requests.post(API_URL, json={"query": user_query})
            response_json = response.json()
            ai_response = response_json.get("response", "Error: No response received.")
        except Exception as e:
            ai_response = f"Error: {e}"
    
    st.session_state.chat_history.append((user_query, ai_response))

st.markdown("---")
st.markdown("### πŸ“œ Chat History")
for user_q, ai_r in st.session_state.chat_history:
    st.markdown(f"**πŸ§‘β€βš–οΈ You:** {user_q}")
    st.markdown(f"**πŸ€– AI:** {ai_r}")
    st.markdown("---")

st.markdown("---")
st.markdown("πŸš€ Powered by Anthropic Claude, Pinecone, and LangChain.")