File size: 5,311 Bytes
fb096d2 099e99c 71fd9c5 7762f99 2b4b309 fb096d2 2b4b309 6521775 2b4b309 6fc91c7 fb096d2 2b4b309 fd2f716 2995161 6fc91c7 2995161 14f85b1 fb096d2 14f85b1 fb096d2 14f85b1 fb096d2 14f85b1 fb096d2 14f85b1 fb096d2 6fc91c7 2995161 14f85b1 099e99c 6fc91c7 0d28c87 099e99c 6fc91c7 40e000b 2995161 ff3c0c2 9ac3da0 ff3c0c2 7762f99 2995161 3c2fc33 099e99c fb096d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import json
from typing import List, Optional, Union
import argilla as rg
import gradio as gr
import numpy as np
import pandas as pd
from gradio.oauth import (
OAuthToken,
get_space,
)
from huggingface_hub import whoami
from jinja2 import Environment, meta
from synthetic_dataset_generator.constants import argilla_client
def get_duplicate_button():
if get_space() is not None:
return gr.DuplicateButton(size="lg")
def list_orgs(oauth_token: Union[OAuthToken, None] = None):
try:
if oauth_token is None:
return []
data = whoami(oauth_token.token)
if data["auth"]["type"] == "oauth":
organizations = [data["name"]] + [org["name"] for org in data["orgs"]]
elif data["auth"]["type"] == "access_token":
organizations = [org["name"] for org in data["orgs"]]
else:
organizations = [
entry["entity"]["name"]
for entry in data["auth"]["accessToken"]["fineGrained"]["scoped"]
if "repo.write" in entry["permissions"]
]
organizations = [org for org in organizations if org != data["name"]]
organizations = [data["name"]] + organizations
except Exception as e:
raise gr.Error(
f"Failed to get organizations: {e}. See if you are logged and connected: https://huggingface.co/settings/connected-applications."
)
return organizations
def get_org_dropdown(oauth_token: Union[OAuthToken, None] = None):
if oauth_token is not None:
orgs = list_orgs(oauth_token)
else:
orgs = []
return gr.Dropdown(
label="Organization",
choices=orgs,
value=orgs[0] if orgs else None,
allow_custom_value=True,
interactive=True,
)
def swap_visibility(oauth_token: Union[OAuthToken, None]):
if oauth_token:
return gr.update(elem_classes=["main_ui_logged_in"])
else:
return gr.update(elem_classes=["main_ui_logged_out"])
def get_argilla_client() -> Union[rg.Argilla, None]:
return argilla_client
def get_preprocess_labels(labels: Optional[List[str]]) -> List[str]:
return list(set([label.lower().strip() for label in labels])) if labels else []
def column_to_list(dataframe: pd.DataFrame, column_name: str) -> List[str]:
if column_name in dataframe.columns:
return dataframe[column_name].tolist()
else:
raise ValueError(f"Column '{column_name}' does not exist.")
def process_columns(
dataframe,
instruction_column: str,
response_columns: Union[str, List[str]],
) -> List[dict]:
instruction_column = [instruction_column]
if isinstance(response_columns, str):
response_columns = [response_columns]
data = []
for _, row in dataframe.iterrows():
instruction = ""
for col in instruction_column:
value = row[col]
if isinstance(value, (list, np.ndarray)):
user_contents = [d["content"] for d in value if d.get("role") == "user"]
if user_contents:
instruction = user_contents[-1]
elif isinstance(value, str):
try:
parsed_message = json.loads(value)
user_contents = [
d["content"] for d in parsed_message if d.get("role") == "user"
]
if user_contents:
instruction = user_contents[-1]
except json.JSONDecodeError:
instruction = value
else:
instruction = ""
generations = []
for col in response_columns:
value = row[col]
if isinstance(value, (list, np.ndarray)):
if all(isinstance(item, dict) and "role" in item for item in value):
assistant_contents = [
d["content"] for d in value if d.get("role") == "assistant"
]
if assistant_contents:
generations.append(assistant_contents[-1])
else:
generations.extend(value)
elif isinstance(value, str):
try:
parsed_message = json.loads(value)
assistant_contents = [
d["content"]
for d in parsed_message
if d.get("role") == "assistant"
]
if assistant_contents:
generations.append(assistant_contents[-1])
except json.JSONDecodeError:
generations.append(value)
else:
pass
data.append({"instruction": instruction, "generations": generations})
return data
def extract_column_names(prompt_template: str) -> List[str]:
env = Environment()
parsed_content = env.parse(prompt_template)
variables = meta.find_undeclared_variables(parsed_content)
return list(variables)
def pad_or_truncate_list(lst, target_length):
lst = lst or []
lst_length = len(lst)
if lst_length >= target_length:
return lst[-target_length:]
else:
return lst + [None] * (target_length - lst_length)
|