| | """ |
| | Helion-OSC Inference Script |
| | DeepXR/Helion-OSC - Mathematical Coding Language Model |
| | |
| | This module provides comprehensive inference capabilities for the Helion-OSC model, |
| | including specialized methods for different programming and mathematical tasks. |
| | """ |
| |
|
| | import torch |
| | import json |
| | import logging |
| | from typing import Optional, Dict, Any, List, Union |
| | from transformers import ( |
| | AutoTokenizer, |
| | AutoModelForCausalLM, |
| | GenerationConfig, |
| | StoppingCriteria, |
| | StoppingCriteriaList |
| | ) |
| | from dataclasses import dataclass |
| | import warnings |
| |
|
| | |
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | @dataclass |
| | class GenerationParameters: |
| | """Parameters for text generation""" |
| | max_length: int = 2048 |
| | temperature: float = 0.7 |
| | top_p: float = 0.95 |
| | top_k: int = 50 |
| | repetition_penalty: float = 1.05 |
| | length_penalty: float = 1.0 |
| | do_sample: bool = True |
| | num_return_sequences: int = 1 |
| | early_stopping: bool = False |
| |
|
| |
|
| | class CodeStoppingCriteria(StoppingCriteria): |
| | """Custom stopping criteria for code generation""" |
| | |
| | def __init__(self, stop_sequences: List[str], tokenizer): |
| | self.stop_sequences = stop_sequences |
| | self.tokenizer = tokenizer |
| | |
| | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
| | decoded = self.tokenizer.decode(input_ids[0], skip_special_tokens=True) |
| | return any(seq in decoded for seq in self.stop_sequences) |
| |
|
| |
|
| | class HelionOSCInference: |
| | """ |
| | Comprehensive inference wrapper for Helion-OSC model |
| | |
| | Supports multiple generation modes: |
| | - Code generation |
| | - Mathematical reasoning |
| | - Algorithm design |
| | - Code debugging |
| | - Documentation generation |
| | """ |
| | |
| | def __init__( |
| | self, |
| | model_name: str = "DeepXR/Helion-OSC", |
| | device: Optional[str] = None, |
| | load_in_8bit: bool = False, |
| | load_in_4bit: bool = False, |
| | use_flash_attention: bool = True, |
| | trust_remote_code: bool = True |
| | ): |
| | """ |
| | Initialize the Helion-OSC model |
| | |
| | Args: |
| | model_name: HuggingFace model identifier |
| | device: Device to load model on (cuda/cpu/mps) |
| | load_in_8bit: Load model in 8-bit precision |
| | load_in_4bit: Load model in 4-bit precision |
| | use_flash_attention: Use flash attention for faster inference |
| | trust_remote_code: Trust remote code from model repository |
| | """ |
| | self.model_name = model_name |
| | self.device = self._get_device(device) |
| | self.load_in_8bit = load_in_8bit |
| | self.load_in_4bit = load_in_4bit |
| | |
| | logger.info(f"Initializing Helion-OSC on {self.device}...") |
| | |
| | |
| | self.tokenizer = self._load_tokenizer(trust_remote_code) |
| | |
| | |
| | self.model = self._load_model( |
| | use_flash_attention=use_flash_attention, |
| | trust_remote_code=trust_remote_code |
| | ) |
| | |
| | |
| | self.generation_configs = self._load_generation_configs() |
| | |
| | logger.info("Model loaded successfully!") |
| | self._print_model_info() |
| | |
| | def _get_device(self, device: Optional[str]) -> str: |
| | """Determine the best available device""" |
| | if device: |
| | return device |
| | if torch.cuda.is_available(): |
| | return "cuda" |
| | elif torch.backends.mps.is_available(): |
| | return "mps" |
| | return "cpu" |
| | |
| | def _load_tokenizer(self, trust_remote_code: bool): |
| | """Load and configure tokenizer""" |
| | logger.info("Loading tokenizer...") |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | self.model_name, |
| | trust_remote_code=trust_remote_code, |
| | padding_side="left" |
| | ) |
| | |
| | |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| | |
| | return tokenizer |
| | |
| | def _load_model(self, use_flash_attention: bool, trust_remote_code: bool): |
| | """Load and configure model""" |
| | logger.info("Loading model...") |
| | |
| | model_kwargs = { |
| | "trust_remote_code": trust_remote_code, |
| | "low_cpu_mem_usage": True |
| | } |
| | |
| | |
| | if self.load_in_8bit: |
| | model_kwargs["load_in_8bit"] = True |
| | logger.info("Loading in 8-bit precision") |
| | elif self.load_in_4bit: |
| | model_kwargs["load_in_4bit"] = True |
| | model_kwargs["bnb_4bit_compute_dtype"] = torch.bfloat16 |
| | model_kwargs["bnb_4bit_use_double_quant"] = True |
| | model_kwargs["bnb_4bit_quant_type"] = "nf4" |
| | logger.info("Loading in 4-bit precision") |
| | else: |
| | if self.device == "cuda": |
| | model_kwargs["torch_dtype"] = torch.bfloat16 |
| | else: |
| | model_kwargs["torch_dtype"] = torch.float32 |
| | |
| | |
| | if self.device == "cuda" and not (self.load_in_8bit or self.load_in_4bit): |
| | model_kwargs["device_map"] = "auto" |
| | |
| | |
| | model = AutoModelForCausalLM.from_pretrained( |
| | self.model_name, |
| | **model_kwargs |
| | ) |
| | |
| | |
| | if self.device != "cuda" or (self.load_in_8bit or self.load_in_4bit): |
| | if not (self.load_in_8bit or self.load_in_4bit): |
| | model = model.to(self.device) |
| | |
| | model.eval() |
| | |
| | |
| | if hasattr(model, 'gradient_checkpointing_enable'): |
| | model.gradient_checkpointing_enable() |
| | |
| | return model |
| | |
| | def _load_generation_configs(self) -> Dict[str, GenerationParameters]: |
| | """Load task-specific generation configurations""" |
| | return { |
| | "code_generation": GenerationParameters( |
| | max_length=4096, |
| | temperature=0.7, |
| | top_p=0.95, |
| | top_k=50, |
| | repetition_penalty=1.05, |
| | do_sample=True |
| | ), |
| | "mathematical_reasoning": GenerationParameters( |
| | max_length=2048, |
| | temperature=0.3, |
| | top_p=0.9, |
| | top_k=40, |
| | repetition_penalty=1.0, |
| | do_sample=False |
| | ), |
| | "code_completion": GenerationParameters( |
| | max_length=1024, |
| | temperature=0.6, |
| | top_p=0.92, |
| | top_k=45, |
| | repetition_penalty=1.03, |
| | do_sample=True |
| | ), |
| | "algorithm_design": GenerationParameters( |
| | max_length=3072, |
| | temperature=0.5, |
| | top_p=0.93, |
| | top_k=50, |
| | repetition_penalty=1.08, |
| | do_sample=True |
| | ), |
| | "debugging": GenerationParameters( |
| | max_length=2048, |
| | temperature=0.4, |
| | top_p=0.88, |
| | repetition_penalty=1.0, |
| | do_sample=False |
| | ) |
| | } |
| | |
| | def _print_model_info(self): |
| | """Print model information""" |
| | try: |
| | num_params = sum(p.numel() for p in self.model.parameters()) |
| | logger.info(f"Model parameters: {num_params:,}") |
| | logger.info(f"Model dtype: {next(self.model.parameters()).dtype}") |
| | logger.info(f"Device: {self.device}") |
| | except Exception as e: |
| | logger.warning(f"Could not get model info: {e}") |
| | |
| | def generate( |
| | self, |
| | prompt: Union[str, List[str]], |
| | task_type: str = "code_generation", |
| | custom_params: Optional[GenerationParameters] = None, |
| | stop_sequences: Optional[List[str]] = None, |
| | return_full_text: bool = False, |
| | **kwargs |
| | ) -> Union[str, List[str]]: |
| | """ |
| | Generate text based on prompt |
| | |
| | Args: |
| | prompt: Input prompt or list of prompts |
| | task_type: Type of task (code_generation, mathematical_reasoning, etc.) |
| | custom_params: Custom generation parameters |
| | stop_sequences: List of sequences to stop generation |
| | return_full_text: Whether to return full text including prompt |
| | **kwargs: Additional generation parameters |
| | |
| | Returns: |
| | Generated text or list of generated texts |
| | """ |
| | |
| | if custom_params: |
| | params = custom_params |
| | elif task_type in self.generation_configs: |
| | params = self.generation_configs[task_type] |
| | else: |
| | logger.warning(f"Unknown task type '{task_type}', using default parameters") |
| | params = GenerationParameters() |
| | |
| | |
| | for key, value in kwargs.items(): |
| | if hasattr(params, key): |
| | setattr(params, key, value) |
| | |
| | |
| | is_batch = isinstance(prompt, list) |
| | inputs = self.tokenizer( |
| | prompt, |
| | return_tensors="pt", |
| | padding=True, |
| | truncation=True, |
| | max_length=self.model.config.max_position_embeddings |
| | ).to(self.device) |
| | |
| | |
| | stopping_criteria = None |
| | if stop_sequences: |
| | stopping_criteria = StoppingCriteriaList([ |
| | CodeStoppingCriteria(stop_sequences, self.tokenizer) |
| | ]) |
| | |
| | |
| | with torch.no_grad(): |
| | outputs = self.model.generate( |
| | **inputs, |
| | max_length=params.max_length, |
| | temperature=params.temperature, |
| | top_p=params.top_p, |
| | top_k=params.top_k, |
| | repetition_penalty=params.repetition_penalty, |
| | length_penalty=params.length_penalty, |
| | do_sample=params.do_sample, |
| | num_return_sequences=params.num_return_sequences, |
| | early_stopping=params.early_stopping, |
| | pad_token_id=self.tokenizer.pad_token_id, |
| | eos_token_id=self.tokenizer.eos_token_id, |
| | stopping_criteria=stopping_criteria |
| | ) |
| | |
| | |
| | generated_texts = [] |
| | for output in outputs: |
| | text = self.tokenizer.decode(output, skip_special_tokens=True) |
| | if not return_full_text and not is_batch: |
| | |
| | if isinstance(prompt, str): |
| | text = text[len(prompt):].strip() |
| | generated_texts.append(text) |
| | |
| | return generated_texts if is_batch or params.num_return_sequences > 1 else generated_texts[0] |
| | |
| | def code_generation( |
| | self, |
| | prompt: str, |
| | language: Optional[str] = None, |
| | max_length: int = 4096, |
| | **kwargs |
| | ) -> str: |
| | """ |
| | Generate code for a given prompt |
| | |
| | Args: |
| | prompt: Code generation prompt |
| | language: Programming language (optional) |
| | max_length: Maximum length of generated code |
| | **kwargs: Additional generation parameters |
| | |
| | Returns: |
| | Generated code |
| | """ |
| | if language: |
| | prompt = f"Language: {language}\n{prompt}" |
| | |
| | return self.generate( |
| | prompt, |
| | task_type="code_generation", |
| | max_length=max_length, |
| | **kwargs |
| | ) |
| | |
| | def mathematical_reasoning( |
| | self, |
| | prompt: str, |
| | max_length: int = 2048, |
| | **kwargs |
| | ) -> str: |
| | """ |
| | Solve mathematical problems with step-by-step reasoning |
| | |
| | Args: |
| | prompt: Mathematical problem |
| | max_length: Maximum length of solution |
| | **kwargs: Additional generation parameters |
| | |
| | Returns: |
| | Mathematical solution with reasoning |
| | """ |
| | return self.generate( |
| | prompt, |
| | task_type="mathematical_reasoning", |
| | max_length=max_length, |
| | **kwargs |
| | ) |
| | |
| | def algorithm_design( |
| | self, |
| | prompt: str, |
| | include_complexity: bool = True, |
| | max_length: int = 3072, |
| | **kwargs |
| | ) -> str: |
| | """ |
| | Design algorithms with complexity analysis |
| | |
| | Args: |
| | prompt: Algorithm design prompt |
| | include_complexity: Whether to include complexity analysis |
| | max_length: Maximum length of output |
| | **kwargs: Additional generation parameters |
| | |
| | Returns: |
| | Algorithm design with analysis |
| | """ |
| | if include_complexity: |
| | prompt += "\n\nPlease include time and space complexity analysis." |
| | |
| | return self.generate( |
| | prompt, |
| | task_type="algorithm_design", |
| | max_length=max_length, |
| | **kwargs |
| | ) |
| | |
| | def debug_code( |
| | self, |
| | code: str, |
| | error_message: Optional[str] = None, |
| | max_length: int = 2048, |
| | **kwargs |
| | ) -> str: |
| | """ |
| | Debug code and provide fixes |
| | |
| | Args: |
| | code: Code to debug |
| | error_message: Optional error message |
| | max_length: Maximum length of output |
| | **kwargs: Additional generation parameters |
| | |
| | Returns: |
| | Debugging analysis and fixes |
| | """ |
| | prompt = f"Debug the following code:\n\n```\n{code}\n```" |
| | if error_message: |
| | prompt += f"\n\nError message: {error_message}" |
| | prompt += "\n\nProvide a detailed explanation and fixed code." |
| | |
| | return self.generate( |
| | prompt, |
| | task_type="debugging", |
| | max_length=max_length, |
| | **kwargs |
| | ) |
| | |
| | def complete_code( |
| | self, |
| | code_context: str, |
| | max_length: int = 1024, |
| | **kwargs |
| | ) -> str: |
| | """ |
| | Complete partial code |
| | |
| | Args: |
| | code_context: Partial code to complete |
| | max_length: Maximum length of completion |
| | **kwargs: Additional generation parameters |
| | |
| | Returns: |
| | Code completion |
| | """ |
| | return self.generate( |
| | code_context, |
| | task_type="code_completion", |
| | max_length=max_length, |
| | stop_sequences=["\n\n", "```", "###"], |
| | **kwargs |
| | ) |
| | |
| | def batch_generate( |
| | self, |
| | prompts: List[str], |
| | task_type: str = "code_generation", |
| | batch_size: int = 4, |
| | **kwargs |
| | ) -> List[str]: |
| | """ |
| | Generate responses for multiple prompts in batches |
| | |
| | Args: |
| | prompts: List of prompts |
| | task_type: Type of task |
| | batch_size: Batch size for processing |
| | **kwargs: Additional generation parameters |
| | |
| | Returns: |
| | List of generated responses |
| | """ |
| | results = [] |
| | for i in range(0, len(prompts), batch_size): |
| | batch = prompts[i:i + batch_size] |
| | batch_results = self.generate(batch, task_type=task_type, **kwargs) |
| | if isinstance(batch_results, str): |
| | batch_results = [batch_results] |
| | results.extend(batch_results) |
| | return results |
| |
|
| |
|
| | def main(): |
| | """Example usage and demonstrations""" |
| | print("=" * 80) |
| | print("Helion-OSC Inference Examples") |
| | print("=" * 80) |
| | |
| | |
| | helion = HelionOSCInference( |
| | load_in_8bit=False, |
| | load_in_4bit=False |
| | ) |
| | |
| | |
| | print("\n" + "=" * 80) |
| | print("Example 1: Code Generation") |
| | print("=" * 80) |
| | code_prompt = """Write a Python function to implement a binary search tree with the following methods: |
| | - insert(value): Insert a new value |
| | - search(value): Search for a value |
| | - delete(value): Delete a value |
| | - inorder_traversal(): Return inorder traversal |
| | |
| | Include proper documentation and type hints.""" |
| | |
| | print(f"\nPrompt:\n{code_prompt}") |
| | print("\nGenerating...") |
| | result = helion.code_generation(code_prompt, language="python") |
| | print(f"\nGenerated Code:\n{result}") |
| | |
| | |
| | print("\n" + "=" * 80) |
| | print("Example 2: Mathematical Reasoning") |
| | print("=" * 80) |
| | math_prompt = """Prove that the sum of the first n natural numbers equals n(n+1)/2 using mathematical induction.""" |
| | |
| | print(f"\nPrompt:\n{math_prompt}") |
| | print("\nGenerating...") |
| | result = helion.mathematical_reasoning(math_prompt) |
| | print(f"\nSolution:\n{result}") |
| | |
| | |
| | print("\n" + "=" * 80) |
| | print("Example 3: Algorithm Design") |
| | print("=" * 80) |
| | algo_prompt = """Design an efficient algorithm to find the longest palindromic substring in a given string.""" |
| | |
| | print(f"\nPrompt:\n{algo_prompt}") |
| | print("\nGenerating...") |
| | result = helion.algorithm_design(algo_prompt, include_complexity=True) |
| | print(f"\nAlgorithm:\n{result}") |
| | |
| | |
| | print("\n" + "=" * 80) |
| | print("Example 4: Code Debugging") |
| | print("=" * 80) |
| | buggy_code = """ |
| | def fibonacci(n): |
| | if n <= 1: |
| | return n |
| | return fibonacci(n-1) + fibonacci(n-2) |
| | |
| | # This is too slow for large n |
| | result = fibonacci(100) |
| | """ |
| | |
| | print(f"\nBuggy Code:\n{buggy_code}") |
| | print("\nGenerating debugging analysis...") |
| | result = helion.debug_code(buggy_code, error_message="Takes too long to compute") |
| | print(f"\nDebug Analysis:\n{result}") |
| | |
| | |
| | print("\n" + "=" * 80) |
| | print("Example 5: Batch Code Generation") |
| | print("=" * 80) |
| | batch_prompts = [ |
| | "Write a Python function to reverse a linked list", |
| | "Write a JavaScript function to debounce API calls", |
| | "Write a Rust function to parse JSON safely" |
| | ] |
| | |
| | print("\nProcessing batch prompts...") |
| | results = helion.batch_generate(batch_prompts, batch_size=2) |
| | for i, (prompt, result) in enumerate(zip(batch_prompts, results), 1): |
| | print(f"\nPrompt {i}: {prompt}") |
| | print(f"Result {i}:\n{result}\n") |
| | |
| | print("=" * 80) |
| | print("Examples completed!") |
| | print("=" * 80) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |