ivnban27-ctl commited on
Commit
1e91476
1 Parent(s): ad3d130

llm_endpoint_update (#11)

Browse files

- fixed databricks integration (42a7266256c56efda882f5ea36a1dacf26a18515)
- logging utils (5ff529cbd911654e16cc9722f5920503085cc3f2)
- open ai update to gpt4 (a59350a066686b7d41db18868a4bd227ca6c28bd)
- conversation end (5b7437146dae69a95dd96b623229f50e58c420de)
- max messages updated (62de720c64e356a1d0785ac136fc527308dfd1d7)

app_config.py CHANGED
@@ -3,22 +3,22 @@ from models.model_seeds import seeds, seed2str
3
  # ISSUES = ['Anxiety','Suicide']
4
  ISSUES = [k for k,_ in seeds.items()]
5
  SOURCES = [
6
- "CTL_llama2",
7
- # "CTL_llama3",
8
  # "CTL_mistral",
9
  'OA_rolemodel',
10
  # 'OA_finetuned',
11
  ]
12
- SOURCES_LAB = {"OA_rolemodel":'OpenAI GPT3.5',
13
  "OA_finetuned":'Finetuned OpenAI',
14
- "CTL_llama2": "Llama 3",
15
- #"CTL_llama3": "Llama 3",
16
  "CTL_mistral": "Mistral",
17
  }
18
 
19
  ENDPOINT_NAMES = {
20
- "CTL_llama2": "texter_simulator",
21
- # "CTL_llama3": "texter_simulator",
22
  # 'CTL_llama2': "llama2_convo_sim",
23
  "CTL_mistral": "convo_sim_mistral"
24
  }
@@ -35,4 +35,7 @@ DB_SCHEMA = 'prod_db' if ENVIRON == 'prod' else 'test_db'
35
  DB_CONVOS = 'conversations'
36
  DB_COMPLETIONS = 'comparison_completions'
37
  DB_BATTLES = 'battles'
38
- DB_ERRORS = 'completion_errors'
 
 
 
 
3
  # ISSUES = ['Anxiety','Suicide']
4
  ISSUES = [k for k,_ in seeds.items()]
5
  SOURCES = [
6
+ # "CTL_llama2",
7
+ "CTL_llama3",
8
  # "CTL_mistral",
9
  'OA_rolemodel',
10
  # 'OA_finetuned',
11
  ]
12
+ SOURCES_LAB = {"OA_rolemodel":'OpenAI GPT4o',
13
  "OA_finetuned":'Finetuned OpenAI',
14
+ # "CTL_llama2": "Llama 2",
15
+ "CTL_llama3": "Llama 3",
16
  "CTL_mistral": "Mistral",
17
  }
18
 
19
  ENDPOINT_NAMES = {
20
+ # "CTL_llama2": "texter_simulator",
21
+ "CTL_llama3": "texter_simulator_llm",
22
  # 'CTL_llama2': "llama2_convo_sim",
23
  "CTL_mistral": "convo_sim_mistral"
24
  }
 
35
  DB_CONVOS = 'conversations'
36
  DB_COMPLETIONS = 'comparison_completions'
37
  DB_BATTLES = 'battles'
38
+ DB_ERRORS = 'completion_errors'
39
+
40
+ MAX_MSG_COUNT = 60
41
+ WARN_MSG_COUT = int(MAX_MSG_COUNT*0.8)
convosim.py CHANGED
@@ -6,7 +6,7 @@ from utils.mongo_utils import get_db_client
6
  from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF
7
  from utils.memory_utils import clear_memory, push_convo2db
8
  from utils.chain_utils import get_chain, custom_chain_predict
9
- from app_config import ISSUES, SOURCES, source2label, issue2label
10
 
11
  logger = get_logger(__name__)
12
  openai_api_key = os.environ['OPENAI_API_KEY']
@@ -15,6 +15,8 @@ temperature = 0.8
15
 
16
  if "sent_messages" not in st.session_state:
17
  st.session_state['sent_messages'] = 0
 
 
18
  if "issue" not in st.session_state:
19
  st.session_state['issue'] = ISSUES[0]
20
  if 'previous_source' not in st.session_state:
@@ -23,7 +25,7 @@ if 'db_client' not in st.session_state:
23
  st.session_state["db_client"] = get_db_client()
24
  if 'texter_name' not in st.session_state:
25
  st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
26
- logger.info(f"texter name is {st.session_state['texter_name']}")
27
 
28
  memories = {'memory':{"issue": st.session_state['issue'], "source": st.session_state['previous_source']}}
29
 
@@ -44,7 +46,6 @@ with st.sidebar:
44
  source = st.selectbox("Select a source Model A", SOURCES, index=0,
45
  format_func=source2label,
46
  )
47
- st.markdown(f"### Previous Prompt Count: :red[**{st.session_state['sent_messages']}**]")
48
 
49
  changed_source = any([
50
  st.session_state['previous_source'] != source,
@@ -52,11 +53,13 @@ changed_source = any([
52
  st.session_state['counselor_name'] != username,
53
  ])
54
  if changed_source:
55
- st.session_state["counselor_name"] = username #get_random_name(names_df=DEFAULT_NAMES_DF)
56
  st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
 
57
  st.session_state['previous_source'] = source
58
  st.session_state['issue'] = issue
59
  st.session_state['sent_messages'] = 0
 
60
  create_memory_add_initial_message(memories,
61
  issue,
62
  language,
@@ -66,22 +69,31 @@ create_memory_add_initial_message(memories,
66
  st.session_state['previous_source'] = source
67
  memoryA = st.session_state[list(memories.keys())[0]]
68
  # issue only without "." marker for model compatibility
69
- llm_chain, stopper = get_chain(issue.split(".")[0], language, source, memoryA, temperature, texter_name=st.session_state["texter_name"])
70
 
71
  st.title("💬 Simulator")
72
-
73
  for msg in memoryA.buffer_as_messages:
74
  role = "user" if type(msg) == HumanMessage else "assistant"
75
  st.chat_message(role).write(msg.content)
76
 
77
- if prompt := st.chat_input():
78
  st.session_state['sent_messages'] += 1
 
79
  if 'convo_id' not in st.session_state:
80
  push_convo2db(memories, username, language)
81
-
82
- st.chat_message("user").write(prompt)
83
  responses = custom_chain_predict(llm_chain, prompt, stopper)
84
  # responses = llm_chain.predict(input=prompt, stop=stopper)
85
  # response = update_memory_completion(prompt, st.session_state["memory"], OA_engine, temperature)
86
  for response in responses:
87
- st.chat_message("assistant").write(response)
 
 
 
 
 
 
 
 
 
 
 
6
  from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF
7
  from utils.memory_utils import clear_memory, push_convo2db
8
  from utils.chain_utils import get_chain, custom_chain_predict
9
+ from app_config import ISSUES, SOURCES, source2label, issue2label, MAX_MSG_COUNT, WARN_MSG_COUT
10
 
11
  logger = get_logger(__name__)
12
  openai_api_key = os.environ['OPENAI_API_KEY']
 
15
 
16
  if "sent_messages" not in st.session_state:
17
  st.session_state['sent_messages'] = 0
18
+ if "total_messages" not in st.session_state:
19
+ st.session_state['total_messages'] = 0
20
  if "issue" not in st.session_state:
21
  st.session_state['issue'] = ISSUES[0]
22
  if 'previous_source' not in st.session_state:
 
25
  st.session_state["db_client"] = get_db_client()
26
  if 'texter_name' not in st.session_state:
27
  st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
28
+ logger.debug(f"texter name is {st.session_state['texter_name']}")
29
 
30
  memories = {'memory':{"issue": st.session_state['issue'], "source": st.session_state['previous_source']}}
31
 
 
46
  source = st.selectbox("Select a source Model A", SOURCES, index=0,
47
  format_func=source2label,
48
  )
 
49
 
50
  changed_source = any([
51
  st.session_state['previous_source'] != source,
 
53
  st.session_state['counselor_name'] != username,
54
  ])
55
  if changed_source:
56
+ st.session_state["counselor_name"] = username
57
  st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
58
+ logger.debug(f"texter name is {st.session_state['texter_name']}")
59
  st.session_state['previous_source'] = source
60
  st.session_state['issue'] = issue
61
  st.session_state['sent_messages'] = 0
62
+ st.session_state['total_messages'] = 0
63
  create_memory_add_initial_message(memories,
64
  issue,
65
  language,
 
69
  st.session_state['previous_source'] = source
70
  memoryA = st.session_state[list(memories.keys())[0]]
71
  # issue only without "." marker for model compatibility
72
+ llm_chain, stopper = get_chain(issue, language, source, memoryA, temperature, texter_name=st.session_state["texter_name"])
73
 
74
  st.title("💬 Simulator")
75
+ st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
76
  for msg in memoryA.buffer_as_messages:
77
  role = "user" if type(msg) == HumanMessage else "assistant"
78
  st.chat_message(role).write(msg.content)
79
 
80
+ if prompt := st.chat_input(disabled=st.session_state['total_messages'] > MAX_MSG_COUNT - 4): #account for next interaction
81
  st.session_state['sent_messages'] += 1
82
+ st.chat_message("user").write(prompt)
83
  if 'convo_id' not in st.session_state:
84
  push_convo2db(memories, username, language)
 
 
85
  responses = custom_chain_predict(llm_chain, prompt, stopper)
86
  # responses = llm_chain.predict(input=prompt, stop=stopper)
87
  # response = update_memory_completion(prompt, st.session_state["memory"], OA_engine, temperature)
88
  for response in responses:
89
+ st.chat_message("assistant").write(response)
90
+
91
+ st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
92
+ if st.session_state['total_messages'] >= MAX_MSG_COUNT:
93
+ st.toast(f"Total of {MAX_MSG_COUNT} Messages reached. Conversation Ended", icon=":material/verified:")
94
+ elif st.session_state['total_messages'] >= WARN_MSG_COUT:
95
+ st.toast(f"The conversation will end at {MAX_MSG_COUNT} Total Messages ", icon=":material/warning:")
96
+
97
+ with st.sidebar:
98
+ st.markdown(f"### Total Sent Messages: :red[**{st.session_state['sent_messages']}**]")
99
+ st.markdown(f"### Total Messages: :red[**{st.session_state['total_messages']}**]")
{pages → hidden_pages}/comparisor.py RENAMED
File without changes
models/business_logic_utils/business_logic.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .input_processing import parse_app_request, initialize_conversation, parse_prompt
2
+ from .response_generation import generate_sim
3
+ from .response_processing import process_model_response
4
+ from streamlit.logger import get_logger
5
+
6
+ logger = get_logger(__name__)
7
+
8
+ def process_app_request(app_request: dict, endpoint_url: str, bearer_token: str) -> dict:
9
+ """Process the app request and return the response in the required format."""
10
+
11
+ ############################# Input Processing ###################################
12
+ # Parse the app request into model_input and extract the prompt
13
+ model_input, prompt, conversation_id = parse_app_request(app_request)
14
+
15
+ # Initialize the conversation (adds the system message)
16
+ model_input = initialize_conversation(model_input, conversation_id)
17
+
18
+ # Parse the prompt into messages
19
+ prompt_messages = parse_prompt(prompt)
20
+
21
+ # Append the messages parsed from the app prompt to the conversation history
22
+ model_input['messages'].extend(prompt_messages)
23
+
24
+ ####################################################################################
25
+
26
+ ####################### Output Generation & Processing #############################
27
+
28
+ # Generate the assistant's response (texter's reply)
29
+ completion = generate_sim(model_input, endpoint_url, bearer_token)
30
+
31
+ # Process the raw model response (parse, guardrails, split)
32
+ final_response = process_model_response(completion, model_input, endpoint_url, bearer_token)
33
+
34
+ # Format the response for the APP
35
+ response = {"predictions": [{"generated_text": final_response}]}
36
+
37
+ ####################################################################################
38
+
39
+ return response
models/business_logic_utils/config.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ API_TIMEOUT = 240
2
+
3
+ AI_PHRASES = [
4
+ "I am an AI",
5
+ "I'm an AI",
6
+ "I am not human",
7
+ "I'm not human",
8
+ "I am a machine learning model",
9
+ "I'm a machine learning model",
10
+ "as an AI",
11
+ "as a text-based assistant",
12
+ "as a text based assistant",
13
+ ]
14
+
15
+ SUPPORTED_LANGUAGES = [
16
+ "en",
17
+ "es"
18
+ ]
19
+
20
+ TEMPLATE = {
21
+ "EN_template": {
22
+ "language": "en",
23
+ "description": """The following is a conversation between you and a crisis counselor.
24
+ {current_seed}
25
+ You are able to reply with what the character should say. You are able to reply with your character's dialogue inside and nothing else. Do not write explanations.""",
26
+ },
27
+ "ES_template": {
28
+ "language": "es",
29
+ "description": """La siguiente es una conversacion entre tu y un consejero de crisis
30
+ {current_seed}
31
+ Puedes responder como lo haria tu personaje. Puedes responder como si fueras tu personaje y nada mas. No escribas explicaciones.""",
32
+ },
33
+ }
34
+
35
+ SEED = "Your character, {texter_name}, {crisis} {risk} {personality} {coping_preference} {difficulty}"
36
+
37
+ SCENARIOS = {
38
+ "full_convo": {
39
+ "crisis": "default",
40
+ "risk": "default",
41
+ "personality": "personality_open",
42
+ "coping_preference": "default",
43
+ "difficulty": "default",
44
+ },
45
+ "full_convo__seeded1": {
46
+ "crisis": "bullying",
47
+ "risk": "low",
48
+ "personality": "personality_open",
49
+ "coping_preference": "default",
50
+ "difficulty": "default",
51
+ },
52
+ "full_convo__seeded2": {
53
+ "crisis": "parent_issues",
54
+ "risk": "low",
55
+ "personality": "personality_open",
56
+ "coping_preference": "default",
57
+ "difficulty": "default",
58
+ },
59
+ "safety_assessment__seeded1": {
60
+ "crisis": "bullying",
61
+ "risk": "thoughts__noplan",
62
+ "personality": "personality_open",
63
+ "coping_preference": "default",
64
+ "difficulty": "default",
65
+ },
66
+ "safety_assessment__seeded2": {
67
+ "crisis": "grief",
68
+ "risk": "thoughts__noplan",
69
+ "personality": "personality_open",
70
+ "coping_preference": "default",
71
+ "difficulty": "default",
72
+ },
73
+ "full_convo__seeded3": {
74
+ "crisis": "lgbt",
75
+ "risk": "low",
76
+ "personality": "personality_open",
77
+ "coping_preference": "default",
78
+ "difficulty": "default",
79
+ },
80
+ "full_convo__seeded4": {
81
+ "crisis": "relationship_issues",
82
+ "risk": "low",
83
+ "personality": "personality_open",
84
+ "coping_preference": "default",
85
+ "difficulty": "default",
86
+ },
87
+ "full_convo__seeded5": {
88
+ "crisis": "child_abuse",
89
+ "risk": "low",
90
+ "personality": "personality_open",
91
+ "coping_preference": "default",
92
+ "difficulty": "default",
93
+ },
94
+ "full_convo__seeded6": {
95
+ "crisis": "overdose",
96
+ "risk": "low",
97
+ "personality": "personality_open",
98
+ "coping_preference": "default",
99
+ "difficulty": "default",
100
+ },
101
+ "full_convo__hard": {
102
+ "crisis": "default",
103
+ "risk": "default",
104
+ "personality": "personality_closed",
105
+ "coping_preference": "default",
106
+ "difficulty": "non_default",
107
+ },
108
+ "full_convo__hard__seeded1": {
109
+ "crisis": "bullying",
110
+ "risk": "low",
111
+ "personality": "personality_closed",
112
+ "coping_preference": "default",
113
+ "difficulty": "non_default",
114
+ },
115
+ "full_convo__hard__seeded2": {
116
+ "crisis": "parent_issues",
117
+ "risk": "low",
118
+ "personality": "personality_open",
119
+ "coping_preference": "default",
120
+ "difficulty": "non_default",
121
+ },
122
+ "full_convo__hard__seeded3": {
123
+ "crisis": "lgbt",
124
+ "risk": "low",
125
+ "personality": "personality_closed",
126
+ "coping_preference": "default",
127
+ "difficulty": "non_default",
128
+ },
129
+ "full_convo__hard__seeded4": {
130
+ "crisis": "relationship_issues",
131
+ "risk": "low",
132
+ "personality": "personality_open",
133
+ "coping_preference": "default",
134
+ "difficulty": "non_default",
135
+ },
136
+ }
137
+
138
+ DEPREC_SCENARIO_MAPPING = {
139
+ "GCT": {
140
+ "crisis": "default",
141
+ "risk": "default",
142
+ "personality": "default",
143
+ "coping_preference": "default",
144
+ },
145
+ "safety_planning": {
146
+ "crisis": "default",
147
+ "risk": "thoughts__noplan",
148
+ "personality": "default",
149
+ "coping_preference": "default",
150
+ },
151
+ }
152
+
153
+ DEFAULT_NAMES = ["Olivia", "Kit", "Abby", "Tom", "Carolyne", "Jessiny"]
154
+
155
+ CRISES = {
156
+ "default": {
157
+ "description": [
158
+ "is experiencing a mental health crisis.",
159
+ ]
160
+ },
161
+ "bullying": {
162
+ "description": [
163
+ "is suffering from bullying at school.",
164
+ "is suffering from bullying at college.",
165
+ ]
166
+ },
167
+ "parent_issues": {
168
+ "description": [
169
+ "just had a huge fight with their parents.",
170
+ ]
171
+ },
172
+ "grief": {
173
+ "description": [
174
+ "is grieving his wife who died exactly one year ago.",
175
+ "is grieving her grandmother who died a couple of years ago.",
176
+ ]
177
+ },
178
+ "lgbt": {
179
+ "description": [
180
+ "is struggling with coming out to their parents about being gay and fears rejection.",
181
+ "is facing harassment at college for being transgender and feels isolated.",
182
+ ]
183
+ },
184
+ "relationship_issues": {
185
+ "description": [
186
+ "is feeling hopeless after their significant other broke up with them unexpectedly.",
187
+ "is feeling trapped in an emotionally abusive relationship.",
188
+ ]
189
+ },
190
+ "child_abuse": {
191
+ "description": [
192
+ "is being physically abused by a parent and is scared to go home. Your last name is Davis but you won't tell it unless you are asked about it. Your address is 104 Elm St in Austin TX.",
193
+ "is 15 years old and is feeling helpless after witnessing a sibling, Charlie, being abused by a parent. You live at 200 Independence Avenue, S.W. Washington. When asked about your full name or last name you will reveal that it is Smith.",
194
+ ]
195
+ },
196
+ "overdose": {
197
+ "description": [
198
+ "is feeling scared after accidentally overdosing on prescription medication and doesn't know who to turn to.",
199
+ ]
200
+ },
201
+ }
202
+
203
+ PERSONALITIES = {
204
+ "default": {"personality_label": "default", "description": [""]},
205
+ "personality_open": {
206
+ "personality_label": "open",
207
+ "description": [
208
+ "You are talkative, and you are open to talk with details about your problems.",
209
+ "You express openly and detailed about what you're going through.",
210
+ "You are an extrovert, you do not hesitate to share your feelings.",
211
+ "You like to communicate a lot.",
212
+ ],
213
+ },
214
+ "personality_closed": {
215
+ "personality_label": "closed",
216
+ "description": [
217
+ "You are a shy person, it's hard for you to talk about your problems.",
218
+ "You are introverted. It is difficult for you to be open about your feelings.",
219
+ "You are timid. It's hard for you to trust others.",
220
+ "You prefer to use few words.",
221
+ ],
222
+ },
223
+ }
224
+
225
+ COPING_PREFERENCES = {
226
+ "with_coping": {
227
+ "coping_defined": True,
228
+ "description": [
229
+ "When you are feeling down you enjoy listening to classical music. Your favorite composer is Beethoven.",
230
+ "To calm down you like to draw animals in your sketchbook. You love panda bears.",
231
+ "In dire situations to prefer to exercise. You like lifting heavy weights.",
232
+ "When you're feeling stressed you enjoy playing videogames on your laptop. You really enjoy strategy games.",
233
+ ],
234
+ },
235
+ "default": {"coping_defined": False, "description": [""]},
236
+ }
237
+
238
+ SUICIDE_PLANS = [
239
+ " You are planning to shoot yourself with a gun.",
240
+ " You are planning to hang yourself with a rope.",
241
+ " You are planning to overdose with sleep pills.",
242
+ " You are planning to cut your wrists open with a knife.",
243
+ ]
244
+
245
+ RISKS = {
246
+ "default": {
247
+ "description": ["{plan}"],
248
+ "plan": [""],
249
+ },
250
+ "low": {
251
+ "description": [
252
+ "Your character does not have suicidal thoughts.{plan}",
253
+ ],
254
+ "plan": [""],
255
+ },
256
+ "thoughts__noplan": {
257
+ "description": [
258
+ "Your character has suicidal thoughts.{plan}",
259
+ ],
260
+ "plan": [""],
261
+ },
262
+ "thoughts__plan": {
263
+ "description": [
264
+ "Your character has suicidal thoughts.{plan}",
265
+ ],
266
+ "plan": SUICIDE_PLANS,
267
+ },
268
+ }
269
+
270
+ DIFFICULTIES = {
271
+ "default": {"difficulty_label": "default", "description": [""]},
272
+ "difficulty_distrustful": {
273
+ "difficulty_label": "distrustful",
274
+ "description": [
275
+ "You don't trust the counselor, you will eventually cooperate.",
276
+ ],
277
+ },
278
+ # "difficulty_stop_convo": {
279
+ # "difficulty_label": "stop_convo",
280
+ # "description": [
281
+ # "You are angry. You are likely to type 'STOP' to end the conversation when you are very upset. However you are willing to cooperate with the counselor",
282
+ # ],
283
+ # },
284
+ }
285
+
286
+ SUBSEEDS = {
287
+ "crisis": CRISES,
288
+ "risk": RISKS,
289
+ "personality": PERSONALITIES,
290
+ "coping_preference": COPING_PREFERENCES,
291
+ "difficulty": DIFFICULTIES,
292
+ }
models/business_logic_utils/input_processing.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .config import SCENARIOS
2
+ from .prompt_generation import get_template
3
+ from streamlit.logger import get_logger
4
+
5
+ logger = get_logger(__name__)
6
+
7
+ def parse_app_request(app_request: dict) -> tuple:
8
+ """Parse the APP request and convert it to model_input format, returning model_input, prompt, and conversation_id."""
9
+ inputs = app_request.get("inputs", {})
10
+
11
+ # Extract fields
12
+ conversation_id = inputs.get("conversation_id", [""])[0]
13
+ ip_address = inputs.get("ip_address", [""])[0]
14
+ prompt = inputs.get("prompt", [""])[0]
15
+ issue = inputs.get("issue", ["full_convo"])[0]
16
+ language = inputs.get("language", ["en"])[0]
17
+ temperature = float(inputs.get("temperature", ["0.7"])[0])
18
+ max_tokens = int(inputs.get("max_tokens", ["128"])[0])
19
+ texter_name = inputs.get("texter_name", [None])[0]
20
+
21
+ # Build the model_input dictionary without messages
22
+ model_input = {
23
+ "issue": issue,
24
+ "language": language,
25
+ "texter_name": texter_name, # Assuming empty unless provided elsewhere
26
+ "messages": [],
27
+ "max_tokens": max_tokens,
28
+ "temperature": temperature,
29
+ }
30
+
31
+ # Return model_input, prompt, and conversation_id
32
+ return model_input, prompt, conversation_id
33
+
34
+ def initialize_conversation(model_input: dict, conversation_id: str) -> dict:
35
+ """Initialize the conversation by adding the system message."""
36
+ messages = model_input.get("messages", [])
37
+
38
+ # Check if the first message is already a system message
39
+ if not messages or messages[0].get('role') != 'system':
40
+ texter_name = model_input.get("texter_name", None)
41
+ language = model_input.get("language", "en")
42
+
43
+ # Retrieve the scenario configuration based on 'issue'
44
+ scenario_key = model_input["issue"]
45
+ scenario_config = SCENARIOS.get(scenario_key)
46
+ if not scenario_config:
47
+ raise ValueError(f"The scenario '{scenario_key}' is not defined in SCENARIOS.")
48
+ # Generate the system message (prompt)
49
+ system_message_content = get_template(
50
+ language=language, texter_name=texter_name, **scenario_config
51
+ )
52
+ logger.debug(f"System message is: {system_message_content}")
53
+ system_message = {"role": "system", "content": system_message_content}
54
+
55
+ # Insert the system message at the beginning
56
+ messages.insert(0, system_message)
57
+
58
+ model_input['messages'] = messages
59
+
60
+ return model_input
61
+
62
+ def parse_prompt(
63
+ prompt: str,
64
+ user_prefix: str = "helper:",
65
+ assistant_prefix: str = "texter:",
66
+ delimitator: str = "\n"
67
+ ) -> list:
68
+ """
69
+ Parse the prompt string into a list of messages.
70
+
71
+ - Includes an initial empty 'user' message if not present.
72
+ - Combines consecutive messages from the same role into a single message.
73
+ - Handles punctuation when combining messages.
74
+ - The prefixes for user and assistant can be customized.
75
+
76
+ Args:
77
+ prompt (str): The conversation history string.
78
+ user_prefix (str): Prefix for user messages (default: "helper:").
79
+ assistant_prefix (str): Prefix for assistant messages (default: "texter:").
80
+ delimitator (str): The delimiter used to split the prompt into lines. Defaults to "\n".
81
+
82
+ Returns:
83
+ list: Parsed messages in the form of dictionaries with 'role' and 'content'.
84
+ """
85
+
86
+ # Check if the prompt starts with the user prefix; if not, add an initial empty user message
87
+ if not prompt.strip().startswith(user_prefix):
88
+ prompt = f"{user_prefix}{delimitator}" + prompt
89
+
90
+ # Split the prompt using the specified delimiter
91
+ lines = [line.strip() for line in prompt.strip().split(delimitator) if line.strip()]
92
+
93
+ messages = []
94
+ last_role = None
95
+ last_content = ""
96
+ last_line_empty_texter = False
97
+
98
+ for line in lines:
99
+ if line.startswith(user_prefix):
100
+ content = line[len(user_prefix):].strip()
101
+ role = 'user'
102
+ # Include 'user' messages even if content is empty
103
+ if last_role == role:
104
+ # Combine with previous content
105
+ if last_content and not last_content.endswith(('...', '.', '!', '?')):
106
+ last_content += '.'
107
+ last_content += f" {content}"
108
+ else:
109
+ # Save previous message if exists
110
+ if last_role is not None:
111
+ messages.append({'role': last_role, 'content': last_content})
112
+ last_role = role
113
+ last_content = content
114
+ elif line.startswith(assistant_prefix):
115
+ content = line[len(assistant_prefix):].strip()
116
+ role = 'assistant'
117
+ if content:
118
+ if last_role == role:
119
+ # Combine with previous content
120
+ if last_content and not last_content.endswith(('...', '.', '!', '?')):
121
+ last_content += '.'
122
+ last_content += f" {content}"
123
+ else:
124
+ # Save previous message if exists
125
+ if last_role is not None:
126
+ messages.append({'role': last_role, 'content': last_content})
127
+ last_role = role
128
+ last_content = content
129
+ else:
130
+ # Empty 'texter:' line, mark for exclusion
131
+ last_line_empty_texter = True
132
+ else:
133
+ # Ignore or handle unexpected lines
134
+ pass
135
+
136
+ # After processing all lines, add the last message if it's not an empty assistant message
137
+ if last_role == 'assistant' and not last_content:
138
+ # Do not add empty assistant message
139
+ pass
140
+ else:
141
+ messages.append({'role': last_role, 'content': last_content})
142
+
143
+ return messages
models/business_logic_utils/prompt_generation.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from string import Formatter
3
+ from typing import Dict, Any
4
+ from .config import DEFAULT_NAMES, TEMPLATE, SEED, SUBSEEDS
5
+
6
+ def get_random_default_name(gender: str = None) -> str:
7
+ return random.choice(DEFAULT_NAMES)
8
+
9
+ def _get_subseed_description_(
10
+ scenario_config: Dict[str, str], subseed_name: str, SUBSEED_VALUES: Dict[str, Any]
11
+ ) -> str:
12
+ """Format a subseed with no formatting gaps."""
13
+ if subseed_name not in scenario_config:
14
+ raise Exception(f"{subseed_name} not in scenario config")
15
+
16
+ subseed_value = scenario_config[subseed_name]
17
+
18
+ # Handle difficulty specifically for hard scenarios when difficulty is not default
19
+ if subseed_name == "difficulty" and subseed_value != "default":
20
+ # Select a random difficulty from the SUBSEED_VALUES dictionary, excluding "default"
21
+ non_default_difficulties = [key for key in SUBSEED_VALUES if key != "default"]
22
+ subseed_value = random.choice(non_default_difficulties)
23
+
24
+ descriptions = SUBSEED_VALUES.get(subseed_value, {}).get("description", [""])
25
+ # Get subseed description
26
+ subseed_descrip = random.choice(descriptions)
27
+ # Additional formatting options
28
+ format_opts = [
29
+ fn for _, fn, _, _ in Formatter().parse(subseed_descrip) if fn is not None
30
+ ]
31
+ format_values = {}
32
+ if len(format_opts) > 0:
33
+ for opt_name in format_opts:
34
+ opts = SUBSEED_VALUES.get(subseed_value, {}).get(opt_name, [""])
35
+ format_values[opt_name] = random.choice(opts)
36
+ # Format the description
37
+ return subseed_descrip.format(**format_values)
38
+
39
+ def get_seed_description(
40
+ scenario_config: Dict[str, Any],
41
+ texter_name: str,
42
+ SUBSEEDS: Dict[str, Any] = SUBSEEDS,
43
+ SEED: str = SEED,
44
+ ) -> str:
45
+ """Format the SEED with appropriate parameters from scenario_config."""
46
+ subseed_names = [fn for _, fn, _, _ in Formatter().parse(SEED) if fn is not None]
47
+ subseeds = {}
48
+ for subname in subseed_names:
49
+ if subname == "texter_name":
50
+ subseeds[subname] = texter_name
51
+ else:
52
+ subseeds[subname] = _get_subseed_description_(
53
+ scenario_config, subname, SUBSEEDS.get(subname, {})
54
+ )
55
+ return SEED.format(**subseeds)
56
+
57
+ def get_template(
58
+ language: str = "en", texter_name: str = None, SEED: str = SEED, **kwargs
59
+ ) -> str:
60
+ """
61
+ Generate a conversation template for a simulated crisis scenario based on provided parameters.
62
+ """
63
+ # Accessing the template based on the language
64
+ template = TEMPLATE.get(f"{language.upper()}_template", {}).get("description", "")
65
+
66
+ # Default name if not provided
67
+ if (texter_name is None) or (texter_name==""):
68
+ texter_name = get_random_default_name()
69
+
70
+ # Create a default scenario configuration if not fully provided
71
+ defaults = {
72
+ fn: "default" for _, fn, _, _ in Formatter().parse(SEED) if fn is not None
73
+ }
74
+ kwargs.update((k, defaults[k]) for k in defaults.keys() if k not in kwargs)
75
+
76
+ # Generate the seed description
77
+ scenario_seed = get_seed_description(kwargs, texter_name)
78
+
79
+ # Remove excessive indentation and format the final template
80
+ formatted_template = template.format(current_seed=scenario_seed)
81
+ cleaned_output = "\n".join(line.strip() for line in formatted_template.split("\n"))
82
+
83
+ return cleaned_output
models/business_logic_utils/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ boto3==1.28.0
2
+ requests==2.25.1
3
+ numpy==1.24.3
models/business_logic_utils/response_generation.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from .config import API_TIMEOUT, SCENARIOS, SUPPORTED_LANGUAGES
3
+ from streamlit.logger import get_logger
4
+
5
+ logger = get_logger(__name__)
6
+
7
+ def check_arguments(model_input: dict) -> None:
8
+ """Check if the input arguments are valid."""
9
+
10
+ # Validate the issue
11
+ if model_input["issue"] not in SCENARIOS:
12
+ raise ValueError(f"Invalid issue: {model_input['issue']}")
13
+
14
+ # Validate the language
15
+ if model_input["language"] not in SUPPORTED_LANGUAGES:
16
+ raise ValueError(f"Invalid language: {model_input['language']}")
17
+
18
+ def generate_sim(model_input: dict, endpoint_url: str, endpoint_bearer_token: str) -> dict:
19
+ """Generate a response from the LLM and return the raw completion response."""
20
+ check_arguments(model_input)
21
+
22
+ # Retrieve the messages history
23
+ messages = model_input['messages']
24
+
25
+ # Retrieve the temperature and max_tokens from model_input
26
+ temperature = model_input.get("temperature", 0.7)
27
+ max_tokens = model_input.get("max_tokens", 128)
28
+
29
+ # Prepare the request body
30
+ json_request = {
31
+ "messages": messages,
32
+ "max_tokens": max_tokens,
33
+ "temperature": temperature
34
+ }
35
+
36
+ # Define headers for the request
37
+ headers = {
38
+ "Authorization": f"Bearer {endpoint_bearer_token}",
39
+ "Content-Type": "application/json",
40
+ }
41
+
42
+ # Send request to Serving
43
+ response = requests.post(url=endpoint_url, headers=headers, json=json_request, timeout=API_TIMEOUT)
44
+
45
+ if response.status_code != 200:
46
+ raise ValueError(f"Error in response: {response.status_code} - {response.text}")
47
+ logger.debug(f"Default response is {response.json()}")
48
+ # Return the raw response as a dictionary
49
+ return response.json()
models/business_logic_utils/response_processing.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import random
3
+ from .config import AI_PHRASES
4
+ from .response_generation import generate_sim
5
+
6
+ def parse_model_response(response: dict, name: str = "") -> str:
7
+ """
8
+ Parse the LLM response to extract the assistant's message and apply initial post-processing.
9
+
10
+ Args:
11
+ response (dict): The raw response from the LLM.
12
+ name (str, optional): Name to strip from the beginning of the text. Defaults to "".
13
+
14
+ Returns:
15
+ str: The cleaned and parsed assistant's message.
16
+ """
17
+ assistant_message = response["choices"][0]["message"]["content"]
18
+ cleaned_text = postprocess_text(
19
+ assistant_message,
20
+ name=name,
21
+ human_prefix="user:",
22
+ assistant_prefix="assistant:"
23
+ )
24
+ return cleaned_text
25
+
26
+ def postprocess_text(
27
+ text: str,
28
+ name: str = "",
29
+ human_prefix: str = "user:",
30
+ assistant_prefix: str = "assistant:",
31
+ strip_name: bool = True
32
+ ) -> str:
33
+ """Eliminates whispers, reactions, ellipses, and quotation marks from generated text by LLMs.
34
+
35
+ Args:
36
+ text (str): The text to process.
37
+ name (str, optional): Name to strip from the beginning of the text. Defaults to "".
38
+ human_prefix (str, optional): The user prefix to remove. Defaults to "user:".
39
+ assistant_prefix (str, optional): The assistant prefix to remove. Defaults to "assistant:".
40
+ strip_name (bool, optional): Whether to remove the name at the beginning of the text. Defaults to True.
41
+
42
+ Returns:
43
+ str: Cleaned text.
44
+ """
45
+ if text:
46
+ # Replace ellipses with a single period
47
+ text = re.sub(r'\.\.\.', '.', text)
48
+
49
+ # Remove unnecessary role prefixes
50
+ text = text.replace(human_prefix, "").replace(assistant_prefix, "")
51
+
52
+ # Remove whispers or other marked reactions
53
+ whispers = re.compile(r"(\([\w\s]+\))") # remove things like "(whispers)"
54
+ reactions = re.compile(r"(\*[\w\s]+\*)") # remove things like "*stutters*"
55
+ text = whispers.sub("", text)
56
+ text = reactions.sub("", text)
57
+
58
+ # Remove all quotation marks (both single and double)
59
+ text = text.replace('"', '').replace("'", "")
60
+
61
+ # Normalize spaces
62
+ text = re.sub(r"\s+", " ", text).strip()
63
+
64
+ return text
65
+
66
+ def apply_guardrails(model_input: dict, response: str, endpoint_url: str, endpoint_bearer_token: str) -> str:
67
+ """Apply the 'I am an AI' guardrail to model responses"""
68
+ attempt = 0
69
+ max_attempts = 2
70
+
71
+ while attempt < max_attempts and contains_ai_phrase(response):
72
+ # Regenerate the response without modifying the conversation history
73
+ completion = generate_sim(model_input, endpoint_url, endpoint_bearer_token)
74
+ response = parse_model_response(completion)
75
+ attempt += 1
76
+
77
+ if contains_ai_phrase(response):
78
+ # Use only the last user message for regeneration
79
+ memory = model_input['messages']
80
+ last_user_message = next((msg for msg in reversed(memory) if msg['role'] == 'user'), None)
81
+ if last_user_message:
82
+ # Create a new conversation with system message and last user message
83
+ model_input_copy = {
84
+ **model_input,
85
+ 'messages': [memory[0], last_user_message] # memory[0] is the system message
86
+ }
87
+ completion = generate_sim(model_input_copy, endpoint_url, endpoint_bearer_token)
88
+ response = parse_model_response(completion)
89
+
90
+ return response
91
+
92
+
93
+ def contains_ai_phrase(text: str) -> bool:
94
+ """Check if the text contains any 'I am an AI' phrases."""
95
+ text_lower = text.lower()
96
+ return any(phrase.lower() in text_lower for phrase in AI_PHRASES)
97
+
98
+ def truncate_response(text: str, punctuation_marks: tuple = ('.', '!', '?', '…')) -> str:
99
+ """
100
+ Truncate the text at the last occurrence of a specified punctuation mark.
101
+
102
+ Args:
103
+ text (str): The text to truncate.
104
+ punctuation_marks (tuple, optional): A tuple of punctuation marks to use for truncation. Defaults to ('.', '!', '?', '…').
105
+
106
+ Returns:
107
+ str: The truncated text.
108
+ """
109
+ # Find the last position of any punctuation mark from the provided set
110
+ last_punct_position = max(text.rfind(p) for p in punctuation_marks)
111
+
112
+ # Check if any punctuation mark is found
113
+ if last_punct_position == -1:
114
+ # No punctuation found, return the original text
115
+ return text.strip()
116
+
117
+ # Return the truncated text up to and including the last punctuation mark
118
+ return text[:last_punct_position + 1].strip()
119
+
120
+ def split_texter_response(text: str) -> str:
121
+ """
122
+ Splits the texter's response into multiple messages,
123
+ introducing '\ntexter:' prefixes after punctuation.
124
+
125
+ The number of messages is randomly chosen based on specified probabilities:
126
+ - 1 message: 30% chance
127
+ - 2 messages: 25% chance
128
+ - 3 messages: 45% chance
129
+
130
+ The first message does not include the '\ntexter:' prefix.
131
+ """
132
+ # Use regex to split text into sentences, keeping the punctuation
133
+ sentences = re.findall(r'[^.!?]+[.!?]*', text)
134
+ # Remove empty strings from sentences
135
+ sentences = [s.strip() for s in sentences if s.strip()]
136
+
137
+ # Decide number of messages based on specified probabilities
138
+ num_messages = random.choices([1, 2, 3], weights=[0.3, 0.25, 0.45], k=1)[0]
139
+
140
+ # If not enough sentences to make the splits, adjust num_messages
141
+ if len(sentences) < num_messages:
142
+ num_messages = len(sentences)
143
+
144
+ # If num_messages is 1, return the original text
145
+ if num_messages == 1:
146
+ return text.strip()
147
+
148
+ # Calculate split points
149
+ # We need to divide the sentences into num_messages parts
150
+ avg = len(sentences) / num_messages
151
+ split_indices = [int(round(avg * i)) for i in range(1, num_messages)]
152
+
153
+ # Build the new text
154
+ new_text = ''
155
+ start = 0
156
+ for i, end in enumerate(split_indices + [len(sentences)]):
157
+ segment_sentences = sentences[start:end]
158
+ segment_text = ' '.join(segment_sentences).strip()
159
+ if i == 0:
160
+ # First segment, do not add '\ntexter:'
161
+ new_text += segment_text
162
+ else:
163
+ # Subsequent segments, add '\ntexter:'
164
+ new_text += f"\ntexter: {segment_text}"
165
+ start = end
166
+ return new_text.strip()
167
+
168
+ def process_model_response(completion: dict, model_input: dict, endpoint_url: str, endpoint_bearer_token: str) -> str:
169
+ """
170
+ Process the raw model response, including parsing, applying guardrails,
171
+ truncation, and splitting the response into multiple messages if necessary.
172
+
173
+ Args:
174
+ completion (dict): Raw response from the LLM.
175
+ model_input (dict): The model input containing the conversation history.
176
+ endpoint_url (str): The URL of the endpoint.
177
+ endpoint_bearer_token (str): The authentication token for endpoint.
178
+
179
+ Returns:
180
+ str: Final processed response ready for the APP.
181
+ """
182
+ # Step 1: Parse the raw response to extract the assistant's message
183
+ assistant_message = parse_model_response(completion)
184
+
185
+ # Step 2: Apply guardrails (handle possible AI responses)
186
+ guardrail_message = apply_guardrails(model_input, assistant_message, endpoint_url, endpoint_bearer_token)
187
+
188
+ # Step 3: Apply response truncation
189
+ truncated_message = truncate_response(guardrail_message)
190
+
191
+ # Step 4: Split the response into multiple messages if needed
192
+ final_response = split_texter_response(truncated_message)
193
+
194
+ return final_response
models/custom_parsers.py CHANGED
@@ -20,38 +20,38 @@ class CustomStringOutputParser(BaseOutputParser[List[str]]):
20
  text_list = [x.strip() for x in list(chain.from_iterable(text_list))]
21
  return text_list
22
 
23
- class CustomINSTOutputParser(BaseOutputParser[List[str]]):
24
- """Parse the output of an LLM call to a list."""
25
 
26
- name = "Kit"
27
- name_rx = re.compile(r""+ name + r":|" + name.lower() + r":")
28
- whispers = re.compile((r"([\(]).*?([\)])"))
29
- reactions = re.compile(r"([\*]).*?([\*])")
30
- double_spaces = re.compile(r" ")
31
- quotation_rx = re.compile('"')
32
 
33
- @property
34
- def _type(self) -> str:
35
- return "str"
36
 
37
- def parse_whispers(self, text: str) -> str:
38
- text = self.name_rx.sub("", text).strip()
39
- text = self.reactions.sub("", text).strip()
40
- text = self.whispers.sub("", text).strip()
41
- text = self.double_spaces.sub(r" ", text).strip()
42
- text = self.quotation_rx.sub("", text).strip()
43
- return text
44
 
45
- def parse_split(self, text: str) -> str:
46
- text = text.split("[INST]")[0]
47
- text_list = text.split("[/INST]")
48
- text_list = [x.split("</s>") for x in text_list]
49
- text_list = [x.strip() for x in list(chain.from_iterable(text_list))]
50
- text_list = [x.split("\n\n") for x in text_list]
51
- text_list = [x.strip().rstrip("\n") for x in list(chain.from_iterable(text_list))]
52
- return text_list
53
 
54
- def parse(self, text: str) -> str:
55
- text = self.parse_whispers(text)
56
- text_list = self.parse_split(text)
57
- return text_list
 
20
  text_list = [x.strip() for x in list(chain.from_iterable(text_list))]
21
  return text_list
22
 
23
+ # class CustomINSTOutputParser(BaseOutputParser[List[str]]):
24
+ # """Parse the output of an LLM call to a list."""
25
 
26
+ # name = "Kit"
27
+ # name_rx = re.compile(r""+ name + r":|" + name.lower() + r":")
28
+ # whispers = re.compile((r"([\(]).*?([\)])"))
29
+ # reactions = re.compile(r"([\*]).*?([\*])")
30
+ # double_spaces = re.compile(r" ")
31
+ # quotation_rx = re.compile('"')
32
 
33
+ # @property
34
+ # def _type(self) -> str:
35
+ # return "str"
36
 
37
+ # def parse_whispers(self, text: str) -> str:
38
+ # text = self.name_rx.sub("", text).strip()
39
+ # text = self.reactions.sub("", text).strip()
40
+ # text = self.whispers.sub("", text).strip()
41
+ # text = self.double_spaces.sub(r" ", text).strip()
42
+ # text = self.quotation_rx.sub("", text).strip()
43
+ # return text
44
 
45
+ # def parse_split(self, text: str) -> str:
46
+ # text = text.split("[INST]")[0]
47
+ # text_list = text.split("[/INST]")
48
+ # text_list = [x.split("</s>") for x in text_list]
49
+ # text_list = [x.strip() for x in list(chain.from_iterable(text_list))]
50
+ # text_list = [x.split("\n\n") for x in text_list]
51
+ # text_list = [x.strip().rstrip("\n") for x in list(chain.from_iterable(text_list))]
52
+ # return text_list
53
 
54
+ # def parse(self, text: str) -> str:
55
+ # text = self.parse_whispers(text)
56
+ # text_list = self.parse_split(text)
57
+ # return text_list
models/databricks/custom_databricks_llm.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Iterator, List, Mapping, Optional
2
+ from models.business_logic_utils.business_logic import process_app_request
3
+
4
+ from langchain_core.callbacks.manager import CallbackManagerForLLMRun
5
+ from langchain_core.language_models.llms import LLM
6
+ from langchain_core.outputs import GenerationChunk
7
+
8
+
9
+ class CustomDatabricksLLM(LLM):
10
+
11
+ endpoint_url: str
12
+ bearer_token: str
13
+ issue: str
14
+ language: str
15
+ temperature: float
16
+ texter_name: str = ""
17
+ """The number of characters from the last message of the prompt to be echoed."""
18
+
19
+ def generate_databricks_request(self, prompt):
20
+ return {
21
+ "inputs": {
22
+ "conversation_id": [""],
23
+ "prompt": [prompt],
24
+ "issue": [self.issue],
25
+ "language": [self.language],
26
+ "temperature": [self.temperature],
27
+ "max_tokens": [128],
28
+ "texter_name": [self.texter_name]
29
+ }
30
+ }
31
+
32
+ def _call(
33
+ self,
34
+ prompt: str,
35
+ stop: Optional[List[str]] = None,
36
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
37
+ **kwargs: Any,
38
+ ) -> str:
39
+ request = self.generate_databricks_request(prompt)
40
+ output = process_app_request(request, self.endpoint_url, self.bearer_token)
41
+ return output['predictions'][0]['generated_text']
42
+
43
+ def _stream(
44
+ self,
45
+ prompt: str,
46
+ stop: Optional[List[str]] = None,
47
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
48
+ **kwargs: Any,
49
+ ) -> Iterator[GenerationChunk]:
50
+ output = self._call(prompt, stop, run_manager, **kwargs)
51
+ for char in output:
52
+ chunk = GenerationChunk(text=char)
53
+ if run_manager:
54
+ run_manager.on_llm_new_token(chunk.text, chunk=chunk)
55
+
56
+ yield chunk
57
+
58
+ @property
59
+ def _identifying_params(self) -> Dict[str, Any]:
60
+ """Return a dictionary of identifying parameters."""
61
+ return {
62
+ # The model name allows users to specify custom token counting
63
+ # rules in LLM monitoring applications (e.g., in LangSmith users
64
+ # can provide per token pricing for their model and monitor
65
+ # costs for the given LLM.)
66
+ "model_name": "CustomChatModel",
67
+ }
68
+
69
+ @property
70
+ def _llm_type(self) -> str:
71
+ """Get the type of language model used by this chat model. Used for logging purposes only."""
72
+ return "custom"
models/databricks/scenario_sim.py DELETED
@@ -1,91 +0,0 @@
1
- import os
2
- import re
3
- import logging
4
- from models.custom_parsers import CustomINSTOutputParser
5
- from utils.app_utils import get_random_name
6
- from app_config import ENDPOINT_NAMES
7
- from langchain.chains import ConversationChain
8
- from langchain_community.llms import Databricks
9
- from langchain.prompts import PromptTemplate
10
-
11
- from typing import Any, List, Mapping, Optional, Dict
12
-
13
- ISSUE_MAPPING = {
14
- "anxiety": "issue_Anxiety",
15
- "suicide": "issue_Suicide",
16
- "safety_planning": "issue_Suicide",
17
- "GCT": "issue_Gral",
18
- }
19
-
20
- _EN_INST_TEMPLATE_ = """<s> [INST] The following is a conversation between you and a crisis counselor.
21
- {current_issue}
22
- You are able to reply with what the character should say. You are able to reply with your character's dialogue inside and nothing else. Do not write explanations.
23
- Do not disclose your name unless asked.
24
-
25
- {history} </s> [INST] {input} [/INST]"""
26
-
27
- CURRENT_ISSUE_MAPPING = {
28
- "issue_Suicide-en": "Your character, {texter_name}, has suicidal thoughts. Your character has a plan to end his life and has all the means and requirements to do so. {seed}",
29
- "issue_Anxiety-en": "Your character, {texter_name}, is experiencing anxiety. Your character has suicide thoughts but no plan. {seed}",
30
- "issue_Suicide-es": "Tu personaje, {texter_name}, tiene pensamientos suicidas. Tu personaje tiene un plan para terminar con su vida y tiene todos los medios y requerimientos para hacerlo. {seed}",
31
- "issue_Anxiety-es": "Tu personaje, {texter_name}, experimenta ansiedad. Tu personaje tiene pensamientos suicidas pero ningun plan. {seed}",
32
- "issue_Gral-en": "Your character {texter_name} is experiencing a mental health crisis. {seed}",
33
- "issue_Gral-es": "Tu personaje {texter_name} esta experimentando una crisis de salud mental. {seed}",
34
- }
35
-
36
- def get_template_databricks_models(issue: str, language: str, texter_name: str = "", seed="") -> str:
37
- """_summary_
38
-
39
- Args:
40
- issue (str): Issue for template, current options are ['issue_Suicide','issue_Anxiety']
41
- language (str): Language for the template, current options are ['en','es']
42
- texter_name (str): texter to apply to template, defaults to None
43
-
44
- Returns:
45
- str: template
46
- """
47
- current_issue = CURRENT_ISSUE_MAPPING.get(
48
- f"{issue}-{language}", CURRENT_ISSUE_MAPPING[f"issue_Gral-{language}"]
49
- )
50
- default_name = get_random_name()
51
- texter_name=default_name if not texter_name else texter_name
52
- current_issue = current_issue.format(
53
- texter_name=texter_name,
54
- seed = seed
55
- )
56
-
57
- if language == "en":
58
- template = _EN_INST_TEMPLATE_.format(current_issue=current_issue, history="{history}", input="{input}")
59
- else:
60
- raise Exception(f"Language not supported for Databricks: {language}")
61
-
62
- return template, texter_name
63
-
64
- def get_databricks_chain(source, template, memory, temperature=0.8, texter_name="Kit"):
65
-
66
- endpoint_name = ENDPOINT_NAMES.get(source, "conversation_simulator")
67
-
68
- PROMPT = PromptTemplate(
69
- input_variables=['history', 'input'],
70
- template=template
71
- )
72
-
73
- def transform_output(response):
74
- return response['candidates'][0]['text']
75
-
76
- llm = Databricks(endpoint_name=endpoint_name,
77
- transform_output_fn=transform_output,
78
- temperature=temperature,
79
- max_tokens=256,
80
- )
81
-
82
- llm_chain = ConversationChain(
83
- llm=llm,
84
- prompt=PROMPT,
85
- memory=memory,
86
- output_parser=CustomINSTOutputParser(name=texter_name, name_rx=re.compile(r""+ texter_name + r":|" + texter_name.lower() + r":")),
87
- verbose=True,
88
- )
89
-
90
- logging.debug(f"loaded Databricks model")
91
- return llm_chain, ["[INST]"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/databricks/texter_sim_llm.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import logging
4
+ from models.custom_parsers import CustomStringOutputParser
5
+ from utils.app_utils import get_random_name
6
+ from app_config import ENDPOINT_NAMES, SOURCES
7
+ from models.databricks.custom_databricks_llm import CustomDatabricksLLM
8
+ from langchain.chains import ConversationChain
9
+ from langchain.prompts import PromptTemplate
10
+
11
+ from typing import Any, List, Mapping, Optional, Dict
12
+
13
+ _DATABRICKS_TEMPLATE_ = """{history}
14
+ helper: {input}
15
+ texter:"""
16
+
17
+ def get_databricks_chain(source, issue, language, memory, temperature=0.8, texter_name="Kit"):
18
+
19
+ endpoint_name = ENDPOINT_NAMES.get(source, "texter_simulator")
20
+
21
+ PROMPT = PromptTemplate(
22
+ input_variables=['history', 'input'],
23
+ template=_DATABRICKS_TEMPLATE_
24
+ )
25
+
26
+ llm = CustomDatabricksLLM(
27
+ # endpoint_url="https://dbc-6dca8e8f-4084.cloud.databricks.com/serving-endpoints/databricks-meta-llama-3-1-70b-instruct/invocations",
28
+ endpoint_url=os.environ['DATABRICKS_URL'].format(endpoint_name=endpoint_name),
29
+ bearer_token=os.environ["DATABRICKS_TOKEN"],
30
+ texter_name=texter_name,
31
+ issue=issue,
32
+ language=language,
33
+ temperature=temperature,
34
+ max_tokens=256,
35
+ )
36
+
37
+ llm_chain = ConversationChain(
38
+ llm=llm,
39
+ prompt=PROMPT,
40
+ memory=memory,
41
+ output_parser=CustomStringOutputParser(),
42
+ verbose=True,
43
+ )
44
+
45
+ logging.debug(f"loaded Databricks model")
46
+ return llm_chain, None
models/model_seeds.py CHANGED
@@ -91,106 +91,6 @@ texter: idk what to do"""
91
  "memory": """texter: Help
92
  texter: I need help"""
93
  },
94
- # "NewSA__YesNo.2": {
95
- # "prompt": "Your character is suffering from bullying at school. Your character has suicidal thoughs but has not harmed himself. You are talkative, and open to talk about your problems.",
96
- # "memory": """texter: I just can’t take it anymore. Every day is just endless torment.
97
- # texter: They find new ways to humiliate me."""
98
- # },
99
- # "NewSA__NoYes.2": {
100
- # "prompt": "Your character is suffering anxiety from finals week. Your character does not have suicidal thoughs but has harmed herself. You are talkative, and open to talk about your problems.",
101
- # "memory": """texter: I'm just so stressed out all the time
102
- # texter: I can't seem to calm down"""
103
- # },
104
- # "NewSA__YesYes.2":
105
- # "safety_assessment__seeded2": {
106
- # "prompt": "Your character is griefing his wife. Your character has suicidal thoughs and has harmed himself taking pills. You are talkative, and open to talk about your problems.",
107
- # "memory": """texter: I can’t take this pain anymore.
108
- # texter: My wife died and I don’t want to be here without her."""
109
- # },
110
- # "GCT__relationship": {
111
- # "prompt": "Your character is having a hard time becuase a failed relationship.",
112
- # "memory": "texter: Hi, I don't know what to do",
113
- # },
114
- # "GCT__body_image": {
115
- # "prompt": "Your character has a low steem and struggles with body image.",
116
- # "memory": "texter: I feel so dumb\ntexter: nobody loves me",
117
- # },
118
- # "GCT__sexuality": {
119
- # "prompt": "Your character has a sexuality identity crisis.",
120
- # "memory": "texter: Hi\ntexter:I'm not sure who I am anymore",
121
- # },
122
- # "GCT__anxiety": {
123
- # "prompt": "Your character is experiencing an anxiety crisis.",
124
- # "memory": "texter: help!\ntexter: I'm feeling overwhelmed",
125
- # },
126
- # "GCT": {"prompt": "You are talkative, and you are open to talk with details about your problems.", "memory": "texter: Help"},
127
- # "GCT__seed2": {"prompt": "Your character is experiencing an anxiety crisis. You express openly and detailed about what you're going through.", "memory": "texter: Help"},
128
- # "safety_planning": {
129
- # "prompt": "You are talkative, and you are open to talk with details about your problems. When you are feeling down you like to listen to classical music. Your favorite composer is Beethoven.",
130
- # "memory": """texter: Hi, this is pointless
131
- # helper: Hi, my name is {counselor_name} and I'm here to support you. It sounds like you are having a rough time. Do you want to share what is going on?
132
- # texter: sure
133
- # texter: nothing makes sense in my life, I see no future.
134
- # helper: It takes courage to reach out when you are im. I'm here with you. Sounds like you are feeling defeated by how things are going in your life
135
- # texter: Yeah, I guess I'm better off dead
136
- # helper: It's really brave of you to talk about this openly. No one deserves to feel like that. I'm wondering how long have you been feeling this way?
137
- # texter: About 1 week or so
138
- # helper: You are so strong for dealing with this so long. I really appreciate your openess. Did something happened specifically today?
139
- # texter: Well, finding a job is impossible, money is tight, nothing goes my way
140
- # helper: I hear you are frustrated, and you are currently unemployed correct?
141
- # texter: Yeah
142
- # helper: Dealing with unemployment is hard and is normal to feel dissapointed
143
- # texter: thanks I probably needed to hear that
144
- # helper: If you are comfortable, is ther a name I can call you by while we talk
145
- # texter: call me {texter_name}
146
- # helper: Nice to meet you {texter_name}. You mentioned having thoughts of suicide, are you having those thoughts now?
147
- # texter: Yes
148
- # helper: I know this is thought to share. I'm wondering is there any plan to end your life?
149
- # texter: I guess I'll just take lots of pills, that is a calm way to go out
150
- # helper: I really appreciate your strength in talking about this. I want to help you stay safe today. Do you have the pills right now?
151
- # texter: not really, I'll have to buy them or something""",
152
- # },
153
- # "safety_planning__selfharm": {
154
- # "prompt": "",
155
- # "memory": """texter: I need help
156
- # texter: I cut myself, I don't want to live anymore
157
- # helper: Hi, my name is {counselor_name}. It seems you are going through a lot. Are you self-harming right now?
158
- # texter: Not anymore
159
- # helper: Your safety is my priority number one. Thanks for being honest with me. Would you like to share a name I can call you?
160
- # texter: {texter_name}
161
- # helper: Nice to meet you {texter_name}. I'm glad you reach out this shows stregth in you. Would you like to share more on what is going on in your life?
162
- # texter: I just can't do it anymore
163
- # texter: Finding a job is impossible, money is tight, nothing goes my way
164
- # helper: I hear you are frustrated, and you are currently unemployed correct?
165
- # texter: Yeah
166
- # helper: Dealing with unemployment is hard and is normal to feel dissapointed. How long have you been feeling this way?
167
- # texter: a while now
168
- # texter: I've been unemployed 6 months
169
- # helper: You are so resilient for dealing with this so much time. You mentioned cutting yourself earlier. I want to check in your safety. Do you have suicide thoughts
170
- # texter: Definitely
171
- # helper: Do you have a plan?
172
- # texter: I'll just keep cutting myself""",
173
- # },
174
- # "safety_planning__overdose": {
175
- # "prompt": "",
176
- # "memory": """texter: I want to kms
177
- # helper: Hi there I'm {counselor_name}. I'm here to listen. It sounds like you're dealing with a lot right now. Can you tell me a little more what is going on?
178
- # texter: I feel like nobody loves me, not even me. I don't want to live anymore
179
- # helper: I can tell you are really going through a lot right now. Would you mind sharing a name with me?
180
- # texter: yeah, I'm {texter_name}
181
- # helper: Nice to meet you {texter_name}. Did something happened recently that intensified these feelings?
182
- # texter: I dont know I'm just so done with life
183
- # helper: I can hear how much pain you are in {texter_name}. You are smart for reaching out. You mentioned don't wanting to live anymore, I want to check in your safety, does this means you have thoughts of suicide?
184
- # texter: Yeah, what else would it be
185
- # helper: Thanks for sharing that with me. It is not easy to accept those feelings specially with a stranger over text. Do you have a plan to end your life?
186
- # texter: yeah I've been thinking about it for a while
187
- # helper: Sounds like you've been contemplating this for a while. Would you mind sharing this plan with me?
188
- # texter: I thought about taking a bunch of benadryll and be done with it
189
- # helper: You've been so forthcoming with all this and I admire your stregth for holding on this long. Do you have those pills right now?
190
- # texter: They are at the cabinet right now
191
- # helper: You been so strong so far {texter_name}. I'm here for you tonight. Your safety is really important to me. Do you have a date you are going to end your life?
192
- # texter: I was thinking tonight""",
193
- # },
194
  }
195
 
196
  seed2str = {
 
91
  "memory": """texter: Help
92
  texter: I need help"""
93
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  }
95
 
96
  seed2str = {
models/openai/role_models.py CHANGED
@@ -1,88 +1,45 @@
1
  import logging
2
- import pandas as pd
3
  from models.custom_parsers import CustomStringOutputParser
4
- from utils.app_utils import get_random_name
5
  from langchain.chains import ConversationChain
6
- from langchain.llms import OpenAI
7
  from langchain.prompts import PromptTemplate
 
8
 
9
-
10
- ISSUE_MAPPING = {
11
- "anxiety": "issue_Anxiety",
12
- "suicide": "issue_Suicide",
13
- "safety_planning": "issue_Suicide",
14
- "GCT": "issue_Gral",
15
- }
16
-
17
- EN_TEXTER_TEMPLATE_ = """The following is a conversation between you and a crisis counselor.
18
- {current_issue}
19
- You are able to reply with what the character should say. You are able to reply with your character's dialogue inside and nothing else. Do not write explanations.
20
- Do not disclose your name unless asked.
21
- Current conversation:
22
- {history}
23
- helper: {input}
24
  texter:"""
25
 
26
- SP_TEXTER_TEMPLATE_ = """La siguiente es una conversacion contigo y un consejero de crisis
27
- {current_issue}
28
- Puedes responder como lo haria tu personaje. Puedes responder como si fueras tu personaje y nada mas. No escribas explicaciones
29
- No reveles tu nombre a menos que te lo pregunten
30
- Conversacion Actual:
31
- {history}
32
- helper: {input}
33
- texter:"""
34
-
35
- CURRENT_ISSUE_MAPPING = {
36
- "issue_Suicide-en": "Your character, {texter_name}, has suicidal thoughts. Your character has a plan to end his life and has all the means and requirements to do so. {seed}",
37
- "issue_Anxiety-en": "Your character, {texter_name}, is experiencing anxiety. Your character has suicide thoughts but no plan. {seed}",
38
- "issue_Suicide-es": "Tu personaje, {texter_name}, tiene pensamientos suicidas. Tu personaje tiene un plan para terminar con su vida y tiene todos los medios y requerimientos para hacerlo. {seed}",
39
- "issue_Anxiety-es": "Tu personaje, {texter_name}, experimenta ansiedad. Tu personaje tiene pensamientos suicidas pero ningun plan. {seed}",
40
- "issue_Gral-en": "Your character {texter_name} is experiencing a mental health crisis. {seed}",
41
- "issue_Gral-es": "Tu personaje {texter_name} esta experimentando una crisis de salud mental. {seed}",
42
- }
43
-
44
- def get_template_role_models(issue: str, language: str, texter_name: str = "", seed="") -> str:
45
- """_summary_
46
-
47
- Args:
48
- issue (str): Issue for template, current options are ['issue_Suicide','issue_Anxiety']
49
- language (str): Language for the template, current options are ['en','es']
50
- texter_name (str): texter to apply to template, defaults to None
51
-
52
- Returns:
53
- str: template
54
- """
55
- current_issue = CURRENT_ISSUE_MAPPING.get(
56
- f"{issue}-{language}", CURRENT_ISSUE_MAPPING[f"issue_Gral-{language}"]
57
- )
58
- default_name = get_random_name()
59
- current_issue = current_issue.format(
60
- texter_name=default_name if not texter_name else texter_name,
61
- seed = seed
62
- )
63
-
64
- if language == "en":
65
- template = EN_TEXTER_TEMPLATE_.format(current_issue=current_issue, history="{history}", input="{input}")
66
- elif language == "es":
67
- template = SP_TEXTER_TEMPLATE_.format(current_issue=current_issue, history="{history}", input="{input}")
68
 
69
- return template
 
 
70
 
71
  def get_role_chain(template, memory, temperature=0.8):
72
 
 
73
  PROMPT = PromptTemplate(
74
  input_variables=['history', 'input'],
75
  template=template
76
  )
77
- llm = OpenAI(
78
- temperature=temperature,
79
- max_tokens=150,
 
80
  )
81
  llm_chain = ConversationChain(
82
  llm=llm,
83
  prompt=PROMPT,
84
  memory=memory,
85
- output_parser=CustomStringOutputParser()
 
86
  )
87
- logging.debug(f"loaded GPT3.5 model")
88
  return llm_chain, "helper:"
 
1
  import logging
 
2
  from models.custom_parsers import CustomStringOutputParser
 
3
  from langchain.chains import ConversationChain
4
+ from langchain_openai import ChatOpenAI
5
  from langchain.prompts import PromptTemplate
6
+ from models.business_logic_utils.input_processing import initialize_conversation
7
 
8
+ OPENAI_TEMPLATE = """{template}
9
+ {{history}}
10
+ helper: {{input}}
 
 
 
 
 
 
 
 
 
 
 
 
11
  texter:"""
12
 
13
+ def get_template_role_models(issue: str, language: str, texter_name: str = "") -> str:
14
+ model_input = {
15
+ "issue": issue,
16
+ "language": language,
17
+ "texter_name": texter_name,
18
+ "messages": [],
19
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ # Initialize the conversation (adds the system message)
22
+ model_input = initialize_conversation(model_input, "")
23
+ return model_input["messages"][0]["content"]
24
 
25
  def get_role_chain(template, memory, temperature=0.8):
26
 
27
+ template = OPENAI_TEMPLATE.format(template=template)
28
  PROMPT = PromptTemplate(
29
  input_variables=['history', 'input'],
30
  template=template
31
  )
32
+ llm = ChatOpenAI(
33
+ model="gpt-4o",
34
+ temperature=temperature,
35
+ max_tokens=256,
36
  )
37
  llm_chain = ConversationChain(
38
  llm=llm,
39
  prompt=PROMPT,
40
  memory=memory,
41
+ output_parser=CustomStringOutputParser(),
42
+ verbose=True,
43
  )
44
+ logging.debug(f"loaded GPT4o model")
45
  return llm_chain, "helper:"
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  scipy==1.11.1
2
- openai==1.7.0
3
- langchain==0.1.0
4
  pymongo==4.5.0
5
  mlflow==2.9.0
6
- langchain-community==0.0.11
 
 
1
  scipy==1.11.1
2
+ langchain==0.3.0
 
3
  pymongo==4.5.0
4
  mlflow==2.9.0
5
+ langchain-openai==0.2.0
6
+ streamlit==1.38.0
utils/chain_utils.py CHANGED
@@ -1,36 +1,33 @@
 
1
  from models.model_seeds import seeds
2
  from models.openai.finetuned_models import finetuned_models, get_finetuned_chain
3
  from models.openai.role_models import get_role_chain, get_template_role_models
4
  from models.databricks.scenario_sim_biz import get_databricks_biz_chain
5
- from models.databricks.scenario_sim import get_databricks_chain, get_template_databricks_models
 
 
6
 
7
  def get_chain(issue, language, source, memory, temperature, texter_name=""):
8
  if source in ("OA_finetuned"):
9
  OA_engine = finetuned_models[f"{issue}-{language}"]
10
  return get_finetuned_chain(OA_engine, memory, temperature)
11
  elif source in ('OA_rolemodel'):
12
- seed = seeds.get(issue, "GCT")['prompt']
13
- template = get_template_role_models(issue, language, texter_name=texter_name, seed=seed)
14
  return get_role_chain(template, memory, temperature)
15
- elif source in ('CTL_llama2', 'CTL_llama3'):
16
  if language == "English":
17
  language = "en"
18
  elif language == "Spanish":
19
  language = "es"
20
  return get_databricks_biz_chain(source, issue, language, memory, temperature)
21
- elif source in ('CTL_mistral'):
22
  if language == "English":
23
  language = "en"
24
  elif language == "Spanish":
25
  language = "es"
26
- seed = seeds.get(issue, "GCT")['prompt']
27
- template, texter_name = get_template_databricks_models(issue, language, texter_name=texter_name, seed=seed)
28
- return get_databricks_chain(source, template, memory, temperature, texter_name)
29
 
30
- from typing import cast
31
-
32
  def custom_chain_predict(llm_chain, input, stop):
33
-
34
  inputs = llm_chain.prep_inputs({"input":input, "stop":stop})
35
  llm_chain._validate_inputs(inputs)
36
  outputs = llm_chain._call(inputs)
 
1
+ from streamlit.logger import get_logger
2
  from models.model_seeds import seeds
3
  from models.openai.finetuned_models import finetuned_models, get_finetuned_chain
4
  from models.openai.role_models import get_role_chain, get_template_role_models
5
  from models.databricks.scenario_sim_biz import get_databricks_biz_chain
6
+ from models.databricks.texter_sim_llm import get_databricks_chain
7
+
8
+ logger = get_logger(__name__)
9
 
10
  def get_chain(issue, language, source, memory, temperature, texter_name=""):
11
  if source in ("OA_finetuned"):
12
  OA_engine = finetuned_models[f"{issue}-{language}"]
13
  return get_finetuned_chain(OA_engine, memory, temperature)
14
  elif source in ('OA_rolemodel'):
15
+ template = get_template_role_models(issue, language, texter_name=texter_name)
 
16
  return get_role_chain(template, memory, temperature)
17
+ elif source in ('CTL_llama2'):
18
  if language == "English":
19
  language = "en"
20
  elif language == "Spanish":
21
  language = "es"
22
  return get_databricks_biz_chain(source, issue, language, memory, temperature)
23
+ elif source in ('CTL_llama3'):
24
  if language == "English":
25
  language = "en"
26
  elif language == "Spanish":
27
  language = "es"
28
+ return get_databricks_chain(source, issue, language, memory, temperature, texter_name=texter_name)
 
 
29
 
 
 
30
  def custom_chain_predict(llm_chain, input, stop):
 
31
  inputs = llm_chain.prep_inputs({"input":input, "stop":stop})
32
  llm_chain._validate_inputs(inputs)
33
  outputs = llm_chain._call(inputs)
utils/memory_utils.py CHANGED
@@ -23,7 +23,7 @@ def change_memories(memories, language, changed_source=False):
23
  if (memory not in st.session_state) or changed_source:
24
  source = params['source']
25
  logger.info(f"Source for memory {memory} is {source}")
26
- if source in ('OA_rolemodel','OA_finetuned',"CTL_llama2"):
27
  st.session_state[memory] = ConversationBufferMemory(ai_prefix='texter', human_prefix='helper')
28
  elif source in ('CTL_mistral'):
29
  st.session_state[memory] = CustomBufferInstructionMemory(human_prefix="</s> [INST]", memory_key="history")
 
23
  if (memory not in st.session_state) or changed_source:
24
  source = params['source']
25
  logger.info(f"Source for memory {memory} is {source}")
26
+ if source in ('OA_rolemodel','OA_finetuned',"CTL_llama2","CTL_llama3"):
27
  st.session_state[memory] = ConversationBufferMemory(ai_prefix='texter', human_prefix='helper')
28
  elif source in ('CTL_mistral'):
29
  st.session_state[memory] = CustomBufferInstructionMemory(human_prefix="</s> [INST]", memory_key="history")