File size: 2,562 Bytes
fe1de71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ffa965
fe1de71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ffa965
 
fe1de71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a4b8df
a2f2772
fe1de71
 
 
 
 
 
 
 
a2fb176
fe1de71
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import pysqlite3
import sys, os
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
from langchain_community.document_loaders import PyPDFLoader
from langchain.llms import HuggingFaceHub
from langchain.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain.vectorstores import Chroma
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_community.llms import HuggingFaceEndpoint
import streamlit as st

HF_TOKEN = st.secrets["HF_TOKEN"]
os.environ["HUGGINGFACEHUB_API_TOKEN"] = HF_TOKEN


@st.cache_resource()
def retrieve_documents():
    embeddings = HuggingFaceInferenceAPIEmbeddings(
        api_key=HF_TOKEN, model_name="BAAI/bge-base-en-v1.5")
        #api_key=HF_TOKEN, model_name="local:BAAI/bge-m3")
    db = Chroma(persist_directory="./db",
            embedding_function=embeddings)
    retriever = db.as_retriever(search_kwargs = {"k":3})
    return retriever

@st.cache_resource()
def create_chain(_retriever):
    template = """
    User: You are an AI Assistant that follows instructions well.
    Please be truthful and give direct answers. Please tell 'I don't know' if user query is not in CONTEXT

    Keep in mind, you will lose the job, if you answer out of CONTEXT questions

    CONTEXT: {context}
    Query: {question}

    Remember only return AI answer
    Assistant:
    """

    llm = HuggingFaceEndpoint(
    endpoint_url = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2",
    max_new_tokens=2048,
    top_k=3,
    top_p=0.9,
    temperature=0.5,
    repetition_penalty=1.1,
    streaming=False,
    )

    prompt = ChatPromptTemplate.from_template(template)
    output_parser = StrOutputParser()
    chain = ({
        "context": _retriever.with_config(run_name="Docs"),
        "question":RunnablePassthrough()
        }
        | prompt
        | llm
        | output_parser
        )
    return chain


def main():
    st.title("All About Sungwon")
    st.header("Ask anything about Sungwon. Professional or Personal")
    prompt = st.text_input("Enter your question")
    text_container = st.empty()
    text_debugger = st.empty()
    full_text = ""
    chain = create_chain(retrieve_documents())
    chunk = chain.invoke(prompt)
    text_container.write(chunk)

    st.write("check out my personalized diffuion model site to see my picture[link](https://huggingface.co/spaces/sorg20/sorg20-autotrain-sd-pic)")

if __name__ == "__main__":
    main()