H2OTest / llm_studio /python_configs /text_causal_classification_modeling_config.py
elineve's picture
Upload 301 files
07423df
raw
history blame
7.96 kB
import os
from dataclasses import dataclass, field
from typing import Any, Dict, List, Tuple
import llm_studio.src.datasets.text_causal_classification_ds
import llm_studio.src.plots.text_causal_classification_modeling_plots
from llm_studio.python_configs.base import DefaultConfig, DefaultConfigProblemBase
from llm_studio.python_configs.text_causal_language_modeling_config import (
ConfigNLPAugmentation,
ConfigNLPCausalLMArchitecture,
ConfigNLPCausalLMDataset,
ConfigNLPCausalLMEnvironment,
ConfigNLPCausalLMLogging,
ConfigNLPCausalLMTokenizer,
ConfigNLPCausalLMTraining,
)
from llm_studio.src import possible_values
from llm_studio.src.losses import text_causal_classification_modeling_losses
from llm_studio.src.metrics import text_causal_classification_modeling_metrics
from llm_studio.src.models import text_causal_classification_modeling_model
from llm_studio.src.utils.modeling_utils import generate_experiment_name
@dataclass
class ConfigNLPCausalClassificationDataset(ConfigNLPCausalLMDataset):
dataset_class: Any = (
llm_studio.src.datasets.text_causal_classification_ds.CustomDataset
)
system_column: str = "None"
prompt_column: Tuple[str, ...] = ("instruction", "input")
answer_column: str = "label"
num_classes: int = 1
parent_id_column: str = "None"
text_system_start: str = ""
text_prompt_start: str = ""
text_answer_separator: str = ""
add_eos_token_to_system: bool = False
add_eos_token_to_prompt: bool = False
add_eos_token_to_answer: bool = False
_allowed_file_extensions: Tuple[str, ...] = ("csv", "pq", "parquet")
def __post_init__(self):
self.prompt_column = (
tuple(
self.prompt_column,
)
if isinstance(self.prompt_column, str)
else tuple(self.prompt_column)
)
super().__post_init__()
self._possible_values["num_classes"] = (1, 100, 1)
self._visibility["personalize"] = -1
self._visibility["chatbot_name"] = -1
self._visibility["chatbot_author"] = -1
self._visibility["mask_prompt_labels"] = -1
self._visibility["add_eos_token_to_answer"] = -1
@dataclass
class ConfigNLPCausalClassificationTraining(ConfigNLPCausalLMTraining):
loss_class: Any = text_causal_classification_modeling_losses.Losses
loss_function: str = "BinaryCrossEntropyLoss"
learning_rate: float = 0.0001
differential_learning_rate_layers: Tuple[str, ...] = ("classification_head",)
differential_learning_rate: float = 0.00001
def __post_init__(self):
super().__post_init__()
self._possible_values["loss_function"] = self.loss_class.names()
self._possible_values["differential_learning_rate_layers"] = (
possible_values.String(
values=("backbone", "embed", "classification_head"),
allow_custom=False,
placeholder="Select optional layers...",
)
)
@dataclass
class ConfigNLPCausalClassificationTokenizer(ConfigNLPCausalLMTokenizer):
max_length_prompt: int = 512
max_length: int = 512
def __post_init__(self):
super().__post_init__()
self._visibility["max_length_answer"] = -1
@dataclass
class ConfigNLPCausalClassificationArchitecture(ConfigNLPCausalLMArchitecture):
model_class: Any = text_causal_classification_modeling_model.Model
def __post_init__(self):
super().__post_init__()
@dataclass
class ConfigNLPCausalClassificationPrediction(DefaultConfig):
metric_class: Any = text_causal_classification_modeling_metrics.Metrics
metric: str = "AUC"
batch_size_inference: int = 0
def __post_init__(self):
super().__post_init__()
self._possible_values["metric"] = self.metric_class.names()
self._possible_values["batch_size_inference"] = (0, 512, 1)
self._visibility["metric_class"] = -1
@dataclass
class ConfigNLPCausalClassificationEnvironment(ConfigNLPCausalLMEnvironment):
_model_card_template: str = "text_causal_classification_model_card_template.md"
_summary_card_template: str = (
"text_causal_classification_experiment_summary_card_template.md"
)
def __post_init__(self):
super().__post_init__()
@dataclass
class ConfigNLPCausalClassificationLogging(ConfigNLPCausalLMLogging):
plots_class: Any = (
llm_studio.src.plots.text_causal_classification_modeling_plots.Plots
)
@dataclass
class ConfigProblemBase(DefaultConfigProblemBase):
output_directory: str = f"output/{os.path.basename(__file__).split('.')[0]}"
experiment_name: str = field(default_factory=generate_experiment_name)
_parent_experiment: str = ""
llm_backbone: str = "h2oai/h2ogpt-4096-llama2-7b"
dataset: ConfigNLPCausalClassificationDataset = field(
default_factory=ConfigNLPCausalClassificationDataset
)
tokenizer: ConfigNLPCausalLMTokenizer = field(
default_factory=ConfigNLPCausalLMTokenizer
)
architecture: ConfigNLPCausalClassificationArchitecture = field(
default_factory=ConfigNLPCausalClassificationArchitecture
)
training: ConfigNLPCausalClassificationTraining = field(
default_factory=ConfigNLPCausalClassificationTraining
)
augmentation: ConfigNLPAugmentation = field(default_factory=ConfigNLPAugmentation)
prediction: ConfigNLPCausalClassificationPrediction = field(
default_factory=ConfigNLPCausalClassificationPrediction
)
environment: ConfigNLPCausalClassificationEnvironment = field(
default_factory=ConfigNLPCausalClassificationEnvironment
)
logging: ConfigNLPCausalClassificationLogging = field(
default_factory=ConfigNLPCausalClassificationLogging
)
def __post_init__(self):
super().__post_init__()
self._visibility["output_directory"] = -1
self._possible_values["llm_backbone"] = possible_values.String(
values=(
"h2oai/h2o-danube2-1.8b-base",
"h2oai/h2o-danube2-1.8b-chat",
"h2oai/h2ogpt-4096-llama2-7b",
"h2oai/h2ogpt-4096-llama2-7b-chat",
"h2oai/h2ogpt-4096-llama2-13b",
"h2oai/h2ogpt-4096-llama2-13b-chat",
"h2oai/h2ogpt-4096-llama2-70b",
"h2oai/h2ogpt-4096-llama2-70b-chat",
"tiiuae/falcon-7b",
"mistralai/Mistral-7B-v0.1",
"HuggingFaceH4/zephyr-7b-beta",
"google/gemma-2b",
"google/gemma-7b",
"stabilityai/stablelm-3b-4e1t",
"microsoft/phi-2",
"facebook/opt-125m",
),
allow_custom=True,
)
def check(self) -> Dict[str, List]:
errors: Dict[str, List] = {"title": [], "message": []}
if self.training.loss_function == "CrossEntropyLoss":
if self.dataset.num_classes == 1:
errors["title"] += ["CrossEntropyLoss requires num_classes > 1"]
errors["message"] += [
"CrossEntropyLoss requires num_classes > 1, "
"but num_classes is set to 1."
]
elif self.training.loss_function == "BinaryCrossEntropyLoss":
if self.dataset.num_classes != 1:
errors["title"] += ["BinaryCrossEntropyLoss requires num_classes == 1"]
errors["message"] += [
"BinaryCrossEntropyLoss requires num_classes == 1, "
"but num_classes is set to {}.".format(self.dataset.num_classes)
]
if self.dataset.parent_id_column not in ["None", None]:
errors["title"] += ["Parent ID column is not supported for classification"]
errors["message"] += [
"Parent ID column is not supported for classification datasets."
]
return errors