Fix training router to accept form column_mapping and prompt_template format
Browse files- app/routers/training.py +174 -694
app/routers/training.py
CHANGED
|
@@ -23,519 +23,50 @@ router = APIRouter(prefix="/api/training", tags=["Training"])
|
|
| 23 |
|
| 24 |
|
| 25 |
# ============================================
|
| 26 |
-
#
|
| 27 |
# ============================================
|
| 28 |
|
| 29 |
-
class
|
| 30 |
-
"""
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
)
|
| 36 |
-
columns: List[str] = Field(
|
| 37 |
-
default_factory=list,
|
| 38 |
-
description="List of dataset columns used in this section"
|
| 39 |
-
)
|
| 40 |
-
prefix: str = Field(default="", description="Prefix before content (e.g., 'System: ')")
|
| 41 |
-
suffix: str = Field(default="", description="Suffix after content (e.g., '\n\n')")
|
| 42 |
-
strip_whitespace: bool = Field(default=True)
|
| 43 |
-
required: bool = Field(default=False, description="Raise error if columns missing")
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
class PromptTemplateConfig(BaseModel):
|
| 47 |
-
"""Full prompt template configuration for training data formatting"""
|
| 48 |
-
# Chat-style formatting
|
| 49 |
-
use_chat_format: bool = Field(default=True, description="Use chat-style message format")
|
| 50 |
-
chat_template: Optional[str] = Field(
|
| 51 |
-
default=None,
|
| 52 |
-
description="Jinja2 chat template (auto-detect from tokenizer if None)"
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
# Message sections
|
| 56 |
-
system: Optional[PromptSectionConfig] = Field(
|
| 57 |
-
default=None,
|
| 58 |
-
description="System message configuration"
|
| 59 |
-
)
|
| 60 |
-
user: Optional[PromptSectionConfig] = Field(
|
| 61 |
-
default=None,
|
| 62 |
-
description="User message configuration"
|
| 63 |
-
)
|
| 64 |
-
context: Optional[PromptSectionConfig] = Field(
|
| 65 |
-
default=None,
|
| 66 |
-
description="Context/passages configuration"
|
| 67 |
-
)
|
| 68 |
-
reasoning: Optional[PromptSectionConfig] = Field(
|
| 69 |
-
default=None,
|
| 70 |
-
description="Reasoning/chain-of-thought configuration"
|
| 71 |
-
)
|
| 72 |
-
assistant: Optional[PromptSectionConfig] = Field(
|
| 73 |
-
default=None,
|
| 74 |
-
description="Assistant response configuration (target for training)"
|
| 75 |
-
)
|
| 76 |
-
|
| 77 |
-
# Custom sections for flexibility
|
| 78 |
-
custom_sections: Dict[str, PromptSectionConfig] = Field(
|
| 79 |
-
default_factory=dict,
|
| 80 |
-
description="Additional custom sections"
|
| 81 |
-
)
|
| 82 |
-
|
| 83 |
-
# Section ordering
|
| 84 |
-
section_order: List[str] = Field(
|
| 85 |
-
default=["system", "context", "user", "reasoning", "assistant"],
|
| 86 |
-
description="Order of sections in the prompt"
|
| 87 |
-
)
|
| 88 |
-
|
| 89 |
-
# Special tokens
|
| 90 |
-
bos_token: Optional[str] = Field(default=None, description="Beginning of sequence token")
|
| 91 |
-
eos_token: Optional[str] = Field(default=None, description="End of sequence token")
|
| 92 |
-
pad_token: Optional[str] = Field(default=None, description="Padding token")
|
| 93 |
-
|
| 94 |
-
# Separator configuration
|
| 95 |
-
section_separator: str = Field(default="\n\n", description="Separator between sections")
|
| 96 |
-
message_separator: str = Field(default="\n", description="Separator between messages")
|
| 97 |
-
|
| 98 |
-
# Instruction format (for instruction-tuned models)
|
| 99 |
-
instruction_format: str = Field(
|
| 100 |
-
default="none",
|
| 101 |
-
description="Preset format: none, alpaca, chatml, llama3, mistral, vicuna, phi3"
|
| 102 |
-
)
|
| 103 |
-
|
| 104 |
-
def get_template_for_format(self, format_name: str) -> Dict[str, Any]:
|
| 105 |
-
"""Get preset template configuration for known formats"""
|
| 106 |
-
presets = {
|
| 107 |
-
"none": {},
|
| 108 |
-
"alpaca": {
|
| 109 |
-
"system": PromptSectionConfig(
|
| 110 |
-
enabled=True,
|
| 111 |
-
template="Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.",
|
| 112 |
-
prefix="", suffix="\n\n"
|
| 113 |
-
),
|
| 114 |
-
"user": PromptSectionConfig(
|
| 115 |
-
enabled=True,
|
| 116 |
-
template="### Instruction:\n{instruction}\n\n### Input:\n{input}",
|
| 117 |
-
columns=["instruction", "input"],
|
| 118 |
-
prefix="", suffix="\n\n"
|
| 119 |
-
),
|
| 120 |
-
"assistant": PromptSectionConfig(
|
| 121 |
-
enabled=True,
|
| 122 |
-
template="### Response:\n{output}",
|
| 123 |
-
columns=["output"],
|
| 124 |
-
prefix="", suffix=""
|
| 125 |
-
),
|
| 126 |
-
"section_order": ["system", "user", "assistant"]
|
| 127 |
-
},
|
| 128 |
-
"chatml": {
|
| 129 |
-
"system": PromptSectionConfig(
|
| 130 |
-
enabled=True,
|
| 131 |
-
template="{system_message}",
|
| 132 |
-
columns=["system_message"],
|
| 133 |
-
prefix="<|im_start|>system\n", suffix="<|im_end|>\n"
|
| 134 |
-
),
|
| 135 |
-
"user": PromptSectionConfig(
|
| 136 |
-
enabled=True,
|
| 137 |
-
template="{user_message}",
|
| 138 |
-
columns=["user_message"],
|
| 139 |
-
prefix="<|im_start|>user\n", suffix="<|im_end|>\n"
|
| 140 |
-
),
|
| 141 |
-
"assistant": PromptSectionConfig(
|
| 142 |
-
enabled=True,
|
| 143 |
-
template="{assistant_message}",
|
| 144 |
-
columns=["assistant_message", "output", "response"],
|
| 145 |
-
prefix="<|im_start|>assistant\n", suffix="<|im_end|>"
|
| 146 |
-
),
|
| 147 |
-
"section_order": ["system", "user", "assistant"]
|
| 148 |
-
},
|
| 149 |
-
"llama3": {
|
| 150 |
-
"system": PromptSectionConfig(
|
| 151 |
-
enabled=True,
|
| 152 |
-
template="{system_message}",
|
| 153 |
-
columns=["system_message"],
|
| 154 |
-
prefix="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n",
|
| 155 |
-
suffix="<|eot_id|>"
|
| 156 |
-
),
|
| 157 |
-
"user": PromptSectionConfig(
|
| 158 |
-
enabled=True,
|
| 159 |
-
template="{user_message}",
|
| 160 |
-
columns=["user_message", "question"],
|
| 161 |
-
prefix="<|start_header_id|>user<|end_header_id|>\n\n",
|
| 162 |
-
suffix="<|eot_id|>"
|
| 163 |
-
),
|
| 164 |
-
"assistant": PromptSectionConfig(
|
| 165 |
-
enabled=True,
|
| 166 |
-
template="{assistant_message}",
|
| 167 |
-
columns=["assistant_message", "output", "response"],
|
| 168 |
-
prefix="<|start_header_id|>assistant<|end_header_id|>\n\n",
|
| 169 |
-
suffix="<|eot_id|>"
|
| 170 |
-
),
|
| 171 |
-
"section_order": ["system", "user", "assistant"]
|
| 172 |
-
},
|
| 173 |
-
"mistral": {
|
| 174 |
-
"system": PromptSectionConfig(
|
| 175 |
-
enabled=True,
|
| 176 |
-
template="{system_message}",
|
| 177 |
-
columns=["system_message"],
|
| 178 |
-
prefix="[INST] ", suffix=" "
|
| 179 |
-
),
|
| 180 |
-
"user": PromptSectionConfig(
|
| 181 |
-
enabled=True,
|
| 182 |
-
template="{user_message}",
|
| 183 |
-
columns=["user_message", "question"],
|
| 184 |
-
prefix="", suffix=" [/INST]"
|
| 185 |
-
),
|
| 186 |
-
"assistant": PromptSectionConfig(
|
| 187 |
-
enabled=True,
|
| 188 |
-
template="{assistant_message}",
|
| 189 |
-
columns=["assistant_message", "output"],
|
| 190 |
-
prefix="", suffix="</s>"
|
| 191 |
-
),
|
| 192 |
-
"section_order": ["system", "user", "assistant"]
|
| 193 |
-
},
|
| 194 |
-
"vicuna": {
|
| 195 |
-
"system": PromptSectionConfig(
|
| 196 |
-
enabled=True,
|
| 197 |
-
template="A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
| 198 |
-
prefix="", suffix=" "
|
| 199 |
-
),
|
| 200 |
-
"user": PromptSectionConfig(
|
| 201 |
-
enabled=True,
|
| 202 |
-
template="{user_message}",
|
| 203 |
-
columns=["user_message", "question"],
|
| 204 |
-
prefix="USER: ", suffix=" "
|
| 205 |
-
),
|
| 206 |
-
"assistant": PromptSectionConfig(
|
| 207 |
-
enabled=True,
|
| 208 |
-
template="{assistant_message}",
|
| 209 |
-
columns=["assistant_message", "output"],
|
| 210 |
-
prefix="ASSISTANT: ", suffix=""
|
| 211 |
-
),
|
| 212 |
-
"section_order": ["system", "user", "assistant"]
|
| 213 |
-
},
|
| 214 |
-
"phi3": {
|
| 215 |
-
"system": PromptSectionConfig(
|
| 216 |
-
enabled=True,
|
| 217 |
-
template="{system_message}",
|
| 218 |
-
columns=["system_message"],
|
| 219 |
-
prefix="<|system|>\n", suffix="<|end|>\n"
|
| 220 |
-
),
|
| 221 |
-
"user": PromptSectionConfig(
|
| 222 |
-
enabled=True,
|
| 223 |
-
template="{user_message}",
|
| 224 |
-
columns=["user_message", "question"],
|
| 225 |
-
prefix="<|user|>\n", suffix="<|end|>\n"
|
| 226 |
-
),
|
| 227 |
-
"assistant": PromptSectionConfig(
|
| 228 |
-
enabled=True,
|
| 229 |
-
template="{assistant_message}",
|
| 230 |
-
columns=["assistant_message", "output"],
|
| 231 |
-
prefix="<|assistant|>\n", suffix="<|end|>"
|
| 232 |
-
),
|
| 233 |
-
"section_order": ["system", "user", "assistant"]
|
| 234 |
-
},
|
| 235 |
-
"reasoning": {
|
| 236 |
-
"system": PromptSectionConfig(
|
| 237 |
-
enabled=True,
|
| 238 |
-
template="You are a helpful AI assistant that thinks step by step before responding.",
|
| 239 |
-
prefix="", suffix="\n\n"
|
| 240 |
-
),
|
| 241 |
-
"context": PromptSectionConfig(
|
| 242 |
-
enabled=True,
|
| 243 |
-
template="{context}",
|
| 244 |
-
columns=["context", "passages", "background"],
|
| 245 |
-
prefix="Context:\n", suffix="\n\n"
|
| 246 |
-
),
|
| 247 |
-
"user": PromptSectionConfig(
|
| 248 |
-
enabled=True,
|
| 249 |
-
template="{question}",
|
| 250 |
-
columns=["question", "query", "user_message"],
|
| 251 |
-
prefix="Question: ", suffix="\n\n"
|
| 252 |
-
),
|
| 253 |
-
"reasoning": PromptSectionConfig(
|
| 254 |
-
enabled=True,
|
| 255 |
-
template="{reasoning}",
|
| 256 |
-
columns=["reasoning", "thinking", "chain_of_thought"],
|
| 257 |
-
prefix="Reasoning:\n", suffix="\n\n"
|
| 258 |
-
),
|
| 259 |
-
"assistant": PromptSectionConfig(
|
| 260 |
-
enabled=True,
|
| 261 |
-
template="{answer}",
|
| 262 |
-
columns=["answer", "output", "response"],
|
| 263 |
-
prefix="Answer: ", suffix=""
|
| 264 |
-
),
|
| 265 |
-
"section_order": ["system", "context", "user", "reasoning", "assistant"]
|
| 266 |
-
}
|
| 267 |
-
}
|
| 268 |
-
return presets.get(format_name, {})
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
# ============================================
|
| 272 |
-
# COLUMN MAPPING CONFIGURATION
|
| 273 |
-
# ============================================
|
| 274 |
-
|
| 275 |
-
class ColumnMappingConfig(BaseModel):
|
| 276 |
-
"""Maps dataset columns to training roles"""
|
| 277 |
-
# Primary text columns
|
| 278 |
-
text_column: Optional[str] = Field(None, description="Main text column (for causal LM)")
|
| 279 |
-
input_column: Optional[str] = Field(None, description="Input text column")
|
| 280 |
-
output_column: Optional[str] = Field(None, description="Output/target column")
|
| 281 |
-
|
| 282 |
-
# Multi-column support
|
| 283 |
-
instruction_column: Optional[str] = Field(None, description="Instruction column")
|
| 284 |
-
question_column: Optional[str] = Field(None, description="Question column")
|
| 285 |
-
answer_column: Optional[str] = Field(None, description="Answer column")
|
| 286 |
-
context_column: Optional[str] = Field(None, description="Context/passages column")
|
| 287 |
-
reasoning_column: Optional[str] = Field(None, description="Reasoning/CoT column")
|
| 288 |
-
|
| 289 |
-
# Classification specific
|
| 290 |
-
label_column: Optional[str] = Field(None, description="Label column for classification")
|
| 291 |
-
label_mapping: Optional[Dict[str, int]] = Field(None, description="Label to ID mapping")
|
| 292 |
-
|
| 293 |
-
# NER/Token classification
|
| 294 |
-
tokens_column: Optional[str] = Field(None, description="Tokens column (for NER)")
|
| 295 |
-
tags_column: Optional[str] = Field(None, description="NER tags column")
|
| 296 |
-
ner_tags_mapping: Optional[Dict[str, str]] = Field(None, description="NER tag to label mapping")
|
| 297 |
-
|
| 298 |
-
# QA specific
|
| 299 |
-
title_column: Optional[str] = Field(None, description="Title column for QA context")
|
| 300 |
-
id_column: Optional[str] = Field(None, description="ID column")
|
| 301 |
-
answers_column: Optional[str] = Field(None, description="Answers column (list of answers)")
|
| 302 |
-
start_position_column: Optional[str] = Field(None, description="Start position column")
|
| 303 |
-
end_position_column: Optional[str] = Field(None, description="End position column")
|
| 304 |
-
|
| 305 |
-
# Additional context
|
| 306 |
-
metadata_columns: List[str] = Field(
|
| 307 |
-
default_factory=list,
|
| 308 |
-
description="Additional columns to include as metadata"
|
| 309 |
-
)
|
| 310 |
-
|
| 311 |
-
# Column transformations
|
| 312 |
-
column_transforms: Dict[str, str] = Field(
|
| 313 |
-
default_factory=dict,
|
| 314 |
-
description="Transformations to apply: column -> transform_type"
|
| 315 |
-
)
|
| 316 |
-
|
| 317 |
-
# Custom column aliases
|
| 318 |
-
column_aliases: Dict[str, str] = Field(
|
| 319 |
-
default_factory=dict,
|
| 320 |
-
description="Map custom column names to standard names"
|
| 321 |
-
)
|
| 322 |
-
|
| 323 |
-
def get_effective_column(self, role: str) -> Optional[str]:
|
| 324 |
-
"""Get the effective column name for a given role, considering aliases"""
|
| 325 |
-
role_to_column = {
|
| 326 |
-
"text": self.text_column,
|
| 327 |
-
"input": self.input_column,
|
| 328 |
-
"output": self.output_column,
|
| 329 |
-
"instruction": self.instruction_column,
|
| 330 |
-
"question": self.question_column,
|
| 331 |
-
"answer": self.answer_column,
|
| 332 |
-
"context": self.context_column,
|
| 333 |
-
"reasoning": self.reasoning_column,
|
| 334 |
-
"label": self.label_column,
|
| 335 |
-
"tokens": self.tokens_column,
|
| 336 |
-
"tags": self.tags_column,
|
| 337 |
-
"title": self.title_column,
|
| 338 |
-
"id": self.id_column,
|
| 339 |
-
"answers": self.answers_column,
|
| 340 |
-
}
|
| 341 |
-
column = role_to_column.get(role)
|
| 342 |
-
if column and column in self.column_aliases:
|
| 343 |
-
return self.column_aliases[column]
|
| 344 |
-
return column
|
| 345 |
-
|
| 346 |
|
| 347 |
-
# ============================================
|
| 348 |
-
# DATASET CONFIGURATION (ENHANCED)
|
| 349 |
-
# ============================================
|
| 350 |
-
|
| 351 |
-
class DatasetSplitConfig(BaseModel):
|
| 352 |
-
"""Configuration for a single dataset split"""
|
| 353 |
-
name: str = Field(default="train", description="Split name (train, validation, test)")
|
| 354 |
-
enabled: bool = Field(default=True)
|
| 355 |
-
max_samples: Optional[int] = Field(None, description="Maximum samples to use")
|
| 356 |
-
shuffle: bool = Field(default=True)
|
| 357 |
-
seed: int = Field(default=42)
|
| 358 |
-
stratify: bool = Field(default=False, description="Stratified sampling")
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
class DatasetConfig(BaseModel):
|
| 362 |
-
"""Enhanced dataset configuration with full control"""
|
| 363 |
-
# Source configuration
|
| 364 |
-
source: str = Field(default="huggingface", description="Dataset source: huggingface, local, upload")
|
| 365 |
-
name: Optional[str] = Field(None, description="HuggingFace dataset name")
|
| 366 |
-
config: Optional[str] = Field(None, description="Dataset config/subset name")
|
| 367 |
-
revision: Optional[str] = Field(None, description="Dataset revision/branch")
|
| 368 |
-
|
| 369 |
-
# Split configuration
|
| 370 |
-
splits: List[DatasetSplitConfig] = Field(
|
| 371 |
-
default_factory=lambda: [DatasetSplitConfig(name="train")],
|
| 372 |
-
description="Dataset splits to use"
|
| 373 |
-
)
|
| 374 |
-
train_split: str = Field(default="train", description="Training split name")
|
| 375 |
-
validation_split: Optional[str] = Field(None, description="Validation split name")
|
| 376 |
-
test_split: Optional[str] = Field(None, description="Test split name")
|
| 377 |
-
|
| 378 |
-
# Split generation
|
| 379 |
-
validation_split_ratio: float = Field(
|
| 380 |
-
default=0.1,
|
| 381 |
-
ge=0.0,
|
| 382 |
-
le=0.5,
|
| 383 |
-
description="Ratio for auto-generating validation split"
|
| 384 |
-
)
|
| 385 |
-
generate_validation: bool = Field(
|
| 386 |
-
default=True,
|
| 387 |
-
description="Auto-generate validation split if not provided"
|
| 388 |
-
)
|
| 389 |
-
|
| 390 |
-
# Column mapping
|
| 391 |
-
column_mapping: ColumnMappingConfig = Field(
|
| 392 |
-
default_factory=ColumnMappingConfig,
|
| 393 |
-
description="Map dataset columns to training roles"
|
| 394 |
-
)
|
| 395 |
-
|
| 396 |
-
# Prompt template
|
| 397 |
-
prompt_template: Optional[PromptTemplateConfig] = Field(
|
| 398 |
-
default=None,
|
| 399 |
-
description="Prompt structure configuration"
|
| 400 |
-
)
|
| 401 |
-
|
| 402 |
-
# Data processing
|
| 403 |
-
max_length: int = Field(default=512, description="Max sequence length")
|
| 404 |
-
max_target_length: int = Field(default=128, description="Max target length for seq2seq")
|
| 405 |
-
truncate: bool = Field(default=True, description="Truncate sequences exceeding max_length")
|
| 406 |
-
|
| 407 |
-
# Filtering
|
| 408 |
-
filter_conditions: List[Dict[str, Any]] = Field(
|
| 409 |
-
default_factory=list,
|
| 410 |
-
description="Filter conditions: [{column, operator, value}]"
|
| 411 |
-
)
|
| 412 |
-
min_text_length: Optional[int] = Field(None, description="Minimum text length")
|
| 413 |
-
max_text_length: Optional[int] = Field(None, description="Maximum text length")
|
| 414 |
-
|
| 415 |
-
# Augmentation
|
| 416 |
-
augmentation_enabled: bool = Field(default=False)
|
| 417 |
-
augmentation_config: Dict[str, Any] = Field(
|
| 418 |
-
default_factory=dict,
|
| 419 |
-
description="Data augmentation configuration"
|
| 420 |
-
)
|
| 421 |
-
|
| 422 |
-
# Streaming (for large datasets)
|
| 423 |
-
streaming: bool = Field(
|
| 424 |
-
default=False,
|
| 425 |
-
description="Use streaming mode for large datasets"
|
| 426 |
-
)
|
| 427 |
-
streaming_buffer_size: int = Field(default=10000)
|
| 428 |
-
|
| 429 |
-
# Caching
|
| 430 |
-
cache_dir: Optional[str] = Field(None, description="Cache directory")
|
| 431 |
-
num_proc: int = Field(default=4, description="Number of processes for data loading")
|
| 432 |
-
|
| 433 |
-
# Local file support
|
| 434 |
-
path: Optional[str] = Field(None, description="Local file path")
|
| 435 |
-
file_type: Optional[str] = Field(None, description="File type: json, jsonl, csv, parquet, text")
|
| 436 |
-
|
| 437 |
-
# Data validation
|
| 438 |
-
validate_data: bool = Field(default=True, description="Validate data before training")
|
| 439 |
-
validation_sample_size: int = Field(default=100, description="Samples to validate")
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
# ============================================
|
| 443 |
-
# TRAINING ARGUMENTS CONFIGURATION
|
| 444 |
-
# ============================================
|
| 445 |
|
| 446 |
-
class
|
| 447 |
-
"""
|
| 448 |
-
epochs: int = Field(default=3
|
| 449 |
-
batch_size: int = Field(default=
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
weight_decay: float = Field(default=0.01, ge=0.0, le=1.0)
|
| 453 |
-
warmup_ratio: float = Field(default=0.1, ge=0.0, le=1.0)
|
| 454 |
-
warmup_steps: int = Field(default=0, ge=0)
|
| 455 |
-
max_grad_norm: float = Field(default=1.0)
|
| 456 |
-
logging_steps: int = Field(default=10, ge=1)
|
| 457 |
-
eval_steps: int = Field(default=500, ge=1)
|
| 458 |
-
save_steps: int = Field(default=500, ge=1)
|
| 459 |
-
save_total_limit: int = Field(default=3, ge=1, le=10)
|
| 460 |
-
gradient_accumulation_steps: int = Field(default=1, ge=1, le=128)
|
| 461 |
-
fp16: bool = Field(default=False)
|
| 462 |
-
bf16: bool = Field(default=False)
|
| 463 |
-
gradient_checkpointing: bool = Field(default=False)
|
| 464 |
-
optimizer: str = Field(default="adamw_torch")
|
| 465 |
-
lr_scheduler_type: str = Field(default="cosine")
|
| 466 |
-
report_to: str = Field(default="none")
|
| 467 |
-
seed: int = Field(default=42)
|
| 468 |
-
|
| 469 |
-
# Advanced options
|
| 470 |
-
eval_strategy: str = Field(default="steps")
|
| 471 |
-
load_best_model_at_end: bool = Field(default=True)
|
| 472 |
-
metric_for_best_model: str = Field(default="eval_loss")
|
| 473 |
-
greater_is_better: bool = Field(default=False)
|
| 474 |
-
|
| 475 |
-
# Memory optimization
|
| 476 |
-
optim: str = Field(default="adamw_torch")
|
| 477 |
-
ddp_find_unused_parameters: bool = Field(default=False)
|
| 478 |
-
|
| 479 |
|
| 480 |
|
| 481 |
-
class
|
| 482 |
-
"""PEFT
|
| 483 |
enabled: bool = Field(default=True)
|
| 484 |
-
method: str = Field(default="lora"
|
| 485 |
-
r: int = Field(default=16
|
| 486 |
-
alpha: int = Field(default=32
|
| 487 |
-
dropout: float = Field(default=0.05
|
| 488 |
-
target_modules: List[str] = Field(default=["q_proj", "v_proj"])
|
| 489 |
-
bias: str = Field(default="none")
|
| 490 |
-
modules_to_save: List[str] = Field(default_factory=list)
|
| 491 |
-
|
| 492 |
-
# AdaLoRA specific
|
| 493 |
-
init_r: int = Field(default=12)
|
| 494 |
-
t_init: int = Field(default=200)
|
| 495 |
-
t_final: int = Field(default=1000)
|
| 496 |
-
|
| 497 |
-
# Prefix Tuning specific
|
| 498 |
-
num_virtual_tokens: int = Field(default=20)
|
| 499 |
-
|
| 500 |
-
# Prompt Tuning specific
|
| 501 |
-
num_tokens: int = Field(default=20)
|
| 502 |
-
token_init: bool = Field(default=True)
|
| 503 |
|
| 504 |
|
| 505 |
-
class
|
| 506 |
-
"""
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
private: bool = Field(default=False)
|
| 510 |
-
save_strategy: str = Field(default="steps")
|
| 511 |
-
output_dir: Optional[str] = Field(None)
|
| 512 |
|
| 513 |
|
| 514 |
-
class
|
| 515 |
-
"""
|
| 516 |
-
name: str = Field(
|
| 517 |
-
description: Optional[str] = Field(None)
|
| 518 |
task_type: str = Field(default="causal-lm")
|
| 519 |
base_model: str = Field(..., description="HuggingFace model ID")
|
| 520 |
-
dataset:
|
| 521 |
-
training_args:
|
| 522 |
-
peft_config: Optional[
|
| 523 |
-
|
| 524 |
-
tags: List[str] = Field(default_factory=list)
|
| 525 |
-
priority: int = Field(default=5, ge=1, le=20)
|
| 526 |
-
|
| 527 |
-
@validator('task_type')
|
| 528 |
-
def validate_task_type(cls, v):
|
| 529 |
-
valid_types = [
|
| 530 |
-
"causal-lm", "seq2seq", "token-classification",
|
| 531 |
-
"sequence-classification", "question-answering",
|
| 532 |
-
"summarization", "translation", "text-classification",
|
| 533 |
-
"masked-lm", "vision-classification", "audio-classification",
|
| 534 |
-
"reasoning"
|
| 535 |
-
]
|
| 536 |
-
if v not in valid_types:
|
| 537 |
-
raise ValueError(f"Invalid task_type. Must be one of: {valid_types}")
|
| 538 |
-
return v
|
| 539 |
|
| 540 |
|
| 541 |
class TrainingJobResponse(BaseModel):
|
|
@@ -572,11 +103,11 @@ class DatasetPreviewResponse(BaseModel):
|
|
| 572 |
dataset_name: str
|
| 573 |
config: Optional[str]
|
| 574 |
splits: List[str]
|
| 575 |
-
columns: List[
|
| 576 |
sample_data: List[Dict[str, Any]]
|
| 577 |
total_rows: Optional[int]
|
| 578 |
detected_task_types: List[str]
|
| 579 |
-
suggested_column_mapping:
|
| 580 |
|
| 581 |
|
| 582 |
# Global queue instance
|
|
@@ -615,29 +146,28 @@ async def preview_dataset(
|
|
| 615 |
# Get column info
|
| 616 |
columns = []
|
| 617 |
for col_name, col_type in ds.features.items():
|
| 618 |
-
|
| 619 |
-
"name": col_name,
|
| 620 |
-
"type": str(col_type),
|
| 621 |
-
"dtype": type(col_type).__name__,
|
| 622 |
-
}
|
| 623 |
-
# Detect if it's a label column
|
| 624 |
-
if hasattr(col_type, 'names'):
|
| 625 |
-
col_info["labels"] = col_type.names
|
| 626 |
-
col_info["num_labels"] = len(col_type.names)
|
| 627 |
-
columns.append(col_info)
|
| 628 |
|
| 629 |
# Get sample data
|
| 630 |
sample_data = []
|
| 631 |
for i in range(min(rows, len(ds))):
|
| 632 |
-
sample_data.append(ds[i])
|
| 633 |
|
| 634 |
# Detect task type and suggest column mapping
|
| 635 |
-
detected_tasks, suggested_mapping = detect_task_and_mapping(ds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 636 |
|
| 637 |
return DatasetPreviewResponse(
|
| 638 |
dataset_name=dataset_name,
|
| 639 |
config=config,
|
| 640 |
-
splits=
|
| 641 |
columns=columns,
|
| 642 |
sample_data=sample_data,
|
| 643 |
total_rows=len(ds),
|
|
@@ -646,151 +176,77 @@ async def preview_dataset(
|
|
| 646 |
)
|
| 647 |
|
| 648 |
except Exception as e:
|
|
|
|
| 649 |
raise HTTPException(status_code=400, detail=f"Error loading dataset: {str(e)}")
|
| 650 |
|
| 651 |
|
| 652 |
-
def detect_task_and_mapping(dataset
|
| 653 |
"""Detect suitable task types and suggest column mappings."""
|
| 654 |
-
|
| 655 |
-
col_names_original =
|
| 656 |
detected_tasks = []
|
| 657 |
-
mapping =
|
| 658 |
|
| 659 |
-
#
|
| 660 |
-
|
| 661 |
-
text_cols = [c for c in columns if "text" in c["name"].lower()]
|
| 662 |
|
| 663 |
-
|
| 664 |
-
detected_tasks.append("text-classification")
|
| 665 |
-
mapping.text_column = text_cols[0]["name"]
|
| 666 |
-
mapping.label_column = label_cols[0]["name"]
|
| 667 |
-
if label_cols[0].get("labels"):
|
| 668 |
-
mapping.label_mapping = {name: i for i, name in enumerate(label_cols[0]["labels"])}
|
| 669 |
|
| 670 |
-
#
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
|
|
|
| 674 |
|
| 675 |
-
|
|
|
|
| 676 |
detected_tasks.append("question-answering")
|
| 677 |
-
mapping
|
| 678 |
-
mapping
|
| 679 |
-
if
|
| 680 |
-
mapping
|
| 681 |
|
| 682 |
-
#
|
| 683 |
-
|
| 684 |
-
output_cols = [c for c in columns if "output" in c["name"].lower() or "response" in c["name"].lower()]
|
| 685 |
-
|
| 686 |
-
if instruction_cols and output_cols:
|
| 687 |
detected_tasks.append("causal-lm")
|
| 688 |
-
mapping
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
if
|
| 696 |
-
detected_tasks
|
| 697 |
-
|
| 698 |
-
mapping
|
| 699 |
-
|
| 700 |
-
# Check for NER
|
| 701 |
-
tokens_cols = [c for c in columns if "token" in c["name"].lower() or "word" in c["name"].lower()]
|
| 702 |
-
tags_cols = [c for c in columns if "tag" in c["name"].lower() or "ner" in c["name"].lower()]
|
| 703 |
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
|
|
|
| 708 |
|
| 709 |
-
# Default
|
| 710 |
if not detected_tasks:
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
|
|
|
|
|
|
| 715 |
|
| 716 |
return detected_tasks, mapping
|
| 717 |
|
| 718 |
|
| 719 |
-
# ============================================
|
| 720 |
-
# PROMPT TEMPLATE ENDPOINTS
|
| 721 |
-
# ============================================
|
| 722 |
-
|
| 723 |
-
@router.get("/prompt-templates")
|
| 724 |
-
async def get_prompt_templates():
|
| 725 |
-
"""Get available prompt template presets."""
|
| 726 |
-
return {
|
| 727 |
-
"presets": [
|
| 728 |
-
{
|
| 729 |
-
"id": "none",
|
| 730 |
-
"name": "None (Raw Text)",
|
| 731 |
-
"description": "Use dataset text directly without formatting"
|
| 732 |
-
},
|
| 733 |
-
{
|
| 734 |
-
"id": "alpaca",
|
| 735 |
-
"name": "Alpaca Format",
|
| 736 |
-
"description": "Instruction-Input-Output format for instruction tuning"
|
| 737 |
-
},
|
| 738 |
-
{
|
| 739 |
-
"id": "chatml",
|
| 740 |
-
"name": "ChatML",
|
| 741 |
-
"description": "ChatML format used by various models"
|
| 742 |
-
},
|
| 743 |
-
{
|
| 744 |
-
"id": "llama3",
|
| 745 |
-
"name": "Llama 3",
|
| 746 |
-
"description": "Llama 3 instruction format"
|
| 747 |
-
},
|
| 748 |
-
{
|
| 749 |
-
"id": "mistral",
|
| 750 |
-
"name": "Mistral",
|
| 751 |
-
"description": "Mistral/Vicuna instruction format"
|
| 752 |
-
},
|
| 753 |
-
{
|
| 754 |
-
"id": "vicuna",
|
| 755 |
-
"name": "Vicuna",
|
| 756 |
-
"description": "Vicuna chat format"
|
| 757 |
-
},
|
| 758 |
-
{
|
| 759 |
-
"id": "phi3",
|
| 760 |
-
"name": "Phi-3",
|
| 761 |
-
"description": "Microsoft Phi-3 format"
|
| 762 |
-
},
|
| 763 |
-
{
|
| 764 |
-
"id": "reasoning",
|
| 765 |
-
"name": "Reasoning/CoT",
|
| 766 |
-
"description": "Chain-of-thought reasoning format with explicit thinking"
|
| 767 |
-
}
|
| 768 |
-
]
|
| 769 |
-
}
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
@router.get("/prompt-templates/{template_id}")
|
| 773 |
-
async def get_prompt_template(template_id: str):
|
| 774 |
-
"""Get specific prompt template configuration."""
|
| 775 |
-
config = PromptTemplateConfig(instruction_format=template_id)
|
| 776 |
-
preset = config.get_template_for_format(template_id)
|
| 777 |
-
|
| 778 |
-
if not preset and template_id != "none":
|
| 779 |
-
raise HTTPException(status_code=404, detail="Template not found")
|
| 780 |
-
|
| 781 |
-
return {
|
| 782 |
-
"id": template_id,
|
| 783 |
-
"config": preset
|
| 784 |
-
}
|
| 785 |
-
|
| 786 |
-
|
| 787 |
# ============================================
|
| 788 |
# TRAINING JOB ENDPOINTS
|
| 789 |
# ============================================
|
| 790 |
|
| 791 |
@router.post("/start", response_model=TrainingJobResponse)
|
| 792 |
async def start_training(
|
| 793 |
-
request:
|
| 794 |
db: AsyncSession = Depends(get_db)
|
| 795 |
):
|
| 796 |
"""Start a new training job."""
|
|
@@ -801,37 +257,50 @@ async def start_training(
|
|
| 801 |
training_job = TrainingJob(
|
| 802 |
job_id=job_id,
|
| 803 |
name=request.name,
|
| 804 |
-
description=request.description,
|
| 805 |
task_type=request.task_type,
|
| 806 |
base_model=request.base_model,
|
| 807 |
-
output_model_name=request.output_config.hub_model_id,
|
| 808 |
-
dataset_source=request.dataset.source,
|
| 809 |
dataset_name=request.dataset.name,
|
| 810 |
-
dataset_config=request.dataset.config,
|
| 811 |
dataset_split=request.dataset.train_split,
|
| 812 |
-
training_args=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 813 |
peft_config=request.peft_config.dict() if request.peft_config else None,
|
| 814 |
status=JobStatus.PENDING.value,
|
| 815 |
total_epochs=request.training_args.epochs,
|
| 816 |
-
tags=request.tags
|
| 817 |
)
|
| 818 |
|
| 819 |
db.add(training_job)
|
| 820 |
await db.commit()
|
| 821 |
|
| 822 |
-
# Build full config
|
| 823 |
config = {
|
| 824 |
"job_id": job_id,
|
| 825 |
"task_type": request.task_type,
|
| 826 |
"base_model": request.base_model,
|
| 827 |
-
"
|
| 828 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 829 |
"peft_config": request.peft_config.dict() if request.peft_config else None,
|
| 830 |
-
"
|
| 831 |
}
|
| 832 |
|
| 833 |
# Submit to queue
|
| 834 |
-
priority = JobPriority
|
| 835 |
await queue.submit(config, priority=priority)
|
| 836 |
|
| 837 |
# Update status
|
|
@@ -854,9 +323,6 @@ async def get_job_status(
|
|
| 854 |
db: AsyncSession = Depends(get_db)
|
| 855 |
):
|
| 856 |
"""Get status of a training job."""
|
| 857 |
-
queue = get_queue()
|
| 858 |
-
queue_status = await queue.get_status(job_id)
|
| 859 |
-
|
| 860 |
result = await db.execute(
|
| 861 |
select(TrainingJob).where(TrainingJob.job_id == job_id)
|
| 862 |
)
|
|
@@ -865,20 +331,15 @@ async def get_job_status(
|
|
| 865 |
if not job:
|
| 866 |
raise HTTPException(status_code=404, detail="Job not found")
|
| 867 |
|
| 868 |
-
if queue_status:
|
| 869 |
-
job.status = queue_status.get("status", job.status)
|
| 870 |
-
if queue_status.get("progress"):
|
| 871 |
-
job.progress = queue_status["progress"]
|
| 872 |
-
|
| 873 |
return JobStatusResponse(
|
| 874 |
job_id=job.job_id,
|
| 875 |
name=job.name,
|
| 876 |
status=job.status,
|
| 877 |
-
progress=job.progress,
|
| 878 |
-
current_epoch=job.current_epoch,
|
| 879 |
-
total_epochs=job.total_epochs,
|
| 880 |
-
current_step=job.current_step,
|
| 881 |
-
total_steps=job.total_steps,
|
| 882 |
train_loss=job.train_loss,
|
| 883 |
eval_loss=job.eval_loss,
|
| 884 |
metrics=job.metrics or {},
|
|
@@ -891,7 +352,7 @@ async def get_job_status(
|
|
| 891 |
)
|
| 892 |
|
| 893 |
|
| 894 |
-
@router.get("/jobs"
|
| 895 |
async def list_jobs(
|
| 896 |
status: Optional[str] = None,
|
| 897 |
limit: int = 50,
|
|
@@ -909,28 +370,30 @@ async def list_jobs(
|
|
| 909 |
result = await db.execute(query)
|
| 910 |
jobs = result.scalars().all()
|
| 911 |
|
| 912 |
-
return
|
| 913 |
-
|
| 914 |
-
|
| 915 |
-
|
| 916 |
-
|
| 917 |
-
|
| 918 |
-
|
| 919 |
-
|
| 920 |
-
|
| 921 |
-
|
| 922 |
-
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
|
| 927 |
-
|
| 928 |
-
|
| 929 |
-
|
| 930 |
-
|
| 931 |
-
|
| 932 |
-
|
| 933 |
-
|
|
|
|
|
|
|
| 934 |
|
| 935 |
|
| 936 |
@router.post("/cancel/{job_id}")
|
|
@@ -942,9 +405,6 @@ async def cancel_job(
|
|
| 942 |
queue = get_queue()
|
| 943 |
cancelled = await queue.cancel_job(job_id)
|
| 944 |
|
| 945 |
-
if not cancelled:
|
| 946 |
-
raise HTTPException(status_code=400, detail="Cannot cancel job")
|
| 947 |
-
|
| 948 |
result = await db.execute(
|
| 949 |
select(TrainingJob).where(TrainingJob.job_id == job_id)
|
| 950 |
)
|
|
@@ -957,12 +417,6 @@ async def cancel_job(
|
|
| 957 |
return {"message": f"Job {job_id} cancelled", "success": True}
|
| 958 |
|
| 959 |
|
| 960 |
-
@router.get("/templates")
|
| 961 |
-
async def get_training_templates():
|
| 962 |
-
"""Get available training configuration templates."""
|
| 963 |
-
return TRAINING_TEMPLATES
|
| 964 |
-
|
| 965 |
-
|
| 966 |
@router.get("/queue/status")
|
| 967 |
async def get_queue_status():
|
| 968 |
"""Get current queue status."""
|
|
@@ -974,6 +428,12 @@ async def get_queue_status():
|
|
| 974 |
}
|
| 975 |
|
| 976 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 977 |
@router.get("/metrics/{job_id}")
|
| 978 |
async def get_job_metrics(
|
| 979 |
job_id: str,
|
|
@@ -992,8 +452,7 @@ async def get_job_metrics(
|
|
| 992 |
"job_id": job_id,
|
| 993 |
"train_loss": job.train_loss,
|
| 994 |
"eval_loss": job.eval_loss,
|
| 995 |
-
"metrics": job.metrics
|
| 996 |
-
"learning_rate": job.learning_rate
|
| 997 |
}
|
| 998 |
|
| 999 |
|
|
@@ -1017,4 +476,25 @@ async def delete_job(
|
|
| 1017 |
await db.delete(job)
|
| 1018 |
await db.commit()
|
| 1019 |
|
| 1020 |
-
return {"message": f"Job {job_id} deleted", "success": True}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
# ============================================
|
| 26 |
+
# SIMPLIFIED REQUEST MODELS (matching dashboard form)
|
| 27 |
# ============================================
|
| 28 |
|
| 29 |
+
class DatasetConfigSimple(BaseModel):
|
| 30 |
+
"""Simplified dataset config matching the dashboard form."""
|
| 31 |
+
name: str = Field(..., description="HuggingFace dataset name")
|
| 32 |
+
train_split: str = Field(default="train")
|
| 33 |
+
validation_split: Optional[str] = Field(default="validation")
|
| 34 |
+
column_mapping: Dict[str, str] = Field(default_factory=dict, description="Maps roles to column names: {text: 'col1', input: 'col2'}")
|
| 35 |
+
max_length: int = Field(default=512)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
+
class TrainingArgsSimple(BaseModel):
|
| 39 |
+
"""Simplified training args matching the dashboard form."""
|
| 40 |
+
epochs: int = Field(default=3)
|
| 41 |
+
batch_size: int = Field(default=1)
|
| 42 |
+
learning_rate: float = Field(default=5e-5)
|
| 43 |
+
warmup_steps: int = Field(default=100)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
+
class PEFTConfigSimple(BaseModel):
|
| 47 |
+
"""Simplified PEFT config matching the dashboard form."""
|
| 48 |
enabled: bool = Field(default=True)
|
| 49 |
+
method: str = Field(default="lora")
|
| 50 |
+
r: int = Field(default=16)
|
| 51 |
+
alpha: int = Field(default=32)
|
| 52 |
+
dropout: float = Field(default=0.05)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
+
class PromptTemplateSimple(BaseModel):
|
| 56 |
+
"""Simplified prompt template matching the dashboard form."""
|
| 57 |
+
preset: str = Field(default="none", description="Template preset: none, alpaca, chatml, llama3, mistral, vicuna, phi3, reasoning")
|
| 58 |
+
custom: Optional[Dict[str, Any]] = Field(default=None, description="Custom template sections")
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
|
| 61 |
+
class TrainingRequestSimple(BaseModel):
|
| 62 |
+
"""Simplified training request matching the dashboard form."""
|
| 63 |
+
name: str = Field(default="training-job")
|
|
|
|
| 64 |
task_type: str = Field(default="causal-lm")
|
| 65 |
base_model: str = Field(..., description="HuggingFace model ID")
|
| 66 |
+
dataset: DatasetConfigSimple
|
| 67 |
+
training_args: TrainingArgsSimple = Field(default_factory=TrainingArgsSimple)
|
| 68 |
+
peft_config: Optional[PEFTConfigSimple] = Field(None)
|
| 69 |
+
prompt_template: Optional[PromptTemplateSimple] = Field(None, description="Prompt template configuration")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
|
| 72 |
class TrainingJobResponse(BaseModel):
|
|
|
|
| 103 |
dataset_name: str
|
| 104 |
config: Optional[str]
|
| 105 |
splits: List[str]
|
| 106 |
+
columns: List[str]
|
| 107 |
sample_data: List[Dict[str, Any]]
|
| 108 |
total_rows: Optional[int]
|
| 109 |
detected_task_types: List[str]
|
| 110 |
+
suggested_column_mapping: Dict[str, str]
|
| 111 |
|
| 112 |
|
| 113 |
# Global queue instance
|
|
|
|
| 146 |
# Get column info
|
| 147 |
columns = []
|
| 148 |
for col_name, col_type in ds.features.items():
|
| 149 |
+
columns.append(col_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
# Get sample data
|
| 152 |
sample_data = []
|
| 153 |
for i in range(min(rows, len(ds))):
|
| 154 |
+
sample_data.append({k: str(v)[:100] if v else None for k, v in ds[i].items()})
|
| 155 |
|
| 156 |
# Detect task type and suggest column mapping
|
| 157 |
+
detected_tasks, suggested_mapping = detect_task_and_mapping(ds)
|
| 158 |
+
|
| 159 |
+
# Get all splits
|
| 160 |
+
try:
|
| 161 |
+
from datasets import load_dataset_builder
|
| 162 |
+
builder = load_dataset_builder(dataset_name, trust_remote_code=True)
|
| 163 |
+
splits = list(builder.info.splits.keys())
|
| 164 |
+
except:
|
| 165 |
+
splits = [split]
|
| 166 |
|
| 167 |
return DatasetPreviewResponse(
|
| 168 |
dataset_name=dataset_name,
|
| 169 |
config=config,
|
| 170 |
+
splits=splits,
|
| 171 |
columns=columns,
|
| 172 |
sample_data=sample_data,
|
| 173 |
total_rows=len(ds),
|
|
|
|
| 176 |
)
|
| 177 |
|
| 178 |
except Exception as e:
|
| 179 |
+
logger.error(f"Error loading dataset: {e}")
|
| 180 |
raise HTTPException(status_code=400, detail=f"Error loading dataset: {str(e)}")
|
| 181 |
|
| 182 |
|
| 183 |
+
def detect_task_and_mapping(dataset) -> tuple:
|
| 184 |
"""Detect suitable task types and suggest column mappings."""
|
| 185 |
+
col_names_lower = [c.lower() for c in dataset.column_names]
|
| 186 |
+
col_names_original = list(dataset.column_names)
|
| 187 |
detected_tasks = []
|
| 188 |
+
mapping = {}
|
| 189 |
|
| 190 |
+
# Build a mapping from lowercase to original
|
| 191 |
+
col_map = {c.lower(): c for c in col_names_original}
|
|
|
|
| 192 |
|
| 193 |
+
# Check for common patterns
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
+
# Text classification
|
| 196 |
+
if 'label' in col_names_lower and 'text' in col_names_lower:
|
| 197 |
+
detected_tasks.append("text-classification")
|
| 198 |
+
mapping['label'] = col_map['label']
|
| 199 |
+
mapping['text'] = col_map['text']
|
| 200 |
|
| 201 |
+
# QA
|
| 202 |
+
if 'question' in col_names_lower and 'answer' in col_names_lower:
|
| 203 |
detected_tasks.append("question-answering")
|
| 204 |
+
mapping['question'] = col_map['question']
|
| 205 |
+
mapping['answer'] = col_map.get('answer', col_map.get('answers', ''))
|
| 206 |
+
if 'context' in col_names_lower:
|
| 207 |
+
mapping['context'] = col_map['context']
|
| 208 |
|
| 209 |
+
# Instruction-output
|
| 210 |
+
if 'instruction' in col_names_lower:
|
|
|
|
|
|
|
|
|
|
| 211 |
detected_tasks.append("causal-lm")
|
| 212 |
+
mapping['instruction'] = col_map['instruction']
|
| 213 |
+
if 'input' in col_names_lower:
|
| 214 |
+
mapping['input'] = col_map['input']
|
| 215 |
+
if 'output' in col_names_lower:
|
| 216 |
+
mapping['output'] = col_map['output']
|
| 217 |
+
|
| 218 |
+
# Input-output
|
| 219 |
+
if 'input' in col_names_lower and 'output' in col_names_lower:
|
| 220 |
+
if 'causal-lm' not in detected_tasks:
|
| 221 |
+
detected_tasks.append("causal-lm")
|
| 222 |
+
mapping['input'] = col_map['input']
|
| 223 |
+
mapping['output'] = col_map['output']
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
+
# Reasoning
|
| 226 |
+
if 'reasoning' in col_names_lower or 'thinking' in col_names_lower:
|
| 227 |
+
detected_tasks.append("reasoning")
|
| 228 |
+
if 'reasoning' in col_names_lower:
|
| 229 |
+
mapping['reasoning'] = col_map['reasoning']
|
| 230 |
|
| 231 |
+
# Default
|
| 232 |
if not detected_tasks:
|
| 233 |
+
detected_tasks.append("causal-lm")
|
| 234 |
+
# Use first text-like column
|
| 235 |
+
for col in col_names_original:
|
| 236 |
+
if len(dataset) > 0 and isinstance(dataset[0].get(col), str):
|
| 237 |
+
mapping['text'] = col
|
| 238 |
+
break
|
| 239 |
|
| 240 |
return detected_tasks, mapping
|
| 241 |
|
| 242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
# ============================================
|
| 244 |
# TRAINING JOB ENDPOINTS
|
| 245 |
# ============================================
|
| 246 |
|
| 247 |
@router.post("/start", response_model=TrainingJobResponse)
|
| 248 |
async def start_training(
|
| 249 |
+
request: TrainingRequestSimple,
|
| 250 |
db: AsyncSession = Depends(get_db)
|
| 251 |
):
|
| 252 |
"""Start a new training job."""
|
|
|
|
| 257 |
training_job = TrainingJob(
|
| 258 |
job_id=job_id,
|
| 259 |
name=request.name,
|
|
|
|
| 260 |
task_type=request.task_type,
|
| 261 |
base_model=request.base_model,
|
|
|
|
|
|
|
| 262 |
dataset_name=request.dataset.name,
|
|
|
|
| 263 |
dataset_split=request.dataset.train_split,
|
| 264 |
+
training_args={
|
| 265 |
+
"epochs": request.training_args.epochs,
|
| 266 |
+
"batch_size": request.training_args.batch_size,
|
| 267 |
+
"learning_rate": request.training_args.learning_rate,
|
| 268 |
+
"warmup_steps": request.training_args.warmup_steps,
|
| 269 |
+
},
|
| 270 |
peft_config=request.peft_config.dict() if request.peft_config else None,
|
| 271 |
status=JobStatus.PENDING.value,
|
| 272 |
total_epochs=request.training_args.epochs,
|
|
|
|
| 273 |
)
|
| 274 |
|
| 275 |
db.add(training_job)
|
| 276 |
await db.commit()
|
| 277 |
|
| 278 |
+
# Build full config for training service
|
| 279 |
config = {
|
| 280 |
"job_id": job_id,
|
| 281 |
"task_type": request.task_type,
|
| 282 |
"base_model": request.base_model,
|
| 283 |
+
"model_name": request.base_model,
|
| 284 |
+
"dataset_name": request.dataset.name,
|
| 285 |
+
"dataset": {
|
| 286 |
+
"name": request.dataset.name,
|
| 287 |
+
"train_split": request.dataset.train_split,
|
| 288 |
+
"validation_split": request.dataset.validation_split,
|
| 289 |
+
"column_mapping": request.dataset.column_mapping,
|
| 290 |
+
"max_length": request.dataset.max_length,
|
| 291 |
+
},
|
| 292 |
+
"training_args": {
|
| 293 |
+
"epochs": request.training_args.epochs,
|
| 294 |
+
"batch_size": request.training_args.batch_size,
|
| 295 |
+
"learning_rate": request.training_args.learning_rate,
|
| 296 |
+
"warmup_steps": request.training_args.warmup_steps,
|
| 297 |
+
},
|
| 298 |
"peft_config": request.peft_config.dict() if request.peft_config else None,
|
| 299 |
+
"prompt_template": request.prompt_template.dict() if request.prompt_template else {"preset": "none"},
|
| 300 |
}
|
| 301 |
|
| 302 |
# Submit to queue
|
| 303 |
+
priority = JobPriority.NORMAL
|
| 304 |
await queue.submit(config, priority=priority)
|
| 305 |
|
| 306 |
# Update status
|
|
|
|
| 323 |
db: AsyncSession = Depends(get_db)
|
| 324 |
):
|
| 325 |
"""Get status of a training job."""
|
|
|
|
|
|
|
|
|
|
| 326 |
result = await db.execute(
|
| 327 |
select(TrainingJob).where(TrainingJob.job_id == job_id)
|
| 328 |
)
|
|
|
|
| 331 |
if not job:
|
| 332 |
raise HTTPException(status_code=404, detail="Job not found")
|
| 333 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
return JobStatusResponse(
|
| 335 |
job_id=job.job_id,
|
| 336 |
name=job.name,
|
| 337 |
status=job.status,
|
| 338 |
+
progress=job.progress or 0.0,
|
| 339 |
+
current_epoch=job.current_epoch or 0,
|
| 340 |
+
total_epochs=job.total_epochs or 0,
|
| 341 |
+
current_step=job.current_step or 0,
|
| 342 |
+
total_steps=job.total_steps or 0,
|
| 343 |
train_loss=job.train_loss,
|
| 344 |
eval_loss=job.eval_loss,
|
| 345 |
metrics=job.metrics or {},
|
|
|
|
| 352 |
)
|
| 353 |
|
| 354 |
|
| 355 |
+
@router.get("/jobs")
|
| 356 |
async def list_jobs(
|
| 357 |
status: Optional[str] = None,
|
| 358 |
limit: int = 50,
|
|
|
|
| 370 |
result = await db.execute(query)
|
| 371 |
jobs = result.scalars().all()
|
| 372 |
|
| 373 |
+
return {
|
| 374 |
+
"jobs": [
|
| 375 |
+
{
|
| 376 |
+
"job_id": job.job_id,
|
| 377 |
+
"name": job.name,
|
| 378 |
+
"status": job.status,
|
| 379 |
+
"progress": job.progress or 0.0,
|
| 380 |
+
"current_epoch": job.current_epoch or 0,
|
| 381 |
+
"total_epochs": job.total_epochs or 0,
|
| 382 |
+
"current_step": job.current_step or 0,
|
| 383 |
+
"total_steps": job.total_steps or 0,
|
| 384 |
+
"train_loss": job.train_loss,
|
| 385 |
+
"eval_loss": job.eval_loss,
|
| 386 |
+
"metrics": job.metrics or {},
|
| 387 |
+
"error_message": job.error_message,
|
| 388 |
+
"created_at": job.created_at.isoformat() if job.created_at else None,
|
| 389 |
+
"started_at": job.started_at.isoformat() if job.started_at else None,
|
| 390 |
+
"completed_at": job.completed_at.isoformat() if job.completed_at else None,
|
| 391 |
+
"model_name": job.base_model,
|
| 392 |
+
"dataset_name": job.dataset_name,
|
| 393 |
+
}
|
| 394 |
+
for job in jobs
|
| 395 |
+
]
|
| 396 |
+
}
|
| 397 |
|
| 398 |
|
| 399 |
@router.post("/cancel/{job_id}")
|
|
|
|
| 405 |
queue = get_queue()
|
| 406 |
cancelled = await queue.cancel_job(job_id)
|
| 407 |
|
|
|
|
|
|
|
|
|
|
| 408 |
result = await db.execute(
|
| 409 |
select(TrainingJob).where(TrainingJob.job_id == job_id)
|
| 410 |
)
|
|
|
|
| 417 |
return {"message": f"Job {job_id} cancelled", "success": True}
|
| 418 |
|
| 419 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
@router.get("/queue/status")
|
| 421 |
async def get_queue_status():
|
| 422 |
"""Get current queue status."""
|
|
|
|
| 428 |
}
|
| 429 |
|
| 430 |
|
| 431 |
+
@router.get("/templates")
|
| 432 |
+
async def get_training_templates():
|
| 433 |
+
"""Get available training configuration templates."""
|
| 434 |
+
return TRAINING_TEMPLATES
|
| 435 |
+
|
| 436 |
+
|
| 437 |
@router.get("/metrics/{job_id}")
|
| 438 |
async def get_job_metrics(
|
| 439 |
job_id: str,
|
|
|
|
| 452 |
"job_id": job_id,
|
| 453 |
"train_loss": job.train_loss,
|
| 454 |
"eval_loss": job.eval_loss,
|
| 455 |
+
"metrics": job.metrics
|
|
|
|
| 456 |
}
|
| 457 |
|
| 458 |
|
|
|
|
| 476 |
await db.delete(job)
|
| 477 |
await db.commit()
|
| 478 |
|
| 479 |
+
return {"message": f"Job {job_id} deleted", "success": True}
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
# ============================================
|
| 483 |
+
# PROMPT TEMPLATE ENDPOINTS
|
| 484 |
+
# ============================================
|
| 485 |
+
|
| 486 |
+
@router.get("/prompt-templates")
|
| 487 |
+
async def get_prompt_templates():
|
| 488 |
+
"""Get available prompt template presets."""
|
| 489 |
+
return {
|
| 490 |
+
"presets": [
|
| 491 |
+
{"id": "none", "name": "None (Raw Text)", "description": "Use dataset text directly"},
|
| 492 |
+
{"id": "alpaca", "name": "Alpaca Format", "description": "Instruction-Input-Output"},
|
| 493 |
+
{"id": "chatml", "name": "ChatML", "description": "ChatML format"},
|
| 494 |
+
{"id": "llama3", "name": "Llama 3", "description": "Llama 3 instruction format"},
|
| 495 |
+
{"id": "mistral", "name": "Mistral", "description": "Mistral instruction format"},
|
| 496 |
+
{"id": "vicuna", "name": "Vicuna", "description": "Vicuna chat format"},
|
| 497 |
+
{"id": "phi3", "name": "Phi-3", "description": "Microsoft Phi-3 format"},
|
| 498 |
+
{"id": "reasoning", "name": "Reasoning/CoT", "description": "Chain-of-thought"}
|
| 499 |
+
]
|
| 500 |
+
}
|