File size: 5,196 Bytes
e2dccc5
 
81cf5f3
e2dccc5
81cf5f3
e2dccc5
 
 
 
 
 
81cf5f3
e2dccc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81cf5f3
 
 
 
 
e2dccc5
81cf5f3
 
e2dccc5
 
 
 
 
 
 
 
 
81cf5f3
e2dccc5
 
 
 
81cf5f3
e2dccc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81cf5f3
e2dccc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81cf5f3
e2dccc5
 
 
 
 
 
81cf5f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2dccc5
 
81cf5f3
e2dccc5
 
 
 
 
 
 
 
 
81cf5f3
 
e2dccc5
 
 
81cf5f3
 
e2dccc5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from langchain.prompts import PromptTemplate
from langchain.output_parsers import PydanticOutputParser
from langchain.chat_models import ChatOpenAI

from llama_index import VectorStoreIndex, ServiceContext, StorageContext,  download_loader, SimpleDirectoryReader
from llama_index.vector_stores import FaissVectorStore
from llama_index.tools import QueryEngineTool, ToolMetadata
from llama_index.query_engine import SubQuestionQueryEngine
from llama_index.embeddings import OpenAIEmbedding
from llama_index.schema import Document
from llama_index.node_parser import UnstructuredElementNodeParser
from llama_index.llms import OpenAI

import streamlit as st
import os
import faiss
import time


st.set_page_config(page_title="Yield Case Analyzer", page_icon=":card_index_dividers:", initial_sidebar_state="expanded", layout="wide")

st.title(":card_index_dividers: Yield Case Analyzer")
st.info("""
Begin by uploading the case report in PDF format. Afterward, click on 'Process Document'. Once the document has been processed. You can enter question and click send, system will answer your question.
""")


def get_model(model_name):
    OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
    if model_name == "openai":
        model = ChatOpenAI(openai_api_key=OPENAI_API_KEY, model_name="gpt-3.5-turbo")
    return model

def get_vector_index(docs, vector_store):
    print(docs)
    llm = get_model("openai")
    if vector_store == "faiss":
        d = 1536
        faiss_index = faiss.IndexFlatL2(d)
        vector_store = FaissVectorStore(faiss_index=faiss_index)
        storage_context = StorageContext.from_defaults(vector_store=vector_store)
        # embed_model = OpenAIEmbedding()
        # service_context = ServiceContext.from_defaults(embed_model=embed_model)
        service_context = ServiceContext.from_defaults(llm=llm)
        index = VectorStoreIndex(docs,
            service_context=service_context,
            storage_context=storage_context
        )
    elif vector_store == "simple":
        index = VectorStoreIndex.from_documents(docs)


    return index



def generate_insight(engine, search_string):

    with open("prompts/main.prompt", "r") as f:
        template = f.read()

    prompt_template = PromptTemplate(
        template=template,
        input_variables=['search_string']
    )

    formatted_input = prompt_template.format(search_string=search_string)
    print(formatted_input)
    response = engine.query(formatted_input)
    return response.response


def get_query_engine(engine):
    llm = get_model("openai")
    service_context = ServiceContext.from_defaults(llm=llm)

    query_engine_tools = [
        QueryEngineTool(
            query_engine=engine,
            metadata=ToolMetadata(
                name="Alert Report",
                description=f"Provides information about the alerts from alerts files uploaded.",
            ),
        ),
    ]


    s_engine = SubQuestionQueryEngine.from_defaults(
        query_engine_tools=query_engine_tools,
        service_context=service_context
    )
    return s_engine


if "process_doc" not in st.session_state:
        st.session_state.process_doc = False


OPENAI_API_KEY = "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"
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY


if OPENAI_API_KEY:
    files_uploaded = st.sidebar.file_uploader("Upload the case report in PDF format", type="pptx")
    st.sidebar.info("""
    Example pdf reports you can upload here:
    """)

    if st.sidebar.button("Process Document"):
        with st.spinner("Processing Document..."):

        data_dir = "./data"
        if not os.path.exists(data_dir):
            os.makedirs(data_dir)

        for file in files_uploaded:
            print(f'file named {file.name}')
            fname=f'{data_dir}/{file.name}'
            with open(fname, 'wb') as f:
                f.write(file.read())

        def fmetadata(dummy: str): return {"file_path": ""}

        PptxReader = download_loader("PptxReader")
        loader =  SimpleDirectoryReader(input_dir=data_dir, file_extractor={".pptx": PptxReader(),}, file_metadata=fmetadata)

        documents = loader.load_data()
        for doc in documents:
            doc.metadata["file_path"]=""

        st.session_state.index = get_vector_index(documents, vector_store="faiss")
        #st.session_state.index = get_vector_index(documents, vector_store="simple")
        st.session_state.process_doc = True
        st.toast("Document Processsed!")

    #st.session_state.process_doc = True

    if st.session_state.process_doc:
        search_text = st.text_input("Enter your question")
        if st.button("Submit"):
            engine = get_query_engine(st.session_state.index.as_query_engine(similarity_top_k=3))
            start_time = time.time()

            st.write("Alert search result...")
            response = generate_insight(engine, search_text)
            st.write(response)
            #st.session_state["end_time"] = "{:.2f}".format((time.time() - start_time))

            st.toast("Report Analysis Complete!")

        #if st.session_state.end_time:
        #    st.write("Report Analysis Time: ", st.session_state.end_time, "s")