ivnban27-ctl commited on
Commit
d4599c5
·
verified ·
1 Parent(s): 2766030

deleted unnecessary file

Browse files
Files changed (1) hide show
  1. convosim.py +0 -99
convosim.py DELETED
@@ -1,99 +0,0 @@
1
- import os
2
- import streamlit as st
3
- from streamlit.logger import get_logger
4
- from langchain.schema.messages import HumanMessage
5
- from utils.mongo_utils import get_db_client
6
- from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF
7
- from utils.memory_utils import clear_memory, push_convo2db
8
- from utils.chain_utils import get_chain, custom_chain_predict
9
- from app_config import ISSUES, SOURCES, source2label, issue2label, MAX_MSG_COUNT, WARN_MSG_COUT
10
-
11
- logger = get_logger(__name__)
12
- openai_api_key = os.environ['OPENAI_API_KEY']
13
- temperature = 0.8
14
- # username = "barb-chase" #"ivnban-ctl"
15
-
16
- if "sent_messages" not in st.session_state:
17
- st.session_state['sent_messages'] = 0
18
- if "total_messages" not in st.session_state:
19
- st.session_state['total_messages'] = 0
20
- if "issue" not in st.session_state:
21
- st.session_state['issue'] = ISSUES[0]
22
- if 'previous_source' not in st.session_state:
23
- st.session_state['previous_source'] = SOURCES[0]
24
- if 'db_client' not in st.session_state:
25
- st.session_state["db_client"] = get_db_client()
26
- if 'texter_name' not in st.session_state:
27
- st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
28
- logger.debug(f"texter name is {st.session_state['texter_name']}")
29
-
30
- memories = {'memory':{"issue": st.session_state['issue'], "source": st.session_state['previous_source']}}
31
-
32
- with st.sidebar:
33
- username = st.text_input("Username", value='Dani', max_chars=30)
34
- if 'counselor_name' not in st.session_state:
35
- st.session_state["counselor_name"] = username #get_random_name(names_df=DEFAULT_NAMES_DF)
36
- # temperature = st.slider("Temperature", 0., 1., value=0.8, step=0.1)
37
- issue = st.selectbox("Select a Scenario", ISSUES, index=0, format_func=issue2label,
38
- on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
39
- )
40
- supported_languages = ['en', "es"] if issue == "Anxiety" else ['en']
41
- language = st.selectbox("Select a Language", supported_languages, index=0,
42
- format_func=lambda x: "English" if x=="en" else "Spanish",
43
- on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
44
- )
45
-
46
- source = st.selectbox("Select a source Model A", SOURCES, index=0,
47
- format_func=source2label,
48
- )
49
-
50
- changed_source = any([
51
- st.session_state['previous_source'] != source,
52
- st.session_state['issue'] != issue,
53
- st.session_state['counselor_name'] != username,
54
- ])
55
- if changed_source:
56
- st.session_state["counselor_name"] = username
57
- st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
58
- logger.debug(f"texter name is {st.session_state['texter_name']}")
59
- st.session_state['previous_source'] = source
60
- st.session_state['issue'] = issue
61
- st.session_state['sent_messages'] = 0
62
- st.session_state['total_messages'] = 0
63
- create_memory_add_initial_message(memories,
64
- issue,
65
- language,
66
- changed_source=changed_source,
67
- counselor_name=st.session_state["counselor_name"],
68
- texter_name=st.session_state["texter_name"])
69
- st.session_state['previous_source'] = source
70
- memoryA = st.session_state[list(memories.keys())[0]]
71
- # issue only without "." marker for model compatibility
72
- llm_chain, stopper = get_chain(issue, language, source, memoryA, temperature, texter_name=st.session_state["texter_name"])
73
-
74
- st.title("💬 Simulator")
75
- st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
76
- for msg in memoryA.buffer_as_messages:
77
- role = "user" if type(msg) == HumanMessage else "assistant"
78
- st.chat_message(role).write(msg.content)
79
-
80
- if prompt := st.chat_input(disabled=st.session_state['total_messages'] > MAX_MSG_COUNT - 4): #account for next interaction
81
- st.session_state['sent_messages'] += 1
82
- st.chat_message("user").write(prompt)
83
- if 'convo_id' not in st.session_state:
84
- push_convo2db(memories, username, language)
85
- responses = custom_chain_predict(llm_chain, prompt, stopper)
86
- # responses = llm_chain.predict(input=prompt, stop=stopper)
87
- # response = update_memory_completion(prompt, st.session_state["memory"], OA_engine, temperature)
88
- for response in responses:
89
- st.chat_message("assistant").write(response)
90
-
91
- st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
92
- if st.session_state['total_messages'] >= MAX_MSG_COUNT:
93
- st.toast(f"Total of {MAX_MSG_COUNT} Messages reached. Conversation Ended", icon=":material/verified:")
94
- elif st.session_state['total_messages'] >= WARN_MSG_COUT:
95
- st.toast(f"The conversation will end at {MAX_MSG_COUNT} Total Messages ", icon=":material/warning:")
96
-
97
- with st.sidebar:
98
- st.markdown(f"### Total Sent Messages: :red[**{st.session_state['sent_messages']}**]")
99
- st.markdown(f"### Total Messages: :red[**{st.session_state['total_messages']}**]")