File size: 5,311 Bytes
fb096d2
099e99c
71fd9c5
7762f99
2b4b309
fb096d2
 
2b4b309
6521775
2b4b309
 
6fc91c7
fb096d2
2b4b309
fd2f716
2995161
6fc91c7
 
 
 
 
 
2995161
14f85b1
 
 
 
 
fb096d2
14f85b1
fb096d2
14f85b1
fb096d2
14f85b1
 
 
 
fb096d2
 
14f85b1
 
 
 
fb096d2
6fc91c7
 
2995161
14f85b1
099e99c
 
 
6fc91c7
0d28c87
 
 
 
099e99c
6fc91c7
40e000b
 
2995161
ff3c0c2
9ac3da0
ff3c0c2
 
7762f99
 
 
2995161
 
3c2fc33
 
099e99c
fb096d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import json
from typing import List, Optional, Union

import argilla as rg
import gradio as gr
import numpy as np
import pandas as pd
from gradio.oauth import (
    OAuthToken,
    get_space,
)
from huggingface_hub import whoami
from jinja2 import Environment, meta

from synthetic_dataset_generator.constants import argilla_client


def get_duplicate_button():
    if get_space() is not None:
        return gr.DuplicateButton(size="lg")


def list_orgs(oauth_token: Union[OAuthToken, None] = None):
    try:
        if oauth_token is None:
            return []
        data = whoami(oauth_token.token)
        if data["auth"]["type"] == "oauth":
            organizations = [data["name"]] + [org["name"] for org in data["orgs"]]
        elif data["auth"]["type"] == "access_token":
            organizations = [org["name"] for org in data["orgs"]]
        else:
            organizations = [
                entry["entity"]["name"]
                for entry in data["auth"]["accessToken"]["fineGrained"]["scoped"]
                if "repo.write" in entry["permissions"]
            ]
            organizations = [org for org in organizations if org != data["name"]]
            organizations = [data["name"]] + organizations
    except Exception as e:
        raise gr.Error(
            f"Failed to get organizations: {e}. See if you are logged and connected: https://huggingface.co/settings/connected-applications."
        )
    return organizations


def get_org_dropdown(oauth_token: Union[OAuthToken, None] = None):
    if oauth_token is not None:
        orgs = list_orgs(oauth_token)
    else:
        orgs = []
    return gr.Dropdown(
        label="Organization",
        choices=orgs,
        value=orgs[0] if orgs else None,
        allow_custom_value=True,
        interactive=True,
    )


def swap_visibility(oauth_token: Union[OAuthToken, None]):
    if oauth_token:
        return gr.update(elem_classes=["main_ui_logged_in"])
    else:
        return gr.update(elem_classes=["main_ui_logged_out"])


def get_argilla_client() -> Union[rg.Argilla, None]:
    return argilla_client


def get_preprocess_labels(labels: Optional[List[str]]) -> List[str]:
    return list(set([label.lower().strip() for label in labels])) if labels else []


def column_to_list(dataframe: pd.DataFrame, column_name: str) -> List[str]:
    if column_name in dataframe.columns:
        return dataframe[column_name].tolist()
    else:
        raise ValueError(f"Column '{column_name}' does not exist.")


def process_columns(
    dataframe,
    instruction_column: str,
    response_columns: Union[str, List[str]],
) -> List[dict]:
    instruction_column = [instruction_column]
    if isinstance(response_columns, str):
        response_columns = [response_columns]

    data = []
    for _, row in dataframe.iterrows():
        instruction = ""
        for col in instruction_column:
            value = row[col]
            if isinstance(value, (list, np.ndarray)):
                user_contents = [d["content"] for d in value if d.get("role") == "user"]
                if user_contents:
                    instruction = user_contents[-1]
            elif isinstance(value, str):
                try:
                    parsed_message = json.loads(value)
                    user_contents = [
                        d["content"] for d in parsed_message if d.get("role") == "user"
                    ]
                    if user_contents:
                        instruction = user_contents[-1]
                except json.JSONDecodeError:
                    instruction = value
            else:
                instruction = ""

        generations = []
        for col in response_columns:
            value = row[col]
            if isinstance(value, (list, np.ndarray)):
                if all(isinstance(item, dict) and "role" in item for item in value):
                    assistant_contents = [
                        d["content"] for d in value if d.get("role") == "assistant"
                    ]
                    if assistant_contents:
                        generations.append(assistant_contents[-1])
                else:
                    generations.extend(value)
            elif isinstance(value, str):
                try:
                    parsed_message = json.loads(value)
                    assistant_contents = [
                        d["content"]
                        for d in parsed_message
                        if d.get("role") == "assistant"
                    ]
                    if assistant_contents:
                        generations.append(assistant_contents[-1])
                except json.JSONDecodeError:
                    generations.append(value)
            else:
                pass

        data.append({"instruction": instruction, "generations": generations})

    return data


def extract_column_names(prompt_template: str) -> List[str]:
    env = Environment()
    parsed_content = env.parse(prompt_template)
    variables = meta.find_undeclared_variables(parsed_content)
    return list(variables)


def pad_or_truncate_list(lst, target_length):
    lst = lst or []
    lst_length = len(lst)
    if lst_length >= target_length:
        return lst[-target_length:]
    else:
        return lst + [None] * (target_length - lst_length)