vectorplasticity commited on
Commit
f676ff7
·
verified ·
1 Parent(s): c30b193

Fix training router to accept form column_mapping and prompt_template format

Browse files
Files changed (1) hide show
  1. app/routers/training.py +174 -694
app/routers/training.py CHANGED
@@ -23,519 +23,50 @@ router = APIRouter(prefix="/api/training", tags=["Training"])
23
 
24
 
25
  # ============================================
26
- # PROMPT TEMPLATE CONFIGURATION
27
  # ============================================
28
 
29
- class PromptSectionConfig(BaseModel):
30
- """Configuration for a single prompt section (system, user, assistant, etc.)"""
31
- enabled: bool = Field(default=True, description="Whether to include this section")
32
- template: str = Field(
33
- default="",
34
- description="Template string with {column} placeholders"
35
- )
36
- columns: List[str] = Field(
37
- default_factory=list,
38
- description="List of dataset columns used in this section"
39
- )
40
- prefix: str = Field(default="", description="Prefix before content (e.g., 'System: ')")
41
- suffix: str = Field(default="", description="Suffix after content (e.g., '\n\n')")
42
- strip_whitespace: bool = Field(default=True)
43
- required: bool = Field(default=False, description="Raise error if columns missing")
44
-
45
-
46
- class PromptTemplateConfig(BaseModel):
47
- """Full prompt template configuration for training data formatting"""
48
- # Chat-style formatting
49
- use_chat_format: bool = Field(default=True, description="Use chat-style message format")
50
- chat_template: Optional[str] = Field(
51
- default=None,
52
- description="Jinja2 chat template (auto-detect from tokenizer if None)"
53
- )
54
-
55
- # Message sections
56
- system: Optional[PromptSectionConfig] = Field(
57
- default=None,
58
- description="System message configuration"
59
- )
60
- user: Optional[PromptSectionConfig] = Field(
61
- default=None,
62
- description="User message configuration"
63
- )
64
- context: Optional[PromptSectionConfig] = Field(
65
- default=None,
66
- description="Context/passages configuration"
67
- )
68
- reasoning: Optional[PromptSectionConfig] = Field(
69
- default=None,
70
- description="Reasoning/chain-of-thought configuration"
71
- )
72
- assistant: Optional[PromptSectionConfig] = Field(
73
- default=None,
74
- description="Assistant response configuration (target for training)"
75
- )
76
-
77
- # Custom sections for flexibility
78
- custom_sections: Dict[str, PromptSectionConfig] = Field(
79
- default_factory=dict,
80
- description="Additional custom sections"
81
- )
82
-
83
- # Section ordering
84
- section_order: List[str] = Field(
85
- default=["system", "context", "user", "reasoning", "assistant"],
86
- description="Order of sections in the prompt"
87
- )
88
-
89
- # Special tokens
90
- bos_token: Optional[str] = Field(default=None, description="Beginning of sequence token")
91
- eos_token: Optional[str] = Field(default=None, description="End of sequence token")
92
- pad_token: Optional[str] = Field(default=None, description="Padding token")
93
-
94
- # Separator configuration
95
- section_separator: str = Field(default="\n\n", description="Separator between sections")
96
- message_separator: str = Field(default="\n", description="Separator between messages")
97
-
98
- # Instruction format (for instruction-tuned models)
99
- instruction_format: str = Field(
100
- default="none",
101
- description="Preset format: none, alpaca, chatml, llama3, mistral, vicuna, phi3"
102
- )
103
-
104
- def get_template_for_format(self, format_name: str) -> Dict[str, Any]:
105
- """Get preset template configuration for known formats"""
106
- presets = {
107
- "none": {},
108
- "alpaca": {
109
- "system": PromptSectionConfig(
110
- enabled=True,
111
- template="Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.",
112
- prefix="", suffix="\n\n"
113
- ),
114
- "user": PromptSectionConfig(
115
- enabled=True,
116
- template="### Instruction:\n{instruction}\n\n### Input:\n{input}",
117
- columns=["instruction", "input"],
118
- prefix="", suffix="\n\n"
119
- ),
120
- "assistant": PromptSectionConfig(
121
- enabled=True,
122
- template="### Response:\n{output}",
123
- columns=["output"],
124
- prefix="", suffix=""
125
- ),
126
- "section_order": ["system", "user", "assistant"]
127
- },
128
- "chatml": {
129
- "system": PromptSectionConfig(
130
- enabled=True,
131
- template="{system_message}",
132
- columns=["system_message"],
133
- prefix="<|im_start|>system\n", suffix="<|im_end|>\n"
134
- ),
135
- "user": PromptSectionConfig(
136
- enabled=True,
137
- template="{user_message}",
138
- columns=["user_message"],
139
- prefix="<|im_start|>user\n", suffix="<|im_end|>\n"
140
- ),
141
- "assistant": PromptSectionConfig(
142
- enabled=True,
143
- template="{assistant_message}",
144
- columns=["assistant_message", "output", "response"],
145
- prefix="<|im_start|>assistant\n", suffix="<|im_end|>"
146
- ),
147
- "section_order": ["system", "user", "assistant"]
148
- },
149
- "llama3": {
150
- "system": PromptSectionConfig(
151
- enabled=True,
152
- template="{system_message}",
153
- columns=["system_message"],
154
- prefix="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n",
155
- suffix="<|eot_id|>"
156
- ),
157
- "user": PromptSectionConfig(
158
- enabled=True,
159
- template="{user_message}",
160
- columns=["user_message", "question"],
161
- prefix="<|start_header_id|>user<|end_header_id|>\n\n",
162
- suffix="<|eot_id|>"
163
- ),
164
- "assistant": PromptSectionConfig(
165
- enabled=True,
166
- template="{assistant_message}",
167
- columns=["assistant_message", "output", "response"],
168
- prefix="<|start_header_id|>assistant<|end_header_id|>\n\n",
169
- suffix="<|eot_id|>"
170
- ),
171
- "section_order": ["system", "user", "assistant"]
172
- },
173
- "mistral": {
174
- "system": PromptSectionConfig(
175
- enabled=True,
176
- template="{system_message}",
177
- columns=["system_message"],
178
- prefix="[INST] ", suffix=" "
179
- ),
180
- "user": PromptSectionConfig(
181
- enabled=True,
182
- template="{user_message}",
183
- columns=["user_message", "question"],
184
- prefix="", suffix=" [/INST]"
185
- ),
186
- "assistant": PromptSectionConfig(
187
- enabled=True,
188
- template="{assistant_message}",
189
- columns=["assistant_message", "output"],
190
- prefix="", suffix="</s>"
191
- ),
192
- "section_order": ["system", "user", "assistant"]
193
- },
194
- "vicuna": {
195
- "system": PromptSectionConfig(
196
- enabled=True,
197
- template="A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.",
198
- prefix="", suffix=" "
199
- ),
200
- "user": PromptSectionConfig(
201
- enabled=True,
202
- template="{user_message}",
203
- columns=["user_message", "question"],
204
- prefix="USER: ", suffix=" "
205
- ),
206
- "assistant": PromptSectionConfig(
207
- enabled=True,
208
- template="{assistant_message}",
209
- columns=["assistant_message", "output"],
210
- prefix="ASSISTANT: ", suffix=""
211
- ),
212
- "section_order": ["system", "user", "assistant"]
213
- },
214
- "phi3": {
215
- "system": PromptSectionConfig(
216
- enabled=True,
217
- template="{system_message}",
218
- columns=["system_message"],
219
- prefix="<|system|>\n", suffix="<|end|>\n"
220
- ),
221
- "user": PromptSectionConfig(
222
- enabled=True,
223
- template="{user_message}",
224
- columns=["user_message", "question"],
225
- prefix="<|user|>\n", suffix="<|end|>\n"
226
- ),
227
- "assistant": PromptSectionConfig(
228
- enabled=True,
229
- template="{assistant_message}",
230
- columns=["assistant_message", "output"],
231
- prefix="<|assistant|>\n", suffix="<|end|>"
232
- ),
233
- "section_order": ["system", "user", "assistant"]
234
- },
235
- "reasoning": {
236
- "system": PromptSectionConfig(
237
- enabled=True,
238
- template="You are a helpful AI assistant that thinks step by step before responding.",
239
- prefix="", suffix="\n\n"
240
- ),
241
- "context": PromptSectionConfig(
242
- enabled=True,
243
- template="{context}",
244
- columns=["context", "passages", "background"],
245
- prefix="Context:\n", suffix="\n\n"
246
- ),
247
- "user": PromptSectionConfig(
248
- enabled=True,
249
- template="{question}",
250
- columns=["question", "query", "user_message"],
251
- prefix="Question: ", suffix="\n\n"
252
- ),
253
- "reasoning": PromptSectionConfig(
254
- enabled=True,
255
- template="{reasoning}",
256
- columns=["reasoning", "thinking", "chain_of_thought"],
257
- prefix="Reasoning:\n", suffix="\n\n"
258
- ),
259
- "assistant": PromptSectionConfig(
260
- enabled=True,
261
- template="{answer}",
262
- columns=["answer", "output", "response"],
263
- prefix="Answer: ", suffix=""
264
- ),
265
- "section_order": ["system", "context", "user", "reasoning", "assistant"]
266
- }
267
- }
268
- return presets.get(format_name, {})
269
-
270
-
271
- # ============================================
272
- # COLUMN MAPPING CONFIGURATION
273
- # ============================================
274
-
275
- class ColumnMappingConfig(BaseModel):
276
- """Maps dataset columns to training roles"""
277
- # Primary text columns
278
- text_column: Optional[str] = Field(None, description="Main text column (for causal LM)")
279
- input_column: Optional[str] = Field(None, description="Input text column")
280
- output_column: Optional[str] = Field(None, description="Output/target column")
281
-
282
- # Multi-column support
283
- instruction_column: Optional[str] = Field(None, description="Instruction column")
284
- question_column: Optional[str] = Field(None, description="Question column")
285
- answer_column: Optional[str] = Field(None, description="Answer column")
286
- context_column: Optional[str] = Field(None, description="Context/passages column")
287
- reasoning_column: Optional[str] = Field(None, description="Reasoning/CoT column")
288
-
289
- # Classification specific
290
- label_column: Optional[str] = Field(None, description="Label column for classification")
291
- label_mapping: Optional[Dict[str, int]] = Field(None, description="Label to ID mapping")
292
-
293
- # NER/Token classification
294
- tokens_column: Optional[str] = Field(None, description="Tokens column (for NER)")
295
- tags_column: Optional[str] = Field(None, description="NER tags column")
296
- ner_tags_mapping: Optional[Dict[str, str]] = Field(None, description="NER tag to label mapping")
297
-
298
- # QA specific
299
- title_column: Optional[str] = Field(None, description="Title column for QA context")
300
- id_column: Optional[str] = Field(None, description="ID column")
301
- answers_column: Optional[str] = Field(None, description="Answers column (list of answers)")
302
- start_position_column: Optional[str] = Field(None, description="Start position column")
303
- end_position_column: Optional[str] = Field(None, description="End position column")
304
-
305
- # Additional context
306
- metadata_columns: List[str] = Field(
307
- default_factory=list,
308
- description="Additional columns to include as metadata"
309
- )
310
-
311
- # Column transformations
312
- column_transforms: Dict[str, str] = Field(
313
- default_factory=dict,
314
- description="Transformations to apply: column -> transform_type"
315
- )
316
-
317
- # Custom column aliases
318
- column_aliases: Dict[str, str] = Field(
319
- default_factory=dict,
320
- description="Map custom column names to standard names"
321
- )
322
-
323
- def get_effective_column(self, role: str) -> Optional[str]:
324
- """Get the effective column name for a given role, considering aliases"""
325
- role_to_column = {
326
- "text": self.text_column,
327
- "input": self.input_column,
328
- "output": self.output_column,
329
- "instruction": self.instruction_column,
330
- "question": self.question_column,
331
- "answer": self.answer_column,
332
- "context": self.context_column,
333
- "reasoning": self.reasoning_column,
334
- "label": self.label_column,
335
- "tokens": self.tokens_column,
336
- "tags": self.tags_column,
337
- "title": self.title_column,
338
- "id": self.id_column,
339
- "answers": self.answers_column,
340
- }
341
- column = role_to_column.get(role)
342
- if column and column in self.column_aliases:
343
- return self.column_aliases[column]
344
- return column
345
-
346
 
