Spaces:
Sleeping
Sleeping
Ari
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -3,17 +3,15 @@ import streamlit as st
|
|
3 |
import pandas as pd
|
4 |
import sqlite3
|
5 |
import openai
|
6 |
-
from transformers import
|
7 |
from langchain import OpenAI
|
8 |
-
from
|
9 |
-
from
|
10 |
-
from
|
11 |
-
from
|
12 |
-
from
|
13 |
-
from langchain.embeddings.openai import OpenAIEmbeddings
|
14 |
import sqlparse
|
15 |
|
16 |
-
|
17 |
# OpenAI API key (ensure it is securely stored)
|
18 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
19 |
|
@@ -37,9 +35,8 @@ data.to_sql(table_name, conn, index=False, if_exists='replace')
|
|
37 |
# SQL table metadata (for validation and schema)
|
38 |
valid_columns = list(data.columns)
|
39 |
|
40 |
-
# Step 3: Use LLaMA for context retrieval (RAG)
|
41 |
-
|
42 |
-
llama_model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b")
|
43 |
|
44 |
# Step 4: Implement RAG with FAISS for vectorized document retrieval
|
45 |
embeddings = OpenAIEmbeddings() # You can use other embeddings if preferred
|
@@ -50,8 +47,6 @@ documents = loader.load()
|
|
50 |
vector_store = FAISS.from_documents(documents, embeddings)
|
51 |
retriever = vector_store.as_retriever()
|
52 |
|
53 |
-
rag_chain = RetrievalQA.from_chain_type(llama_model, retriever=retriever)
|
54 |
-
|
55 |
# Step 5: OpenAI for SQL query generation based on user prompt and context
|
56 |
openai_llm = OpenAI(temperature=0)
|
57 |
db = SQLDatabase.from_uri('sqlite:///:memory:') # Create an SQLite database for LangChain
|
@@ -77,7 +72,7 @@ user_prompt = st.text_input("Enter your natural language prompt:")
|
|
77 |
if user_prompt:
|
78 |
try:
|
79 |
# Step 9: Retrieve relevant context using LLaMA RAG
|
80 |
-
rag_result =
|
81 |
st.write(f"Retrieved Context from LLaMA RAG: {rag_result}")
|
82 |
|
83 |
# Step 10: Generate SQL query with OpenAI based on user prompt and retrieved context
|
|
|
3 |
import pandas as pd
|
4 |
import sqlite3
|
5 |
import openai
|
6 |
+
from transformers import pipeline # Using Hugging Face pipeline for memory-efficient loading
|
7 |
from langchain import OpenAI
|
8 |
+
from langchain_community.agent_toolkits.sql.base import create_sql_agent
|
9 |
+
from langchain_community.utilities import SQLDatabase
|
10 |
+
from langchain_community.document_loaders import CSVLoader
|
11 |
+
from langchain_community.vectorstores import FAISS
|
12 |
+
from langchain_community.embeddings import OpenAIEmbeddings
|
|
|
13 |
import sqlparse
|
14 |
|
|
|
15 |
# OpenAI API key (ensure it is securely stored)
|
16 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
17 |
|
|
|
35 |
# SQL table metadata (for validation and schema)
|
36 |
valid_columns = list(data.columns)
|
37 |
|
38 |
+
# Step 3: Use a smaller LLaMA for context retrieval (RAG)
|
39 |
+
llama_pipeline = pipeline("text-generation", model="huggyllama/llama-2-3b-hf", device=0) # Use smaller model
|
|
|
40 |
|
41 |
# Step 4: Implement RAG with FAISS for vectorized document retrieval
|
42 |
embeddings = OpenAIEmbeddings() # You can use other embeddings if preferred
|
|
|
47 |
vector_store = FAISS.from_documents(documents, embeddings)
|
48 |
retriever = vector_store.as_retriever()
|
49 |
|
|
|
|
|
50 |
# Step 5: OpenAI for SQL query generation based on user prompt and context
|
51 |
openai_llm = OpenAI(temperature=0)
|
52 |
db = SQLDatabase.from_uri('sqlite:///:memory:') # Create an SQLite database for LangChain
|
|
|
72 |
if user_prompt:
|
73 |
try:
|
74 |
# Step 9: Retrieve relevant context using LLaMA RAG
|
75 |
+
rag_result = llama_pipeline(user_prompt, max_length=200)
|
76 |
st.write(f"Retrieved Context from LLaMA RAG: {rag_result}")
|
77 |
|
78 |
# Step 10: Generate SQL query with OpenAI based on user prompt and retrieved context
|