Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -6,24 +6,20 @@ import requests
|
|
6 |
|
7 |
# Paths
|
8 |
PROCESSED_DATA_DIR = "data/preprocessed/"
|
9 |
-
HUGGINGFACE_KEY_FILE = "configs/huggingface_api_key.txt"
|
10 |
API_URL = "https://api-inference.huggingface.co/models/gpt2"
|
11 |
|
12 |
-
#
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
return file.readline().strip()
|
17 |
-
except FileNotFoundError:
|
18 |
-
raise FileNotFoundError(f"API key file not found: {key_file}")
|
19 |
-
except Exception as e:
|
20 |
-
raise Exception(f"Error reading API key: {e}")
|
21 |
|
22 |
-
HUGGINGFACE_API_KEY = read_huggingface_api_key(HUGGINGFACE_KEY_FILE)
|
23 |
headers = {"Authorization": f"Bearer {HUGGINGFACE_API_KEY}"}
|
24 |
|
25 |
# Load FAISS Index
|
26 |
def load_faiss_index(processed_data_dir):
|
|
|
|
|
|
|
27 |
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
28 |
vector_store = FAISS.load_local(
|
29 |
processed_data_dir, embedding_model, allow_dangerous_deserialization=True
|
@@ -34,6 +30,9 @@ vector_store = load_faiss_index(PROCESSED_DATA_DIR)
|
|
34 |
|
35 |
# Query GPT-2 Model via Hugging Face API
|
36 |
def query_huggingface_api(prompt):
|
|
|
|
|
|
|
37 |
response = requests.post(API_URL, headers=headers, json={"inputs": prompt})
|
38 |
if response.status_code == 200:
|
39 |
return response.json()[0]["generated_text"]
|
@@ -42,12 +41,16 @@ def query_huggingface_api(prompt):
|
|
42 |
|
43 |
# Generate Response
|
44 |
def generate_response(query):
|
|
|
|
|
|
|
45 |
retriever = vector_store.as_retriever()
|
46 |
retrieved_chunks = retriever.get_relevant_documents(query)
|
47 |
|
48 |
if not retrieved_chunks:
|
49 |
return "No relevant documents found."
|
50 |
|
|
|
51 |
context = "\n\n".join([doc.page_content for doc in retrieved_chunks[:3]])[:1500]
|
52 |
prompt = (
|
53 |
f"You are a legal expert specializing in business laws and the legal environment. "
|
@@ -58,6 +61,9 @@ def generate_response(query):
|
|
58 |
|
59 |
# Gradio Interface for QA Bot
|
60 |
def qa_bot(query):
|
|
|
|
|
|
|
61 |
return generate_response(query)
|
62 |
|
63 |
# Define Gradio Interface
|
|
|
6 |
|
7 |
# Paths
|
8 |
PROCESSED_DATA_DIR = "data/preprocessed/"
|
|
|
9 |
API_URL = "https://api-inference.huggingface.co/models/gpt2"
|
10 |
|
11 |
+
# Load Hugging Face API key from environment variable
|
12 |
+
HUGGINGFACE_API_KEY = os.getenv("HF_API_TOKEN")
|
13 |
+
if not HUGGINGFACE_API_KEY:
|
14 |
+
raise ValueError("Hugging Face API token is not set. Please ensure HF_API_TOKEN is added as a secret.")
|
|
|
|
|
|
|
|
|
|
|
15 |
|
|
|
16 |
headers = {"Authorization": f"Bearer {HUGGINGFACE_API_KEY}"}
|
17 |
|
18 |
# Load FAISS Index
|
19 |
def load_faiss_index(processed_data_dir):
|
20 |
+
"""
|
21 |
+
Load the FAISS index and embedding model.
|
22 |
+
"""
|
23 |
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
24 |
vector_store = FAISS.load_local(
|
25 |
processed_data_dir, embedding_model, allow_dangerous_deserialization=True
|
|
|
30 |
|
31 |
# Query GPT-2 Model via Hugging Face API
|
32 |
def query_huggingface_api(prompt):
|
33 |
+
"""
|
34 |
+
Query the Hugging Face GPT-2 model via the Inference API.
|
35 |
+
"""
|
36 |
response = requests.post(API_URL, headers=headers, json={"inputs": prompt})
|
37 |
if response.status_code == 200:
|
38 |
return response.json()[0]["generated_text"]
|
|
|
41 |
|
42 |
# Generate Response
|
43 |
def generate_response(query):
|
44 |
+
"""
|
45 |
+
Generate a response using FAISS and GPT-2.
|
46 |
+
"""
|
47 |
retriever = vector_store.as_retriever()
|
48 |
retrieved_chunks = retriever.get_relevant_documents(query)
|
49 |
|
50 |
if not retrieved_chunks:
|
51 |
return "No relevant documents found."
|
52 |
|
53 |
+
# Combine retrieved chunks into context
|
54 |
context = "\n\n".join([doc.page_content for doc in retrieved_chunks[:3]])[:1500]
|
55 |
prompt = (
|
56 |
f"You are a legal expert specializing in business laws and the legal environment. "
|
|
|
61 |
|
62 |
# Gradio Interface for QA Bot
|
63 |
def qa_bot(query):
|
64 |
+
"""
|
65 |
+
Gradio wrapper function for the QA Bot.
|
66 |
+
"""
|
67 |
return generate_response(query)
|
68 |
|
69 |
# Define Gradio Interface
|