347
- # ============================================
348
- # DATASET CONFIGURATION (ENHANCED)
349
- # ============================================
350
-
351
- class DatasetSplitConfig(BaseModel):
352
- """Configuration for a single dataset split"""
353
- name: str = Field(default="train", description="Split name (train, validation, test)")
354
- enabled: bool = Field(default=True)
355
- max_samples: Optional[int] = Field(None, description="Maximum samples to use")
356
- shuffle: bool = Field(default=True)
357
- seed: int = Field(default=42)
358
- stratify: bool = Field(default=False, description="Stratified sampling")
359
-
360
-
361
- class DatasetConfig(BaseModel):
362
- """Enhanced dataset configuration with full control"""
363
- # Source configuration
364
- source: str = Field(default="huggingface", description="Dataset source: huggingface, local, upload")
365
- name: Optional[str] = Field(None, description="HuggingFace dataset name")
366
- config: Optional[str] = Field(None, description="Dataset config/subset name")
367
- revision: Optional[str] = Field(None, description="Dataset revision/branch")
368
-
369
- # Split configuration
370
- splits: List[DatasetSplitConfig] = Field(
371
- default_factory=lambda: [DatasetSplitConfig(name="train")],
372
- description="Dataset splits to use"
373
- )
374
- train_split: str = Field(default="train", description="Training split name")
375
- validation_split: Optional[str] = Field(None, description="Validation split name")
376
- test_split: Optional[str] = Field(None, description="Test split name")
377
-
378
- # Split generation
379
- validation_split_ratio: float = Field(
380
- default=0.1,
381
- ge=0.0,
382
- le=0.5,
383
- description="Ratio for auto-generating validation split"
384
- )
385
- generate_validation: bool = Field(
386
- default=True,
387
- description="Auto-generate validation split if not provided"
388
- )
389
-
390
- # Column mapping
391
- column_mapping: ColumnMappingConfig = Field(
392
- default_factory=ColumnMappingConfig,
393
- description="Map dataset columns to training roles"
394
- )
395
-
396
- # Prompt template
397
- prompt_template: Optional[PromptTemplateConfig] = Field(
398
- default=None,
399
- description="Prompt structure configuration"
400
- )
401
-
402
- # Data processing
403
- max_length: int = Field(default=512, description="Max sequence length")
404
- max_target_length: int = Field(default=128, description="Max target length for seq2seq")
405
- truncate: bool = Field(default=True, description="Truncate sequences exceeding max_length")
406
-
407
- # Filtering
408
- filter_conditions: List[Dict[str, Any]] = Field(
409
- default_factory=list,
410
- description="Filter conditions: [{column, operator, value}]"
411
- )
412
- min_text_length: Optional[int] = Field(None, description="Minimum text length")
413
- max_text_length: Optional[int] = Field(None, description="Maximum text length")
414
-
415
- # Augmentation
416
- augmentation_enabled: bool = Field(default=False)
417
- augmentation_config: Dict[str, Any] = Field(
418
- default_factory=dict,
419
- description="Data augmentation configuration"
420
- )
421
-
422
- # Streaming (for large datasets)
423
- streaming: bool = Field(
424
- default=False,
425
- description="Use streaming mode for large datasets"
426
- )
427
- streaming_buffer_size: int = Field(default=10000)
428
-
429
- # Caching
430
- cache_dir: Optional[str] = Field(None, description="Cache directory")
431
- num_proc: int = Field(default=4, description="Number of processes for data loading")
432
-
433
- # Local file support
434
- path: Optional[str] = Field(None, description="Local file path")
435
- file_type: Optional[str] = Field(None, description="File type: json, jsonl, csv, parquet, text")
436
-
437
- # Data validation
438
- validate_data: bool = Field(default=True, description="Validate data before training")
439
- validation_sample_size: int = Field(default=100, description="Samples to validate")
440
-
441
-
442
- # ============================================
443
- # TRAINING ARGUMENTS CONFIGURATION
444
- # ============================================
445
 
