|
import os |
|
from dataclasses import dataclass, field |
|
from typing import Any, Dict, List, Tuple |
|
|
|
import llm_studio.src.datasets.text_causal_classification_ds |
|
import llm_studio.src.plots.text_causal_classification_modeling_plots |
|
from llm_studio.python_configs.base import DefaultConfig, DefaultConfigProblemBase |
|
from llm_studio.python_configs.text_causal_language_modeling_config import ( |
|
ConfigNLPAugmentation, |
|
ConfigNLPCausalLMArchitecture, |
|
ConfigNLPCausalLMDataset, |
|
ConfigNLPCausalLMEnvironment, |
|
ConfigNLPCausalLMLogging, |
|
ConfigNLPCausalLMTokenizer, |
|
ConfigNLPCausalLMTraining, |
|
) |
|
from llm_studio.src import possible_values |
|
from llm_studio.src.losses import text_causal_classification_modeling_losses |
|
from llm_studio.src.metrics import text_causal_classification_modeling_metrics |
|
from llm_studio.src.models import text_causal_classification_modeling_model |
|
from llm_studio.src.utils.modeling_utils import generate_experiment_name |
|
|
|
|
|
@dataclass |
|
class ConfigNLPCausalClassificationDataset(ConfigNLPCausalLMDataset): |
|
dataset_class: Any = ( |
|
llm_studio.src.datasets.text_causal_classification_ds.CustomDataset |
|
) |
|
system_column: str = "None" |
|
prompt_column: Tuple[str, ...] = ("instruction", "input") |
|
answer_column: str = "label" |
|
num_classes: int = 1 |
|
parent_id_column: str = "None" |
|
|
|
text_system_start: str = "" |
|
text_prompt_start: str = "" |
|
text_answer_separator: str = "" |
|
|
|
add_eos_token_to_system: bool = False |
|
add_eos_token_to_prompt: bool = False |
|
add_eos_token_to_answer: bool = False |
|
|
|
_allowed_file_extensions: Tuple[str, ...] = ("csv", "pq", "parquet") |
|
|
|
def __post_init__(self): |
|
self.prompt_column = ( |
|
tuple( |
|
self.prompt_column, |
|
) |
|
if isinstance(self.prompt_column, str) |
|
else tuple(self.prompt_column) |
|
) |
|
super().__post_init__() |
|
|
|
self._possible_values["num_classes"] = (1, 100, 1) |
|
|
|
self._visibility["personalize"] = -1 |
|
self._visibility["chatbot_name"] = -1 |
|
self._visibility["chatbot_author"] = -1 |
|
self._visibility["mask_prompt_labels"] = -1 |
|
self._visibility["add_eos_token_to_answer"] = -1 |
|
|
|
|
|
@dataclass |
|
class ConfigNLPCausalClassificationTraining(ConfigNLPCausalLMTraining): |
|
loss_class: Any = text_causal_classification_modeling_losses.Losses |
|
loss_function: str = "BinaryCrossEntropyLoss" |
|
|
|
learning_rate: float = 0.0001 |
|
differential_learning_rate_layers: Tuple[str, ...] = ("classification_head",) |
|
differential_learning_rate: float = 0.00001 |
|
|
|
def __post_init__(self): |
|
super().__post_init__() |
|
self._possible_values["loss_function"] = self.loss_class.names() |
|
|
|
self._possible_values["differential_learning_rate_layers"] = ( |
|
possible_values.String( |
|
values=("backbone", "embed", "classification_head"), |
|
allow_custom=False, |
|
placeholder="Select optional layers...", |
|
) |
|
) |
|
|
|
|
|
@dataclass |
|
class ConfigNLPCausalClassificationTokenizer(ConfigNLPCausalLMTokenizer): |
|
max_length_prompt: int = 512 |
|
max_length: int = 512 |
|
|
|
def __post_init__(self): |
|
super().__post_init__() |
|
|
|
self._visibility["max_length_answer"] = -1 |
|
|
|
|
|
@dataclass |
|
class ConfigNLPCausalClassificationArchitecture(ConfigNLPCausalLMArchitecture): |
|
model_class: Any = text_causal_classification_modeling_model.Model |
|
|
|
def __post_init__(self): |
|
super().__post_init__() |
|
|
|
|
|
@dataclass |
|
class ConfigNLPCausalClassificationPrediction(DefaultConfig): |
|
metric_class: Any = text_causal_classification_modeling_metrics.Metrics |
|
metric: str = "AUC" |
|
batch_size_inference: int = 0 |
|
|
|
def __post_init__(self): |
|
super().__post_init__() |
|
|
|
self._possible_values["metric"] = self.metric_class.names() |
|
self._possible_values["batch_size_inference"] = (0, 512, 1) |
|
|
|
self._visibility["metric_class"] = -1 |
|
|
|
|
|
@dataclass |
|
class ConfigNLPCausalClassificationEnvironment(ConfigNLPCausalLMEnvironment): |
|
_model_card_template: str = "text_causal_classification_model_card_template.md" |
|
_summary_card_template: str = ( |
|
"text_causal_classification_experiment_summary_card_template.md" |
|
) |
|
|
|
def __post_init__(self): |
|
super().__post_init__() |
|
|
|
|
|
@dataclass |
|
class ConfigNLPCausalClassificationLogging(ConfigNLPCausalLMLogging): |
|
plots_class: Any = ( |
|
llm_studio.src.plots.text_causal_classification_modeling_plots.Plots |
|
) |
|
|
|
|
|
@dataclass |
|
class ConfigProblemBase(DefaultConfigProblemBase): |
|
output_directory: str = f"output/{os.path.basename(__file__).split('.')[0]}" |
|
experiment_name: str = field(default_factory=generate_experiment_name) |
|
_parent_experiment: str = "" |
|
llm_backbone: str = "h2oai/h2ogpt-4096-llama2-7b" |
|
|
|
dataset: ConfigNLPCausalClassificationDataset = field( |
|
default_factory=ConfigNLPCausalClassificationDataset |
|
) |
|
tokenizer: ConfigNLPCausalLMTokenizer = field( |
|
default_factory=ConfigNLPCausalLMTokenizer |
|
) |
|
architecture: ConfigNLPCausalClassificationArchitecture = field( |
|
default_factory=ConfigNLPCausalClassificationArchitecture |
|
) |
|
training: ConfigNLPCausalClassificationTraining = field( |
|
default_factory=ConfigNLPCausalClassificationTraining |
|
) |
|
augmentation: ConfigNLPAugmentation = field(default_factory=ConfigNLPAugmentation) |
|
prediction: ConfigNLPCausalClassificationPrediction = field( |
|
default_factory=ConfigNLPCausalClassificationPrediction |
|
) |
|
environment: ConfigNLPCausalClassificationEnvironment = field( |
|
default_factory=ConfigNLPCausalClassificationEnvironment |
|
) |
|
logging: ConfigNLPCausalClassificationLogging = field( |
|
default_factory=ConfigNLPCausalClassificationLogging |
|
) |
|
|
|
def __post_init__(self): |
|
super().__post_init__() |
|
|
|
self._visibility["output_directory"] = -1 |
|
|
|
self._possible_values["llm_backbone"] = possible_values.String( |
|
values=( |
|
"h2oai/h2o-danube2-1.8b-base", |
|
"h2oai/h2o-danube2-1.8b-chat", |
|
"h2oai/h2ogpt-4096-llama2-7b", |
|
"h2oai/h2ogpt-4096-llama2-7b-chat", |
|
"h2oai/h2ogpt-4096-llama2-13b", |
|
"h2oai/h2ogpt-4096-llama2-13b-chat", |
|
"h2oai/h2ogpt-4096-llama2-70b", |
|
"h2oai/h2ogpt-4096-llama2-70b-chat", |
|
"tiiuae/falcon-7b", |
|
"mistralai/Mistral-7B-v0.1", |
|
"HuggingFaceH4/zephyr-7b-beta", |
|
"google/gemma-2b", |
|
"google/gemma-7b", |
|
"stabilityai/stablelm-3b-4e1t", |
|
"microsoft/phi-2", |
|
"facebook/opt-125m", |
|
), |
|
allow_custom=True, |
|
) |
|
|
|
def check(self) -> Dict[str, List]: |
|
errors: Dict[str, List] = {"title": [], "message": []} |
|
|
|
if self.training.loss_function == "CrossEntropyLoss": |
|
if self.dataset.num_classes == 1: |
|
errors["title"] += ["CrossEntropyLoss requires num_classes > 1"] |
|
errors["message"] += [ |
|
"CrossEntropyLoss requires num_classes > 1, " |
|
"but num_classes is set to 1." |
|
] |
|
elif self.training.loss_function == "BinaryCrossEntropyLoss": |
|
if self.dataset.num_classes != 1: |
|
errors["title"] += ["BinaryCrossEntropyLoss requires num_classes == 1"] |
|
errors["message"] += [ |
|
"BinaryCrossEntropyLoss requires num_classes == 1, " |
|
"but num_classes is set to {}.".format(self.dataset.num_classes) |
|
] |
|
if self.dataset.parent_id_column not in ["None", None]: |
|
errors["title"] += ["Parent ID column is not supported for classification"] |
|
errors["message"] += [ |
|
"Parent ID column is not supported for classification datasets." |
|
] |
|
|
|
return errors |
|
|