Spaces:
Sleeping
Sleeping
File size: 4,598 Bytes
f3e0ba5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import inspect
import pandas as pd
from .config import QUESTION2FILTERARGS, TEXTER_PREFIX, HELPER_PREFIX
# Utils to filter convo according to a phase
from .ta_filter_utils import filter_convo
def join_messages(
grp: pd.DataFrame, texter_prefix: str = "texter", helper_prefix: str = "helper"
) -> str:
"""join messages from dataframe using texter an helper prefixes
Args:
grp (pd.DataFrame): conversation in DataFrame with each row corresponding to each **message**.
Must have the following columns:
- actor_role
- message
texter_prefix (str, optional): prefix to use as the texter. Defaults to "texter".
helper_prefix (str, optional): prefix to use as the counselor (helper). Defaults to "helper".
Returns:
str: joined messages string separated by prefixes
"""
if "actor_role" not in grp:
raise Exception("Column 'actor_role' not in DataFrame")
if "message" not in grp:
raise Exception("Column 'message' not in DataFrame")
roles = grp.actor_role.replace(
{"texter": texter_prefix, "counselor": helper_prefix, "helper": helper_prefix}
)
messages = roles.str.strip() + ": " + grp.message.str.strip()
return "\n".join(messages)
def _get_context(grp: pd.DataFrame, **kwargs) -> str:
"""Get context as a str taking into account message to delete, context marker
and the type of question to use. This allows for better truncation later
Args:
grp (pd.DataFrame): conversation in DataFrame with each row corresponding to each **message**.
Must have the following columns:
- actor_role
- message
- `column`
column (str): column name in which the marker of the problem is
Returns:
pd.DataFrame: joined messages string separated by prefixes
"""
if "actor_role" not in grp:
raise Exception("Column 'actor_role' not in DataFrame")
if "message" not in grp:
raise Exception("Column 'message' not in DataFrame")
join_args = list(inspect.signature(join_messages).parameters)
join_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in join_args}
## DEPRECATED
# context_args = list(inspect.signature(get_context_on_marker).parameters)
# context_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in context_args}
return join_messages(grp, **join_kwargs)
def load_context(
messages: pd.DataFrame,
question: str,
message_col: str,
col_type: str,
inference: bool = False,
**kwargs,
) -> pd.DataFrame:
"""Load and filter conversation from messages given a question (with configured parameters of what phase that question is answered)
Args:
messages (pd.DataFrame): Messages dataframe with conversation_id, actor_role, `message_col` and phase prediction
question (str): Question to get context to
message_col (str): Column where messages are
col_type (str): type of message_col, can be "individual" or "joined"
base_dir (str, optional): Base directory to find model base args. Defaults to "../../".
Raises:
Exception: If question is not supported
Returns:
pd.DataFrame: filtered messages according to question configuration
"""
if question not in QUESTION2FILTERARGS:
raise Exception(f"Question {question} not supported")
texter_prefix = TEXTER_PREFIX
helper_prefix = HELPER_PREFIX
context_data = messages.copy()
def convo_cpc_get_context(grp, **kwargs):
"""Filter convo according to Convo Phase Classifier (CPC) predictions"""
context_ = filter_convo(grp, **QUESTION2FILTERARGS[question])
return _get_context(context_, **kwargs)
if col_type == "individual":
if "actor_role" in context_data:
context_data.dropna(subset=["actor_role"], inplace=True)
if "delete_message" in context_data:
context_data.delete_message.replace({1: True}, inplace=True)
context_data.delete_message.fillna(False, inplace=True)
context_data = (
context_data.groupby("conversation_id")
.apply(
convo_cpc_get_context,
helper_prefix=helper_prefix,
texter_prefix=texter_prefix,
)
.rename("q_context")
)
elif col_type == "joined":
context_data = context_data.groupby("conversation_id")[[message_col]].max()
context_data.rename(columns={message_col: "q_context"}, inplace=True)
return context_data |