446
- class TrainingArgsConfig(BaseModel):
447
- """Training arguments configuration."""
448
- epochs: int = Field(default=3, ge=1, le=100)
449
- batch_size: int = Field(default=8, ge=1, le=256)
450
- eval_batch_size: int = Field(default=16, ge=1, le=512)
451
- learning_rate: float = Field(default=5e-5, ge=1e-7, le=1.0)
452
- weight_decay: float = Field(default=0.01, ge=0.0, le=1.0)
453
- warmup_ratio: float = Field(default=0.1, ge=0.0, le=1.0)
454
- warmup_steps: int = Field(default=0, ge=0)
455
- max_grad_norm: float = Field(default=1.0)
456
- logging_steps: int = Field(default=10, ge=1)
457
- eval_steps: int = Field(default=500, ge=1)
458
- save_steps: int = Field(default=500, ge=1)
459
- save_total_limit: int = Field(default=3, ge=1, le=10)
460
- gradient_accumulation_steps: int = Field(default=1, ge=1, le=128)
461
- fp16: bool = Field(default=False)
462
- bf16: bool = Field(default=False)
463
- gradient_checkpointing: bool = Field(default=False)
464
- optimizer: str = Field(default="adamw_torch")
465
- lr_scheduler_type: str = Field(default="cosine")
466
- report_to: str = Field(default="none")
467
- seed: int = Field(default=42)
468
-
469
- # Advanced options
470
- eval_strategy: str = Field(default="steps")
471
- load_best_model_at_end: bool = Field(default=True)
472
- metric_for_best_model: str = Field(default="eval_loss")
473
- greater_is_better: bool = Field(default=False)
474
-
475
- # Memory optimization
476
- optim: str = Field(default="adamw_torch")
477
- ddp_find_unused_parameters: bool = Field(default=False)
478
-
479
 
