Spaces:
Running
Running
regraded01
commited on
Commit
•
9f5f200
1
Parent(s):
65db96a
feat: check config keys are set properly
Browse files- app_langchain.py +32 -1
app_langchain.py
CHANGED
@@ -3,6 +3,10 @@ import yaml
|
|
3 |
import requests
|
4 |
import re
|
5 |
import os
|
|
|
|
|
|
|
|
|
6 |
from src.pdfParser import get_pdf_text
|
7 |
|
8 |
# Get HuggingFace API key
|
@@ -11,12 +15,39 @@ api_key = os.getenv(api_key_name)
|
|
11 |
if api_key is None:
|
12 |
st.error(f"Failed to read `{api_key_name}`. Ensure the token is correctly located")
|
13 |
|
|
|
|
|
|
|
14 |
|
15 |
-
with open(
|
16 |
model_config = yaml.safe_load(file)
|
17 |
|
|
|
|
|
|
|
|
|
18 |
system_message = model_config["system_message"]
|
19 |
model_id = model_config["model_id"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
def query(payload, model_id):
|
|
|
3 |
import requests
|
4 |
import re
|
5 |
import os
|
6 |
+
|
7 |
+
from langchain_core.prompts import PromptTemplate
|
8 |
+
import streamlit as st
|
9 |
+
|
10 |
from src.pdfParser import get_pdf_text
|
11 |
|
12 |
# Get HuggingFace API key
|
|
|
15 |
if api_key is None:
|
16 |
st.error(f"Failed to read `{api_key_name}`. Ensure the token is correctly located")
|
17 |
|
18 |
+
# Load in model configuration and check the required keys are present
|
19 |
+
model_config_dir = "config/model_config.yml"
|
20 |
+
config_keys = ["system_message", "model_id", "template"]
|
21 |
|
22 |
+
with open(model_config_dir, "r") as file:
|
23 |
model_config = yaml.safe_load(file)
|
24 |
|
25 |
+
for var in model_config.keys():
|
26 |
+
if var not in config_keys:
|
27 |
+
raise ValueError(f"`{var}` key missing from `{model_config_dir}`")
|
28 |
+
|
29 |
system_message = model_config["system_message"]
|
30 |
model_id = model_config["model_id"]
|
31 |
+
template = model_config["template"]
|
32 |
+
|
33 |
+
prompt_template = PromptTemplate(
|
34 |
+
template=template,
|
35 |
+
input_variables=["system_message", "user_message"]
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
|
52 |
|
53 |
def query(payload, model_id):
|