File size: 4,078 Bytes
7e226d9
 
ff92b60
a0c64c9
ff92b60
a0c64c9
 
 
ff92b60
 
 
 
 
a0c64c9
cdef8cd
a0c64c9
 
 
cdef8cd
5021aeb
 
cdef8cd
a0c64c9
cdef8cd
ff92b60
 
 
 
 
cdef8cd
ff92b60
cdef8cd
a0c64c9
cdef8cd
ff92b60
 
 
 
 
 
a0c64c9
 
 
5021aeb
 
 
a0c64c9
 
5021aeb
 
 
 
 
 
a0c64c9
5021aeb
ff92b60
a0c64c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff92b60
 
cdef8cd
ff92b60
 
 
 
 
 
a0c64c9
 
cdef8cd
ff92b60
 
 
 
 
 
 
 
 
 
 
 
5021aeb
ff92b60
 
 
 
 
cdef8cd
 
 
ff92b60
 
 
e787109
5021aeb
e787109
a0c64c9
 
 
 
 
ff92b60
 
 
cdef8cd
ff92b60
cdef8cd
 
ff92b60
cdef8cd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from typing import Optional, Dict
import os
import chainlit as cl
from langchain_community.vectorstores import Chroma
from langchain_core.prompts import PromptTemplate
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain_community.llms import LlamaCpp
from chainlit.types import ThreadDict
from langchain.chains import RetrievalQA, ConversationChain
from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.chains.conversation.memory import ConversationBufferMemory


llm = LlamaCpp(
    model_path="Model/llama-2-7b-chat.Q4_K_M.gguf",
    callback_manager= CallbackManager([StreamingStdOutCallbackHandler()]),  # token streaming to terminal
    device="cpu", # IF CUDA OUT OF MEMORY, CHANGE TO SOMETHING <33
    verbose=True,
    max_tokens=4096,
    n_ctx=3064,
    streaming=True,
    config={'temperature': 0.25}  # randomness of the reply
)

DATA_PATH = 'Data/'
DB_CHROMA_PATH = 'vectorstore/db_chroma'

template = """
You are an AI specialized in the medical domain. 
Your purpose is to provide accurate, clear, and helpful responses to medical-related inquiries. 
You must avoid misinformation at all costs. Do not respond to questions outside of the medical domain. 
You an summarize chapters, pharagraphs and other content that is in the medical domain. 
If you are unsure or lack information about a query, you must clearly state that you do not know the answer.

Question: {query}

Answer:
"""

prompt_template = PromptTemplate(input_variables=["query"], template=template)


rag_template = """
You are an AI specialized in the medical domain. 
Your purpose is to provide accurate, clear, and helpful responses to medical-related inquiries. 
You must avoid misinformation at all costs. Do not respond to questions outside of the medical domain. 
If you are unsure or lack information about a query, you must clearly state that you do not know the answer.

Question: {query}

Answer:
"""


rag_template = PromptTemplate(input_variables=["query"], template=rag_template)

embedding_function = HuggingFaceEmbeddings(
    model_name='sentence-transformers/all-MiniLM-L6-v2',
    model_kwargs={'device': 'cpu'}
)


db = Chroma(persist_directory=DB_CHROMA_PATH, embedding_function=embedding_function)

rag_pipeline = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type='stuff',
    retriever=db.as_retriever(),
    return_source_documents=True
)


conversation_buf = ConversationChain(
    llm=llm,
    memory=ConversationBufferMemory(),
)

@cl.on_chat_start
async def on_chat_start():
    pass


@cl.step()
async def get_response(query):
    """
    Generates a response from the language model based on the user's input. If the input includes
    '-rag', it uses a retrieval-augmented generation pipeline, otherwise, it directly invokes
    the language model.

    Args:
        question (str): The user's input text.

    Returns:
        str: The language model's response, potentially including source documents if '-rag' was used.
    """
    if "-rag" in query.lower():
        response = await cl.make_async(rag_pipeline)(rag_template.format(query=query))
        result = response["result"]
        source = response["source_documents"]
        if source:
            source_details = "\n\nSources:"
            for source in source:
                page_content = source.page_content
                page_number = source.metadata.get('page', 'N/A')
                source_book = source.metadata.get('source', 'N/A')
                source_details += f"\n- Page {page_number} from {source_book}: \"{page_content}\""
            result += source_details
        return result

    return await cl.make_async(llm.invoke)(prompt_template.format(query=query))






@cl.on_message
async def on_message(message: cl.Message):
    """
    Fetches the response from the language model and shows it in the web UI.
    """
    response = await get_response(message.content)
    msg = cl.Message(content=response)

    await msg.send()