File size: 4,598 Bytes
f3e0ba5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import inspect

import pandas as pd

from .config import QUESTION2FILTERARGS, TEXTER_PREFIX, HELPER_PREFIX

# Utils to filter convo according to a phase
from .ta_filter_utils import filter_convo


def join_messages(
    grp: pd.DataFrame, texter_prefix: str = "texter", helper_prefix: str = "helper"
) -> str:
    """join messages from dataframe using texter an helper prefixes

    Args:
        grp (pd.DataFrame): conversation in DataFrame with each row corresponding to each **message**.
            Must have the following columns:
            - actor_role
            - message

        texter_prefix (str, optional): prefix to use as the texter. Defaults to "texter".
        helper_prefix (str, optional): prefix to use as the counselor (helper). Defaults to "helper".

    Returns:
        str: joined messages string separated by prefixes
    """

    if "actor_role" not in grp:
        raise Exception("Column 'actor_role' not in DataFrame")
    if "message" not in grp:
        raise Exception("Column 'message' not in DataFrame")

    roles = grp.actor_role.replace(
        {"texter": texter_prefix, "counselor": helper_prefix, "helper": helper_prefix}
    )
    messages = roles.str.strip() + ": " + grp.message.str.strip()
    return "\n".join(messages)


def _get_context(grp: pd.DataFrame, **kwargs) -> str:
    """Get context as a str taking into account message to delete, context marker
    and the type of question to use. This allows for better truncation later

    Args:
        grp (pd.DataFrame): conversation in DataFrame with each row corresponding to each **message**.
            Must have the following columns:
            - actor_role
            - message
            - `column`
        column (str): column name in which the marker of the problem is

    Returns:
        pd.DataFrame: joined messages string separated by prefixes
    """

    if "actor_role" not in grp:
        raise Exception("Column 'actor_role' not in DataFrame")
    if "message" not in grp:
        raise Exception("Column 'message' not in DataFrame")

    join_args = list(inspect.signature(join_messages).parameters)
    join_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in join_args}

    ## DEPRECATED
    # context_args = list(inspect.signature(get_context_on_marker).parameters)
    # context_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in context_args}

    return join_messages(grp, **join_kwargs)


def load_context(
    messages: pd.DataFrame,
    question: str,
    message_col: str,
    col_type: str,
    inference: bool = False,
    **kwargs,
) -> pd.DataFrame:
    """Load and filter conversation from messages given a question (with configured parameters of what phase that question is answered)

    Args:
        messages (pd.DataFrame): Messages dataframe with conversation_id, actor_role, `message_col` and phase prediction
        question (str): Question to get context to
        message_col (str): Column where messages are
        col_type (str): type of message_col, can be "individual" or "joined"
        base_dir (str, optional): Base directory to find model base args. Defaults to "../../".

    Raises:
        Exception: If question is not supported

    Returns:
        pd.DataFrame: filtered messages according to question configuration
    """

    if question not in QUESTION2FILTERARGS:
        raise Exception(f"Question {question} not supported")

    texter_prefix = TEXTER_PREFIX
    helper_prefix = HELPER_PREFIX
    context_data = messages.copy()

    def convo_cpc_get_context(grp, **kwargs):
        """Filter convo according to Convo Phase Classifier (CPC) predictions"""
        context_ = filter_convo(grp, **QUESTION2FILTERARGS[question])
        return _get_context(context_, **kwargs)

    if col_type == "individual":
        if "actor_role" in context_data:
            context_data.dropna(subset=["actor_role"], inplace=True)
        if "delete_message" in context_data:
            context_data.delete_message.replace({1: True}, inplace=True)
            context_data.delete_message.fillna(False, inplace=True)

        context_data = (
            context_data.groupby("conversation_id")
            .apply(
                convo_cpc_get_context,
                helper_prefix=helper_prefix,
                texter_prefix=texter_prefix,
            )
            .rename("q_context")
        )
    elif col_type == "joined":
        context_data = context_data.groupby("conversation_id")[[message_col]].max()
        context_data.rename(columns={message_col: "q_context"}, inplace=True)

    return context_data