Spaces:
Sleeping
Sleeping
import streamlit as st | |
import yaml | |
import requests | |
import re | |
import os | |
from langchain_core.prompts import PromptTemplate | |
import streamlit as st | |
from src.pdfParser import get_pdf_text | |
# Get HuggingFace API key | |
api_key_name = "HUGGINGFACE_HUB_TOKEN" | |
api_key = os.getenv(api_key_name) | |
if api_key is None: | |
st.error(f"Failed to read `{api_key_name}`. Ensure the token is correctly located") | |
# Load in model configuration and check the required keys are present | |
model_config_dir = "config/model_config.yml" | |
config_keys = ["system_message", "model_id", "template"] | |
with open(model_config_dir, "r") as file: | |
model_config = yaml.safe_load(file) | |
for var in model_config.keys(): | |
if var not in config_keys: | |
raise ValueError(f"`{var}` key missing from `{model_config_dir}`") | |
system_message = model_config["system_message"] | |
model_id = model_config["model_id"] | |
template = model_config["template"] | |
prompt_template = PromptTemplate( | |
template=template, | |
input_variables=["system_message", "user_message"] | |
) | |
def query(payload, model_id): | |
headers = {"Authorization": f"Bearer {api_key}"} | |
API_URL = f"https://api-inference.huggingface.co/models/{model_id}" | |
response = requests.post(API_URL, headers=headers, json=payload) | |
return response.json() | |
def prompt_generator(system_message, user_message): | |
return f""" | |
<s>[INST] <<SYS>> | |
{system_message} | |
<</SYS>> | |
{user_message} [/INST] | |
""" | |
# Pattern to clean up text response from API | |
pattern = r".*\[/INST\]([\s\S]*)$" | |
# Initialize chat history | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# Include PDF upload ability | |
pdf_upload = st.file_uploader( | |
"Upload a .PDF here", | |
type=".pdf", | |
) | |
if pdf_upload is not None: | |
pdf_text = get_pdf_text(pdf_upload) | |
if "key_inputs" not in st.session_state: | |
st.session_state.key_inputs = {} | |
col1, col2, col3 = st.columns([3, 3, 2]) | |
with col1: | |
key_name = st.text_input("Key/Column Name (e.g. patient_name)", key="key_name") | |
with col2: | |
key_description = st.text_area( | |
"*(Optional) Description of key/column", key="key_description" | |
) | |
with col3: | |
if st.button("Extract this column"): | |
if key_description: | |
st.session_state.key_inputs[key_name] = key_description | |
else: | |
st.session_state.key_inputs[key_name] = "No further description provided" | |
if st.session_state.key_inputs: | |
keys_title = st.write("\nKeys/Columns for extraction:") | |
keys_values = st.write(st.session_state.key_inputs) | |
with st.spinner("Extracting requested data"): | |
if st.button("Extract data!"): | |
user_message = f""" | |
Use the text provided and denoted by 3 backticks ```{pdf_text}```. | |
Extract the following columns and return a table that could be uploaded to an SQL database. | |
{'; '.join([key + ': ' + st.session_state.key_inputs[key] for key in st.session_state.key_inputs])} | |
""" | |
the_prompt = prompt_generator( | |
system_message=system_message, user_message=user_message | |
) | |
response = query( | |
{ | |
"inputs": the_prompt, | |
"parameters": {"max_new_tokens": 500, "temperature": 0.1}, | |
}, | |
model_id, | |
) | |
try: | |
match = re.search( | |
pattern, response[0]["generated_text"], re.MULTILINE | re.DOTALL | |
) | |
if match: | |
response = match.group(1).strip() | |
response = eval(response) | |
st.success("Data Extracted Successfully!") | |
st.write(response) | |
except: | |
st.error("Unable to connect to model. Please try again later.") | |
# st.success(f"Data Extracted!") | |