480
 
481
- class PEFTConfig(BaseModel):
482
- """PEFT/LoRA configuration."""
483
  enabled: bool = Field(default=True)
484
- method: str = Field(default="lora", description="lora, adalora, ia3, prefix_tuning, prompt_tuning")
485
- r: int = Field(default=16, ge=1, le=256, description="LoRA rank")
486
- alpha: int = Field(default=32, ge=1, description="LoRA alpha")
487
- dropout: float = Field(default=0.05, ge=0.0, le=0.5)
488
- target_modules: List[str] = Field(default=["q_proj", "v_proj"])
489
- bias: str = Field(default="none")
490
- modules_to_save: List[str] = Field(default_factory=list)
491
-
492
- # AdaLoRA specific
493
- init_r: int = Field(default=12)
494
- t_init: int = Field(default=200)
495
- t_final: int = Field(default=1000)
496
-
497
- # Prefix Tuning specific
498
- num_virtual_tokens: int = Field(default=20)
499
-
500
- # Prompt Tuning specific
501
- num_tokens: int = Field(default=20)
502
- token_init: bool = Field(default=True)
503
 
504
 
505
- class OutputConfig(BaseModel):
506
- """Output configuration."""
507
- push_to_hub: bool = Field(default=False)
508
- hub_model_id: Optional[str] = Field(None)
509
- private: bool = Field(default=False)
510
- save_strategy: str = Field(default="steps")
511
- output_dir: Optional[str] = Field(None)
512
 
513
 
514
- class TrainingRequest(BaseModel):
515
- """Full training request with all configuration options."""
516
- name: str = Field(..., min_length=1, max_length=255)
517
- description: Optional[str] = Field(None)
518
  task_type: str = Field(default="causal-lm")
519
  base_model: str = Field(..., description="HuggingFace model ID")
520
- dataset: DatasetConfig = Field(default_factory=DatasetConfig)
521
- training_args: TrainingArgsConfig = Field(default_factory=TrainingArgsConfig)
522
- peft_config: Optional[PEFTConfig] = Field(None)
523
- output_config: OutputConfig = Field(default_factory=OutputConfig)
524
- tags: List[str] = Field(default_factory=list)
525
- priority: int = Field(default=5, ge=1, le=20)
526
-
527
- @validator('task_type')
528
- def validate_task_type(cls, v):
529
- valid_types = [
530
- "causal-lm", "seq2seq", "token-classification",
531
- "sequence-classification", "question-answering",
532
- "summarization", "translation", "text-classification",
533
- "masked-lm", "vision-classification", "audio-classification",
534
- "reasoning"
535
- ]
536
- if v not in valid_types:
537
- raise ValueError(f"Invalid task_type. Must be one of: {valid_types}")
538
- return v
539
 
540
 
541
  class TrainingJobResponse(BaseModel):
@@ -572,11 +103,11 @@ class DatasetPreviewResponse(BaseModel):
572
  dataset_name: str
573
  config: Optional[str]
574
  splits: List[str]
575
- columns: List[Dict[str, Any]]
576
  sample_data: List[Dict[str, Any]]
577
  total_rows: Optional[int]
578
  detected_task_types: List[str]
579
- suggested_column_mapping: ColumnMappingConfig
580
 
581
 
582
  # Global queue instance
