File size: 3,478 Bytes
cd73f6e
 
 
 
9b6d908
 
cd73f6e
be04987
cd73f6e
 
 
 
 
9cbe2ec
24b913e
811654f
24b913e
0f238a0
24b913e
 
 
 
20d5ffa
 
 
f3c827b
 
4aa0194
 
20d5ffa
 
 
cd73f6e
811654f
24b913e
b492afd
20d5ffa
811654f
1466c6e
811654f
cd73f6e
 
20d5ffa
 
 
 
 
 
 
 
 
cd73f6e
f1edf44
 
cd73f6e
 
 
675db3f
cd73f6e
 
 
 
 
 
 
 
 
 
 
 
0aa9e50
d7d7f0d
cd73f6e
 
f9a4cc4
cd73f6e
811654f
 
 
 
 
cd73f6e
82aedcc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import streamlit as st 
import os
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.schema import StrOutputParser
# from datetime import datetime
from datetime import datetime, timezone, timedelta

from custom_llm import CustomLLM, custom_chain_with_history

from typing import Optional

from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.chat_history import BaseChatMessageHistory
from langchain.memory import ConversationBufferMemory#, PostgresChatMessageHistory

import urllib.parse as up

# os.environ['LANGCHAIN_TRACING_V2'] = "true"


API_TOKEN = os.getenv('HF_INFER_API')

@st.cache_resource
def get_llm_chain():
    return custom_chain_with_history(
        # llm=CustomLLM(repo_id="AdaptLLM/medicine-chat", model_type='text-generation', api_token=API_TOKEN, temperature=0.001), 
        llm=CustomLLM(repo_id="google/gemma-7b", model_type='text-generation', api_token=API_TOKEN, stop=["\n<|","<|"], temperature=0.001), 
        # memory=st.session_state.memory.chat_memory,
        memory=st.session_state.memory
    )


if 'memory' not in st.session_state:
    st.session_state['memory'] = ConversationBufferMemory(return_messages=True)
    
    # st.session_state.memory = PostgresChatMessageHistory(connection_string=POSTGRE_URL, session_id=str(datetime.timestamp(datetime.now())))

    # st.session_state.memory = get_memory()
    st.session_state.memory.chat_memory.add_ai_message("Hello, I'm AI medical consultant. How can I help you today?")
    # st.session_state.memory.add_ai_message("Hello, My name is Jonathan Jordan. You can call me Jojo. How can I help you today?")

if 'chain' not in st.session_state:
    # st.session_state['chain'] = custom_chain_with_history(
    #     llm=CustomLLM(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", model_type='text-generation', api_token=API_TOKEN, stop=["\n<|","<|"], temperature=0.001), 
    #     memory=st.session_state.memory.chat_memory,
    #     # memory=st.session_state.memory
    # )

    st.session_state['chain'] = get_llm_chain()



st.title("AI Medical Consultation")
st.subheader("")

# Initialize chat history
if "messages" not in st.session_state:
    st.session_state.messages = [{"role":"assistant", "content":"Hello, I'm AI medical consultant. How can I help you today?"}]

# Display chat messages from history on app rerun
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# React to user input
if prompt := st.chat_input("Ask me anything.."):
    # Display user message in chat message container
    st.chat_message("User").markdown(prompt)
    # Add user message to chat history
    st.session_state.messages.append({"role": "User", "content": prompt})
    
    response = st.session_state.chain.invoke({"question":prompt, "memory":st.session_state.memory}).split("\n<|")[0]

    # Display assistant response in chat message container
    with st.chat_message("assistant"):
        st.markdown(response)
        
    # st.session_state.memory.add_user_message(prompt)
    # st.session_state.memory.add_ai_message(response)
    st.session_state.memory.save_context({"question":prompt}, {"output":response})
    st.session_state.memory.chat_memory.messages = st.session_state.memory.chat_memory.messages[-15:]
    # Add assistant response to chat history
    st.session_state.messages.append({"role": "assistant", "content": response})