Spaces:
Paused
Paused
| import os | |
| import logging | |
| from typing import List, Dict, Optional | |
| from pathlib import Path | |
| import json | |
| from datetime import datetime | |
| import torch | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| from label_studio_ml.model import LabelStudioMLBase, ModelResponse | |
| from peft import get_peft_model, LoraConfig, PeftModel | |
| import time | |
| logger = logging.getLogger(__name__) | |
| class T5Model(LabelStudioMLBase): | |
| # Class-level configuration | |
| model_name = os.getenv('MODEL_NAME', 'google/flan-t5-base') | |
| max_length = int(os.getenv('MAX_LENGTH', '512')) | |
| generation_max_length = int(os.getenv('GENERATION_MAX_LENGTH', '128')) | |
| num_return_sequences = int(os.getenv('NUM_RETURN_SEQUENCES', '1')) | |
| # Model components (initialized as None) | |
| tokenizer = None | |
| model = None | |
| device = None # Will be set during setup | |
| def setup(self): | |
| """Initialize the T5 model and parse configuration""" | |
| try: | |
| # Parse label config first | |
| text_config, choices_config = self.parse_config(self.label_config) | |
| self.from_name = choices_config.get('name') | |
| self.to_name = text_config.get('name') | |
| # Load tokenizer and model | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) | |
| # Set device after model loading | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if self.device == "cuda": | |
| self.model = self.model.cuda() | |
| # After initializing the base model, try to load the latest fine-tuned version | |
| latest_model_path = self.get_latest_model_path() | |
| if latest_model_path is not None: | |
| try: | |
| logger.info(f"Loading latest model from {latest_model_path}") | |
| self.model = PeftModel.from_pretrained(self.model, latest_model_path) | |
| logger.info("Successfully loaded latest model") | |
| except Exception as e: | |
| logger.error(f"Failed to load latest model: {str(e)}") | |
| # Continue with base model if loading fails | |
| self.model.eval() | |
| logger.info(f"Using device: {self.device}") | |
| logger.info(f"Initialized with from_name={self.from_name}, to_name={self.to_name}") | |
| # Set initial model version | |
| self.set("model_version", "1.0.0") | |
| except Exception as e: | |
| logger.error(f"Error in model setup: {str(e)}") | |
| raise | |
| def parse_config(self, label_config): | |
| """Parse the label config to find nested elements""" | |
| import xml.etree.ElementTree as ET | |
| root = ET.fromstring(label_config) | |
| # Find Text and Choices tags anywhere in the tree | |
| text_tag = root.find('.//Text') | |
| choices_tag = root.find('.//Choices') | |
| text_config = text_tag.attrib if text_tag is not None else {} | |
| choices_config = choices_tag.attrib if choices_tag is not None else {} | |
| return text_config, choices_config | |
| def get_valid_choices(self, label_config): | |
| """Extract valid choice values from label config""" | |
| import xml.etree.ElementTree as ET | |
| root = ET.fromstring(label_config) | |
| choices = root.findall('.//Choice') | |
| return [choice.get('value') for choice in choices] | |
| def get_categories_with_hints(self, label_config): | |
| """Extract categories and their hints from label config""" | |
| import xml.etree.ElementTree as ET | |
| root = ET.fromstring(label_config) | |
| choices = root.findall('.//Choice') | |
| categories = [] | |
| for choice in choices: | |
| categories.append({ | |
| 'value': choice.get('value'), | |
| 'hint': choice.get('hint') | |
| }) | |
| return categories | |
| def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> ModelResponse: | |
| """Generate predictions using T5 model""" | |
| logger.info("Received prediction request") | |
| logger.info(f"Tasks: {json.dumps(tasks, indent=2)}") | |
| predictions = [] | |
| # Get categories with their descriptions | |
| try: | |
| categories = self.get_categories_with_hints(self.label_config) | |
| valid_choices = [cat['value'] for cat in categories] | |
| category_descriptions = [f"{cat['value']}: {cat['hint']}" for cat in categories] | |
| logger.info(f"Valid choices: {valid_choices}") | |
| except Exception as e: | |
| logger.error(f"Error parsing choices: {str(e)}") | |
| # TODO: remove this from all places once we have a valid choices | |
| valid_choices = ["other"] | |
| category_descriptions = ["other: Default category when no others apply"] | |
| try: | |
| for task in tasks: | |
| input_text = task['data'].get(self.to_name) | |
| if not input_text: | |
| logger.warning(f"No input text found using {self.to_name}") | |
| continue | |
| # Format prompt with input text and category descriptions | |
| prompt = f"""Classify the following text into exactly one category. | |
| Available categories with descriptions: | |
| {chr(10).join(f"- {desc}" for desc in category_descriptions)} | |
| Text to classify: {input_text} | |
| Instructions: | |
| 1. Consider the text carefully | |
| 2. Choose the most appropriate category from the list | |
| 3. Return ONLY the category value (e.g. 'business_and_career', 'date', etc.) | |
| 4. Do not add any explanations or additional text | |
| Category:""" | |
| logger.info(f"Generated prompt: {prompt}") | |
| # Generate prediction with prompt | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| max_length=self.max_length, | |
| truncation=True, | |
| padding=True | |
| ).to(self.device) | |
| logger.info("Generating prediction...") | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_length=self.generation_max_length, | |
| num_return_sequences=self.num_return_sequences, | |
| do_sample=True, | |
| temperature=0.7 | |
| ) | |
| predicted_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| logger.info(f"Generated prediction: {predicted_text}") | |
| # Find best matching choice | |
| best_choice = "other" # default fallback | |
| if predicted_text in valid_choices: | |
| best_choice = predicted_text | |
| # Format prediction with valid choice | |
| prediction = { | |
| "result": [{ | |
| "from_name": self.from_name, | |
| "to_name": self.to_name, | |
| "type": "choices", | |
| "value": { | |
| "choices": [best_choice] | |
| } | |
| }], | |
| "model_version": "1.0.0" | |
| } | |
| logger.info(f"Formatted prediction: {json.dumps(prediction, indent=2)}") | |
| predictions.append(prediction) | |
| except Exception as e: | |
| logger.error(f"Error in prediction: {str(e)}", exc_info=True) | |
| raise | |
| logger.info(f"Returning {len(predictions)} predictions") | |
| return predictions | |
| def fit(self, event, data, **kwargs): | |
| """Handle annotation events from Label Studio""" | |
| start_time = time.time() | |
| logger.info("Starting training session...") | |
| valid_events = {'ANNOTATION_CREATED', 'ANNOTATION_UPDATED', 'START_TRAINING'} | |
| if event not in valid_events: | |
| logger.warning(f"Skip training: event {event} is not supported") | |
| return | |
| try: | |
| # Extract text and label | |
| # LS sends two webhooks when training is initiated: | |
| # 1. contains all project data | |
| # 2. contains only the task data | |
| # We need to check which one is present and use the appropriate data | |
| if 'task' in data: | |
| text = data['task']['data']['text'] | |
| label = data['annotation']['result'][0]['value']['choices'][0] | |
| else: | |
| logger.info("Skipping initial project setup webhook") | |
| return | |
| # Configure LoRA | |
| lora_config = LoraConfig( | |
| r=int(os.getenv('LORA_R', '4')), | |
| lora_alpha=int(os.getenv('LORA_ALPHA', '8')), | |
| target_modules=os.getenv('LORA_TARGET_MODULES', 'q,v').split(','), | |
| lora_dropout=float(os.getenv('LORA_DROPOUT', '0.1')), | |
| bias="none", | |
| task_type="SEQ_2_SEQ_LM" | |
| ) | |
| logger.info("Preparing model for training...") | |
| model = get_peft_model(self.model, lora_config) | |
| model.print_trainable_parameters() | |
| # Tokenize inputs first | |
| inputs = self.tokenizer(text, return_tensors="pt", max_length=self.max_length, truncation=True).to(self.device) | |
| labels = self.tokenizer(label, return_tensors="pt", max_length=self.generation_max_length, truncation=True).to(self.device) | |
| # Training loop | |
| logger.info("Starting training loop...") | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=float(os.getenv('LEARNING_RATE', '1e-5'))) | |
| num_epochs = int(os.getenv('NUM_EPOCHS', '6')) | |
| # Add LoRA settings logging here | |
| logger.info("Current LoRA Configuration:") | |
| logger.info(f" - Rank (r): {lora_config.r}") | |
| logger.info(f" - Alpha: {lora_config.lora_alpha}") | |
| logger.info(f" - Target Modules: {lora_config.target_modules}") | |
| logger.info(f" - Dropout: {lora_config.lora_dropout}") | |
| logger.info(f" - Learning Rate: {float(os.getenv('LEARNING_RATE', '1e-4'))}") | |
| logger.info(f" - Number of Epochs: {num_epochs}") | |
| logger.info(f" - Input text length: {len(inputs['input_ids'][0])} tokens") | |
| logger.info(f" - Label length: {len(labels['input_ids'][0])} tokens") | |
| for epoch in range(num_epochs): | |
| logger.info(f"Starting epoch {epoch+1}/{num_epochs}") | |
| model.train() | |
| optimizer.zero_grad() | |
| outputs = model(**inputs, labels=labels["input_ids"]) | |
| loss = outputs.loss | |
| loss.backward() | |
| optimizer.step() | |
| logger.info(f"Epoch {epoch+1}/{num_epochs} completed. Loss: {loss.item():.4f}") | |
| # Save the model | |
| try: | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| model_dir = Path(os.getenv('MODEL_DIR', '/data/models')) | |
| model_dir.mkdir(parents=True, exist_ok=True) | |
| save_path = model_dir / f"model_{timestamp}" | |
| logger.info(f"Saving model to {save_path}") | |
| # Save the full model state | |
| model.save_pretrained( | |
| save_path, | |
| save_function=torch.save, | |
| safe_serialization=True, | |
| save_state_dict=True | |
| ) | |
| logger.info(f"Model successfully saved to {save_path}") | |
| except Exception as e: | |
| logger.error(f"Failed to save model: {str(e)}") | |
| raise | |
| # Save the tokenizer | |
| try: | |
| logger.info(f"Saving tokenizer to {save_path}") | |
| self.tokenizer.save_pretrained(save_path) | |
| logger.info("Tokenizer successfully saved") | |
| except Exception as e: | |
| logger.error(f"Failed to save tokenizer: {str(e)}") | |
| raise | |
| # Switch to eval mode | |
| model.eval() | |
| training_time = time.time() - start_time | |
| logger.info(f"Training session completed successfully in {training_time:.2f} seconds with tag: '{text}' and label: '{label}'") | |
| except Exception as e: | |
| training_time = time.time() - start_time | |
| logger.error(f"Training failed after {training_time:.2f} seconds") | |
| logger.error(f"Error during training: {str(e)}") | |
| raise | |
| def get_latest_model_path(self) -> Path: | |
| """Get the path to the most recently saved model""" | |
| model_dir = Path(os.getenv('MODEL_DIR', '/data/models')) | |
| if not model_dir.exists(): | |
| logger.warning(f"Model directory {model_dir} does not exist") | |
| return None | |
| # Find all model directories (they start with 'model_') | |
| model_paths = list(model_dir.glob("model_*")) | |
| if not model_paths: | |
| logger.warning("No saved models found") | |
| return None | |
| # Sort by creation time and get the most recent | |
| latest_model = max(model_paths, key=lambda x: x.stat().st_mtime) | |
| logger.info(f"Found latest model: {latest_model}") | |
| return latest_model | |