File size: 3,346 Bytes
2267014
f15fdf1
2267014
 
2062332
2267014
 
2062332
 
 
 
 
f15fdf1
 
 
 
 
 
2267014
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f61498
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import streamlit as st
import yaml
import requests
import re
import os
from pdfParser import get_pdf_text

# Get HuggingFace API key
api_key_name = "HUGGINGFACE_HUB_TOKEN"
api_key = os.getenv(api_key_name)
if api_key is None:
    st.error(f"Failed to read `{api_key_name}`. Ensure the token is correctly located")

with open("config/model_config.yml", "r") as file:
    model_config = yaml.safe_load(file)

system_message = model_config["system_message"]
model_id = model_config["model_id"]


def query(payload, model_id):
    headers = {"Authorization": f"Bearer {api_key}"}
    API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
    response = requests.post(API_URL, headers=headers, json=payload)
    return response.json()


def prompt_generator(system_message, user_message):
    return f"""
        <s>[INST] <<SYS>>
        {system_message}
        <</SYS>>
        {user_message} [/INST]
        """


# Pattern to clean up text response from API
pattern = r".*\[/INST\]([\s\S]*)$"

# Initialize chat history
if "messages" not in st.session_state:
    st.session_state.messages = []

# Include PDF upload ability
pdf_upload = st.file_uploader(
    "Upload a .PDF here",
    type=".pdf",
)

if pdf_upload is not None:
    pdf_text = get_pdf_text(pdf_upload)


if "key_inputs" not in st.session_state:
    st.session_state.key_inputs = {}

col1, col2, col3 = st.columns([3, 3, 2])

with col1:
    key_name = st.text_input("Key/Column Name (e.g. patient_name)", key="key_name")

with col2:
    key_description = st.text_area(
        "*(Optional) Description of key/column", key="key_description"
    )

with col3:
    if st.button("Extract this column"):
        if key_description:
            st.session_state.key_inputs[key_name] = key_description
        else:
            st.session_state.key_inputs[key_name] = "No further description provided"

if st.session_state.key_inputs:
    keys_title = st.write("\nKeys/Columns for extraction:")
    keys_values = st.write(st.session_state.key_inputs)

    with st.spinner("Extracting requested data"):
        if st.button("Extract data!"):
            user_message = f"""
                Use the text provided and denoted by 3 backticks ```{pdf_text}```. 
                Extract the following columns and return a table that could be uploaded to an SQL database.
                {'; '.join([key + ': ' + st.session_state.key_inputs[key] for key in st.session_state.key_inputs])}
            """
            the_prompt = prompt_generator(
                system_message=system_message, user_message=user_message
            )
            response = query(
                {
                    "inputs": the_prompt,
                    "parameters": {"max_new_tokens": 500, "temperature": 0.1},
                },
                model_id,
            )
            try:
                match = re.search(
                    pattern, response[0]["generated_text"], re.MULTILINE | re.DOTALL
                )
                if match:
                    response = match.group(1).strip()

                response = eval(response)

                st.success("Data Extracted Successfully!")
                st.write(response)
            except:
                st.error("Unable to connect to model. Please try again later.")

            # st.success(f"Data Extracted!")