@@ -615,29 +146,28 @@ async def preview_dataset(
615
  # Get column info
616
  columns = []
617
  for col_name, col_type in ds.features.items():
618
- col_info = {
619
- "name": col_name,
620
- "type": str(col_type),
621
- "dtype": type(col_type).__name__,
622
- }
623
- # Detect if it's a label column
624
- if hasattr(col_type, 'names'):
625
- col_info["labels"] = col_type.names
626
- col_info["num_labels"] = len(col_type.names)
627
- columns.append(col_info)
628
 
629
  # Get sample data
630
  sample_data = []
631
  for i in range(min(rows, len(ds))):
632
- sample_data.append(ds[i])
633
 
634
  # Detect task type and suggest column mapping
635
- detected_tasks, suggested_mapping = detect_task_and_mapping(ds, columns)
 
 
 
 
 
 
 
 
636
 
637
  return DatasetPreviewResponse(
638
  dataset_name=dataset_name,
639
  config=config,
640
- splits=list(ds.info.splits.keys()) if hasattr(ds, 'info') and ds.info.splits else [split],
641
  columns=columns,
642
  sample_data=sample_data,
643
  total_rows=len(ds),
@@ -646,151 +176,77 @@ async def preview_dataset(
646
  )
647
 
648
  except Exception as e:
 
649
  raise HTTPException(status_code=400, detail=f"Error loading dataset: {str(e)}")
650
 
651
 
652
- def detect_task_and_mapping(dataset, columns: List[Dict]) -> tuple:
653
  """Detect suitable task types and suggest column mappings."""
654
- col_names = [c["name"].lower() for c in columns]
655
- col_names_original = [c["name"] for c in columns]
656
  detected_tasks = []
657
- mapping = ColumnMappingConfig()
658
 
659
- # Check for text classification
660
- label_cols = [c for c in columns if "label" in c["name"].lower() or c.get("labels")]
661
- text_cols = [c for c in columns if "text" in c["name"].lower()]
662
 
663
- if label_cols and text_cols:
664
- detected_tasks.append("text-classification")
665
- mapping.text_column = text_cols[0]["name"]
666
- mapping.label_column = label_cols[0]["name"]
667
- if label_cols[0].get("labels"):
668
- mapping.label_mapping = {name: i for i, name in enumerate(label_cols[0]["labels"])}
669
 
670
- # Check for QA
671
- question_cols = [c for c in columns if "question" in c["name"].lower()]
672
- answer_cols = [c for c in columns if "answer" in c["name"].lower()]
673
- context_cols = [c for c in columns if "context" in c["name"].lower()]
 
674
 
675
- if question_cols and answer_cols:
 
676
  detected_tasks.append("question-answering")
677
- mapping.question_column = question_cols[0]["name"]
678
- mapping.answer_column = answer_cols[0]["name"]
679
- if context_cols:
680
- mapping.context_column = context_cols[0]["name"]
681
 
682
- # Check for instruction/output
683
- instruction_cols = [c for c in columns if "instruction" in c["name"].lower() or "prompt" in c["name"].lower()]
684
- output_cols = [c for c in columns if "output" in c["name"].lower() or "response" in c["name"].lower()]
685
-
686
- if instruction_cols and output_cols:
687
  detected_tasks.append("causal-lm")
688
- mapping.instruction_column = instruction_cols[0]["name"]
689
- mapping.output_column = output_cols[0]["name"]
690
-
691
- # Check for input/target (seq2seq)
692
- input_cols = [c for c in columns if "input" in c["name"].lower() or "source" in c["name"].lower()]
693
- target_cols = [c for c in columns if "target" in c["name"].lower() or "summary" in c["name"].lower()]
694
-
695
- if input_cols and target_cols:
696
- detected_tasks.append("seq2seq")
697
- mapping.input_column = input_cols[0]["name"]
698
- mapping.output_column = target_cols[0]["name"]
699
-
700
- # Check for NER
701
- tokens_cols = [c for c in columns if "token" in c["name"].lower() or "word" in c["name"].lower()]
702
- tags_cols = [c for c in columns if "tag" in c["name"].lower() or "ner" in c["name"].lower()]
703
 
704
- if tokens_cols and tags_cols:
705
- detected_tasks.append("token-classification")
706
- mapping.tokens_column = tokens_cols[0]["name"]
707
- mapping.tags_column = tags_cols[0]["name"]
 
708
 
709
- # Default to causal LM if we have any text
710
  if not detected_tasks:
711
- text_like = [c for c in columns if c["dtype"] in ["Value", "LargeString"]]
712
- if text_like:
713
- detected_tasks.append("causal-lm")
714
- mapping.text_column = text_like[0]["name"]
 
 
715
 
716
  return detected_tasks, mapping
717
 
718
 
719
- # ============================================
720
- # PROMPT TEMPLATE ENDPOINTS
721
- # ============================================
722
-
723
- @router.get("/prompt-templates")
724
- async def get_prompt_templates():
725
- """Get available prompt template presets."""
726
- return {
727
- "presets": [
728
- {
729
- "id": "none",
730
- "name": "None (Raw Text)",
731
- "description": "Use dataset text directly without formatting"
732
- },
733
- {
734
- "id": "alpaca",
735
- "name": "Alpaca Format",
736
- "description": "Instruction-Input-Output format for instruction tuning"
737
- },
738
- {
739
- "id": "chatml",
740
- "name": "ChatML",
741
- "description": "ChatML format used by various models"
742
- },
743
- {
744
- "id": "llama3",
745
- "name": "Llama 3",
746
- "description": "Llama 3 instruction format"
747
- },
748
- {
749
- "id": "mistral",
750
- "name": "Mistral",
751
- "description": "Mistral/Vicuna instruction format"
752
- },
753
- {
754
- "id": "vicuna",
755
- "name": "Vicuna",
756
- "description": "Vicuna chat format"
757
- },
758
- {
759
- "id": "phi3",
760
- "name": "Phi-3",
761
- "description": "Microsoft Phi-3 format"
762
- },
763
- {
764
- "id": "reasoning",
765
- "name": "Reasoning/CoT",
766
- "description": "Chain-of-thought reasoning format with explicit thinking"
767
- }
768
- ]
769
- }
770
-
771
-
772
- @router.get("/prompt-templates/{template_id}")
773
- async def get_prompt_template(template_id: str):
774
- """Get specific prompt template configuration."""
775
- config = PromptTemplateConfig(instruction_format=template_id)
776
- preset = config.get_template_for_format(template_id)
777
-
778
- if not preset and template_id != "none":
779
- raise HTTPException(status_code=404, detail="Template not found")
780
-
781
- return {
782
- "id": template_id,
783
- "config": preset
784
- }
785
-
786
-
787
  # ============================================
788
  # TRAINING JOB ENDPOINTS
789
  # ============================================
790
 
791
  @router.post("/start", response_model=TrainingJobResponse)
792
  async def start_training(
793
- request: TrainingRequest,
794
  db: AsyncSession = Depends(get_db)
795
  ):
796
  """Start a new training job."""
@@ -801,37 +257,50 @@ async def start_training(
801
  training_job = TrainingJob(
802
  job_id=job_id,
803
  name=request.name,
804
- description=request.description,
805
  task_type=request.task_type,
806
  base_model=request.base_model,
807
- output_model_name=request.output_config.hub_model_id,
808
- dataset_source=request.dataset.source,
809
  dataset_name=request.dataset.name,
810
- dataset_config=request.dataset.config,
811
  dataset_split=request.dataset.train_split,
812
- training_args=request.training_args.dict(),
 
 
 
 
 
813
  peft_config=request.peft_config.dict() if request.peft_config else None,
814
  status=JobStatus.PENDING.value,
815
  total_epochs=request.training_args.epochs,
816
- tags=request.tags
817
  )
818
 
819
  db.add(training_job)
820
  await db.commit()
821
 
822
- # Build full config
823
  config = {
824
  "job_id": job_id,
825
  "task_type": request.task_type,
826
  "base_model": request.base_model,
827
- "dataset": request.dataset.dict(),
828
- "training_args": request.training_args.dict(),
 
 
 
 
 
 
 
 
 
 
 
 
 
829
  "peft_config": request.peft_config.dict() if request.peft_config else None,
830
- "output_config": request.output_config.dict()
831
  }
832
 
833
  # Submit to queue
834
- priority = JobPriority(request.priority)
835
  await queue.submit(config, priority=priority)
836
 
837
  # Update status
@@ -854,9 +323,6 @@ async def get_job_status(
854
  db: AsyncSession = Depends(get_db)
855
  ):
856
  """Get status of a training job."""
857
- queue = get_queue()
858
- queue_status = await queue.get_status(job_id)
859
-
860
  result = await db.execute(
861
  select(TrainingJob).where(TrainingJob.job_id == job_id)
862
  )
@@ -865,20 +331,15 @@ async def get_job_status(
865
  if not job:
866
  raise HTTPException(status_code=404, detail="Job not found")
867
 
868
- if queue_status:
869
- job.status = queue_status.get("status", job.status)
870
- if queue_status.get("progress"):
871
- job.progress = queue_status["progress"]
872
-
873
  return JobStatusResponse(
874
  job_id=job.job_id,
875
  name=job.name,
876
  status=job.status,
877
- progress=job.progress,
878
- current_epoch=job.current_epoch,
879
- total_epochs=job.total_epochs,
880
- current_step=job.current_step,
881
- total_steps=job.total_steps,
882
  train_loss=job.train_loss,
883
  eval_loss=job.eval_loss,
884
  metrics=job.metrics or {},
@@ -891,7 +352,7 @@ async def get_job_status(
891
  )
892
 
893
 
894
- @router.get("/jobs", response_model=List[JobStatusResponse])
895
  async def list_jobs(
896
  status: Optional[str] = None,
897
  limit: int = 50,
@@ -909,28 +370,30 @@ async def list_jobs(
909
  result = await db.execute(query)
910
  jobs = result.scalars().all()
911
 
912
- return [
913
- JobStatusResponse(
914
- job_id=job.job_id,
915
- name=job.name,
916
- status=job.status,
917
- progress=job.progress,
918
- current_epoch=job.current_epoch,
919
- total_epochs=job.total_epochs,
920
- current_step=job.current_step,
921
- total_steps=job.total_steps,
922
- train_loss=job.train_loss,
923
- eval_loss=job.eval_loss,
924
- metrics=job.metrics or {},
925
- error_message=job.error_message,
926
- created_at=job.created_at,
927
- started_at=job.started_at,
928
- completed_at=job.completed_at,
929
- output_path=job.output_path,
930
- hub_model_id=job.hub_model_id
931
- )
932
- for job in jobs
933
- ]
 
 
934
 
935
 
936
  @router.post("/cancel/{job_id}")
@@ -942,9 +405,6 @@ async def cancel_job(
942
  queue = get_queue()
943
  cancelled = await queue.cancel_job(job_id)
944
 
945
- if not cancelled:
946
- raise HTTPException(status_code=400, detail="Cannot cancel job")
947
-
948
  result = await db.execute(
949
  select(TrainingJob).where(TrainingJob.job_id == job_id)
950
  )
@@ -957,12 +417,6 @@ async def cancel_job(
957
  return {"message": f"Job {job_id} cancelled", "success": True}
958
 
959
 
960
- @router.get("/templates")
961
- async def get_training_templates():
962
- """Get available training configuration templates."""
963
- return TRAINING_TEMPLATES
964
-
965
-
966
  @router.get("/queue/status")
967
  async def get_queue_status():
968
  """Get current queue status."""
@@ -974,6 +428,12 @@ async def get_queue_status():
974
  }
975
 
976
 
 
 
 
 
 
 
977
  @router.get("/metrics/{job_id}")
978
  async def get_job_metrics(
979
  job_id: str,
@@ -992,8 +452,7 @@ async def get_job_metrics(
992
  "job_id": job_id,
993
  "train_loss": job.train_loss,
994
  "eval_loss": job.eval_loss,
995
- "metrics": job.metrics,
996
- "learning_rate": job.learning_rate
997
  }
998
 
999
 
@@ -1017,4 +476,25 @@ async def delete_job(
1017
  await db.delete(job)
1018
  await db.commit()
1019
 
1020
- return {"message": f"Job {job_id} deleted", "success": True}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  # ============================================
26
+ # SIMPLIFIED REQUEST MODELS (matching dashboard form)
27
  # ============================================
28
 
29
+ class DatasetConfigSimple(BaseModel):
30
+ """Simplified dataset config matching the dashboard form."""
31
+ name: str = Field(..., description="HuggingFace dataset name")
32
+ train_split: str = Field(default="train")
33
+ validation_split: Optional[str] = Field(default="validation")
34
+ column_mapping: Dict[str, str] = Field(default_factory=dict, description="Maps roles to column names: {text: 'col1', input: 'col2'}")
35
+ max_length: int = Field(default=512)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ class TrainingArgsSimple(BaseModel):
39
+ """Simplified training args matching the dashboard form."""
40
+ epochs: int = Field(default=3)
41
+ batch_size: int = Field(default=1)
42
+ learning_rate: float = Field(default=5e-5)
43
+ warmup_steps: int = Field(default=100)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
 
46
+ class PEFTConfigSimple(BaseModel):
47
+ """Simplified PEFT config matching the dashboard form."""
48
  enabled: bool = Field(default=True)
49
+ method: str = Field(default="lora")
50
+ r: int = Field(default=16)
51
+ alpha: int = Field(default=32)
52
+ dropout: float = Field(default=0.05)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
 
55
+ class PromptTemplateSimple(BaseModel):
56
+ """Simplified prompt template matching the dashboard form."""
57
+ preset: str = Field(default="none", description="Template preset: none, alpaca, chatml, llama3, mistral, vicuna, phi3, reasoning")
58
+ custom: Optional[Dict[str, Any]] = Field(default=None, description="Custom template sections")
 
 
 
59
 
60
 
61
+ class TrainingRequestSimple(BaseModel):
62
+ """Simplified training request matching the dashboard form."""
63
+ name: str = Field(default="training-job")
 
64
  task_type: str = Field(default="causal-lm")
65
  base_model: str = Field(..., description="HuggingFace model ID")
66
+ dataset: DatasetConfigSimple
67
+ training_args: TrainingArgsSimple = Field(default_factory=TrainingArgsSimple)
68
+ peft_config: Optional[PEFTConfigSimple] = Field(None)
69
+ prompt_template: Optional[PromptTemplateSimple] = Field(None, description="Prompt template configuration")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
 
72
  class TrainingJobResponse(BaseModel):
 
103
  dataset_name: str
104
  config: Optional[str]
105
  splits: List[str]
106
+ columns: List[str]
107
  sample_data: List[Dict[str, Any]]
108
  total_rows: Optional[int]
109
  detected_task_types: List[str]
110
+ suggested_column_mapping: Dict[str, str]
111
 
112
 
113
  # Global queue instance
 
146
  # Get column info
147
  columns = []
148
  for col_name, col_type in ds.features.items():
149
+ columns.append(col_name)
 
 
 
 
 
 
 
 
 
150
 
151
  # Get sample data
152
  sample_data = []
153
  for i in range(min(rows, len(ds))):
154
+ sample_data.append({k: str(v)[:100] if v else None for k, v in ds[i].items()})
155
 
156
  # Detect task type and suggest column mapping
157
+ detected_tasks, suggested_mapping = detect_task_and_mapping(ds)
158
+
159
+ # Get all splits
160
+ try:
161
+ from datasets import load_dataset_builder
162
+ builder = load_dataset_builder(dataset_name, trust_remote_code=True)
163
+ splits = list(builder.info.splits.keys())
164
+ except:
165
+ splits = [split]
166
 
167
  return DatasetPreviewResponse(
168
  dataset_name=dataset_name,
169
  config=config,
170
+ splits=splits,
171
  columns=columns,
172
  sample_data=sample_data,
173
  total_rows=len(ds),
 
176
  )
177
 
178
  except Exception as e:
179
+ logger.error(f"Error loading dataset: {e}")
180
  raise HTTPException(status_code=400, detail=f"Error loading dataset: {str(e)}")
181
 
182
 
183
+ def detect_task_and_mapping(dataset) -> tuple:
184
  """Detect suitable task types and suggest column mappings."""
185
+ col_names_lower = [c.lower() for c in dataset.column_names]
186
+ col_names_original = list(dataset.column_names)
187
  detected_tasks = []
188
+ mapping = {}
189
 
190
+ # Build a mapping from lowercase to original
191
+ col_map = {c.lower(): c for c in col_names_original}
 
192
 
193
+ # Check for common patterns
 
 
 
 
 
194
 
195
+ # Text classification
196
+ if 'label' in col_names_lower and 'text' in col_names_lower:
197
+ detected_tasks.append("text-classification")
198
+ mapping['label'] = col_map['label']
199
+ mapping['text'] = col_map['text']
200
 
201
+ # QA
202
+ if 'question' in col_names_lower and 'answer' in col_names_lower:
203
  detected_tasks.append("question-answering")
204
+ mapping['question'] = col_map['question']
205
+ mapping['answer'] = col_map.get('answer', col_map.get('answers', ''))
206
+ if 'context' in col_names_lower:
207
+ mapping['context'] = col_map['context']
208
 
209
+ # Instruction-output
210
+ if 'instruction' in col_names_lower:
 
 
 
211
  detected_tasks.append("causal-lm")
212
+ mapping['instruction'] = col_map['instruction']
213
+ if 'input' in col_names_lower:
214
+ mapping['input'] = col_map['input']
215
+ if 'output' in col_names_lower:
216
+ mapping['output'] = col_map['output']
217
+
218
+ # Input-output
219
+ if 'input' in col_names_lower and 'output' in col_names_lower:
220
+ if 'causal-lm' not in detected_tasks:
221
+ detected_tasks.append("causal-lm")
222
+ mapping['input'] = col_map['input']
223
+ mapping['output'] = col_map['output']
 
 
 
224
 
225
+ # Reasoning
226
+ if 'reasoning' in col_names_lower or 'thinking' in col_names_lower:
227
+ detected_tasks.append("reasoning")
228
+ if 'reasoning' in col_names_lower:
229
+ mapping['reasoning'] = col_map['reasoning']
230
 
231
+ # Default
232
  if not detected_tasks:
233
+ detected_tasks.append("causal-lm")
234
+ # Use first text-like column
235
+ for col in col_names_original:
236
+ if len(dataset) > 0 and isinstance(dataset[0].get(col), str):
237
+ mapping['text'] = col
238
+ break
239
 
240
  return detected_tasks, mapping
241
 
242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  # ============================================
244
  # TRAINING JOB ENDPOINTS
245
  # ============================================
246
 
247
  @router.post("/start", response_model=TrainingJobResponse)
248
  async def start_training(
249
+ request: TrainingRequestSimple,
250
  db: AsyncSession = Depends(get_db)
251
  ):
252
  """Start a new training job."""
 
257
  training_job = TrainingJob(
258
  job_id=job_id,
259
  name=request.name,
 
260
  task_type=request.task_type,
261
  base_model=request.base_model,
 
 
262
  dataset_name=request.dataset.name,
 
263
  dataset_split=request.dataset.train_split,
264
+ training_args={
265
+ "epochs": request.training_args.epochs,
266
+ "batch_size": request.training_args.batch_size,
267
+ "learning_rate": request.training_args.learning_rate,
268
+ "warmup_steps": request.training_args.warmup_steps,
269
+ },
270
  peft_config=request.peft_config.dict() if request.peft_config else None,
271
  status=JobStatus.PENDING.value,
272
  total_epochs=request.training_args.epochs,
 
273
  )
274
 
275
  db.add(training_job)
276
  await db.commit()
277
 
278
+ # Build full config for training service
279
  config = {
280
  "job_id": job_id,
281
  "task_type": request.task_type,
282
  "base_model": request.base_model,
283
+ "model_name": request.base_model,
284
+ "dataset_name": request.dataset.name,
285
+ "dataset": {
286
+ "name": request.dataset.name,
287
+ "train_split": request.dataset.train_split,
288
+ "validation_split": request.dataset.validation_split,
289
+ "column_mapping": request.dataset.column_mapping,
290
+ "max_length": request.dataset.max_length,
291
+ },
292
+ "training_args": {
293
+ "epochs": request.training_args.epochs,
294
+ "batch_size": request.training_args.batch_size,
295
+ "learning_rate": request.training_args.learning_rate,
296
+ "warmup_steps": request.training_args.warmup_steps,
297
+ },
298
  "peft_config": request.peft_config.dict() if request.peft_config else None,
299
+ "prompt_template": request.prompt_template.dict() if request.prompt_template else {"preset": "none"},
300
  }
301
 
302
  # Submit to queue
303
+ priority = JobPriority.NORMAL
304
  await queue.submit(config, priority=priority)
305
 
306
  # Update status
 
323
  db: AsyncSession = Depends(get_db)
324
  ):
325
  """Get status of a training job."""
 
 
 
326
  result = await db.execute(
327
  select(TrainingJob).where(TrainingJob.job_id == job_id)
328
  )
 
331
  if not job:
332
  raise HTTPException(status_code=404, detail="Job not found")
333
 
 
 
 
 
 
334
  return JobStatusResponse(
335
  job_id=job.job_id,
336
  name=job.name,
337
  status=job.status,
338
+ progress=job.progress or 0.0,
339
+ current_epoch=job.current_epoch or 0,
340
+ total_epochs=job.total_epochs or 0,
341
+ current_step=job.current_step or 0,
342
+ total_steps=job.total_steps or 0,
343
  train_loss=job.train_loss,
344
  eval_loss=job.eval_loss,
345
  metrics=job.metrics or {},
 
352
  )
353
 
354
 
355
+ @router.get("/jobs")
356
  async def list_jobs(
357
  status: Optional[str] = None,
358
  limit: int = 50,
 
370
  result = await db.execute(query)
371
  jobs = result.scalars().all()
372
 
373
+ return {
374
+ "jobs": [
375
+ {
376
+ "job_id": job.job_id,
377
+ "name": job.name,
378
+ "status": job.status,
379
+ "progress": job.progress or 0.0,
380
+ "current_epoch": job.current_epoch or 0,
381
+ "total_epochs": job.total_epochs or 0,
382
+ "current_step": job.current_step or 0,
383
+ "total_steps": job.total_steps or 0,
384
+ "train_loss": job.train_loss,
385
+ "eval_loss": job.eval_loss,
386
+ "metrics": job.metrics or {},
387
+ "error_message": job.error_message,
388
+ "created_at": job.created_at.isoformat() if job.created_at else None,
389
+ "started_at": job.started_at.isoformat() if job.started_at else None,
390
+ "completed_at": job.completed_at.isoformat() if job.completed_at else None,
391
+ "model_name": job.base_model,
392
+ "dataset_name": job.dataset_name,
393
+ }
394
+ for job in jobs
395
+ ]
396
+ }
397
 
398
 
399
  @router.post("/cancel/{job_id}")
 
405
  queue = get_queue()
406
  cancelled = await queue.cancel_job(job_id)
407
 
 
 
 
408
  result = await db.execute(
409
  select(TrainingJob).where(TrainingJob.job_id == job_id)
410
  )
 
417
  return {"message": f"Job {job_id} cancelled", "success": True}
418
 
419
 
 
 
 
 
 
 
420
  @router.get("/queue/status")
421
  async def get_queue_status():
422
  """Get current queue status."""
 
428
  }
429
 
430
 
431
+ @router.get("/templates")
432
+ async def get_training_templates():
433
+ """Get available training configuration templates."""
434
+ return TRAINING_TEMPLATES
435
+
436
+
437
  @router.get("/metrics/{job_id}")
438
  async def get_job_metrics(
439
  job_id: str,
 
452
  "job_id": job_id,
453
  "train_loss": job.train_loss,
454
  "eval_loss": job.eval_loss,
455
+ "metrics": job.metrics
 
456
  }
457
 
458
 
 
476
  await db.delete(job)
477
  await db.commit()
478
 
479
+ return {"message": f"Job {job_id} deleted", "success": True}
480
+
481
+
482
+ # ============================================
483
+ # PROMPT TEMPLATE ENDPOINTS
484
+ # ============================================
485
+
486
+ @router.get("/prompt-templates")
487
+ async def get_prompt_templates():
488
+ """Get available prompt template presets."""
489
+ return {
490
+ "presets": [
491
+ {"id": "none", "name": "None (Raw Text)", "description": "Use dataset text directly"},
492
+ {"id": "alpaca", "name": "Alpaca Format", "description": "Instruction-Input-Output"},
493
+ {"id": "chatml", "name": "ChatML", "description": "ChatML format"},
494
+ {"id": "llama3", "name": "Llama 3", "description": "Llama 3 instruction format"},
495
+ {"id": "mistral", "name": "Mistral", "description": "Mistral instruction format"},
496
+ {"id": "vicuna", "name": "Vicuna", "description": "Vicuna chat format"},
497
+ {"id": "phi3", "name": "Phi-3", "description": "Microsoft Phi-3 format"},
498
+ {"id": "reasoning", "name": "Reasoning/CoT", "description": "Chain-of-thought"}
499
+ ]
500
+ }