File size: 3,750 Bytes
e88ca0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d937612
e88ca0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c21090d
 
e88ca0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c01f8c
 
 
e88ca0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c01f8c
 
 
 
e88ca0c
 
 
9c01f8c
 
 
 
922030c
 
 
9c01f8c
 
 
 
 
 
 
 
 
922030c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127

import os
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage import LocalFileStore

import chainlit as cl  
from chainlit.playground.providers import ChatOpenAI  
from dotenv import load_dotenv
load_dotenv()

from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema.runnable import RunnablePassthrough
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage import LocalFileStore
from langchain.vectorstores import Pinecone

from operator import itemgetter
import pinecone

# =============================================================================
# Retrieval Chain
# =============================================================================
def load_llm():
  llm = ChatOpenAI(
        model='gpt-3.5-turbo',
        temperature=0.0,
    )
  return llm


def load_vectorstore():

    pinecone.init(
        api_key=os.getenv('PINECONE_API_KEY'),
        environment=os.getenv('PINECONE_ENV')
    )

    #index = pinecone.GRPCIndex("youtube-index")
    index = pinecone.Index("youtube-index")
    store = LocalFileStore("./cache/")
    core_embeddings_model = OpenAIEmbeddings()

    embedder = CacheBackedEmbeddings.from_bytes_store(
        core_embeddings_model,
        store,
        namespace=core_embeddings_model.model
    )

    text_field = "text"

    vectorstore = Pinecone(
        index,
        embedder,  
        text_field
    )

    return vectorstore


def qa_chain():

    vectorstore = load_vectorstore()
    
    llm = load_llm()
    retriever = vectorstore.as_retriever()

    template = """You are a helpful assistant that answers questions on the provided context, if its not answered within the context respond with "This query is not directly mentioned by AI Makerspace" then respond the best to your ability. 
                  Additionally, the context includes a specific integer formatted as <int>, representing a timestamp. 
                  In your response, include this integer as a citation, formatted as a YouTube video link: "https://www.youtube.com/watch?v=[video_id]&t=<int>s" and text of link be the title of video.


    ### CONTEXT
    {context}

    ### QUESTION
    {question}
    """

    prompt = ChatPromptTemplate.from_template(template)

    retrieval_augmented_qa_chain = (
        {"context": itemgetter("question") | retriever,
        "question": itemgetter("question")
        }
        | RunnablePassthrough.assign(
            context=itemgetter("context")
        )
        | {
            "response": prompt  | llm,
            "context": itemgetter("context"),
        }
    )

    return retrieval_augmented_qa_chain

# =============================================================================
# Chainlit
# =============================================================================
@cl.on_chat_start
async def on_chat_start():
    chain = qa_chain()
    cl.user_session.set("chain", chain)
    msg=cl.Message(content="What is your question about AI Makerspace?")
    await msg.send()

@cl.on_message
async def on_message(message: cl.Message):
    chain=cl.user_session.get("chain")
    res = chain.invoke({"question" : message.content})

    answer = res['response'].content
    await cl.Message(content=answer).send()
    
    '''
    source_documents = set()

    for document in res['context']:
        source_url = document.metadata['source_document']
        source_documents.add(source_url)

    combined_message = answer + "\n\nSource Documents:\n" + "\n".join(source_documents)

    await cl.Message(content=combined_message).send()
    '''