File size: 3,897 Bytes
65db96a
 
 
 
 
9f5f200
 
 
 
65db96a
 
 
 
 
 
 
 
9f5f200
 
 
65db96a
9f5f200
65db96a
 
9f5f200
 
 
 
65db96a
 
9f5f200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65db96a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import streamlit as st
import yaml
import requests
import re
import os

from langchain_core.prompts import PromptTemplate
import streamlit as st

from src.pdfParser import get_pdf_text

# Get HuggingFace API key
api_key_name = "HUGGINGFACE_HUB_TOKEN"
api_key = os.getenv(api_key_name)
if api_key is None:
    st.error(f"Failed to read `{api_key_name}`. Ensure the token is correctly located")

# Load in model configuration and check the required keys are present
model_config_dir = "config/model_config.yml"
config_keys = ["system_message", "model_id", "template"]

with open(model_config_dir, "r") as file:
    model_config = yaml.safe_load(file)

for var in model_config.keys():
    if var not in config_keys:
        raise ValueError(f"`{var}` key missing from `{model_config_dir}`")
    
system_message = model_config["system_message"]
model_id = model_config["model_id"]
template = model_config["template"]

prompt_template = PromptTemplate(
    template=template,
    input_variables=["system_message", "user_message"]
)
















def query(payload, model_id):
    headers = {"Authorization": f"Bearer {api_key}"}
    API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
    response = requests.post(API_URL, headers=headers, json=payload)
    return response.json()


def prompt_generator(system_message, user_message):
    return f"""
        <s>[INST] <<SYS>>
        {system_message}
        <</SYS>>
        {user_message} [/INST]
        """


# Pattern to clean up text response from API
pattern = r".*\[/INST\]([\s\S]*)$"

# Initialize chat history
if "messages" not in st.session_state:
    st.session_state.messages = []

# Include PDF upload ability
pdf_upload = st.file_uploader(
    "Upload a .PDF here",
    type=".pdf",
)

if pdf_upload is not None:
    pdf_text = get_pdf_text(pdf_upload)


if "key_inputs" not in st.session_state:
    st.session_state.key_inputs = {}

col1, col2, col3 = st.columns([3, 3, 2])

with col1:
    key_name = st.text_input("Key/Column Name (e.g. patient_name)", key="key_name")

with col2:
    key_description = st.text_area(
        "*(Optional) Description of key/column", key="key_description"
    )

with col3:
    if st.button("Extract this column"):
        if key_description:
            st.session_state.key_inputs[key_name] = key_description
        else:
            st.session_state.key_inputs[key_name] = "No further description provided"

if st.session_state.key_inputs:
    keys_title = st.write("\nKeys/Columns for extraction:")
    keys_values = st.write(st.session_state.key_inputs)

    with st.spinner("Extracting requested data"):
        if st.button("Extract data!"):
            user_message = f"""
                Use the text provided and denoted by 3 backticks ```{pdf_text}```. 
                Extract the following columns and return a table that could be uploaded to an SQL database.
                {'; '.join([key + ': ' + st.session_state.key_inputs[key] for key in st.session_state.key_inputs])}
            """
            the_prompt = prompt_generator(
                system_message=system_message, user_message=user_message
            )
            response = query(
                {
                    "inputs": the_prompt,
                    "parameters": {"max_new_tokens": 500, "temperature": 0.1},
                },
                model_id,
            )
            try:
                match = re.search(
                    pattern, response[0]["generated_text"], re.MULTILINE | re.DOTALL
                )
                if match:
                    response = match.group(1).strip()

                response = eval(response)

                st.success("Data Extracted Successfully!")
                st.write(response)
            except:
                st.error("Unable to connect to model. Please try again later.")

            # st.success(f"Data Extracted!")