Spaces:
Running
Running
import os | |
import threading | |
# model_wrapper.py - Enhanced version with better debugging | |
from transformers import ( | |
AutoTokenizer, | |
AutoModel, | |
AutoModelForSeq2SeqLM, | |
AutoModelForCausalLM, | |
AutoConfig, | |
pipeline | |
) | |
import torch | |
import logging | |
import os | |
import traceback | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class CodeDebuggerWrapper: | |
def __init__(self, model_name="Girinath11/aiml_code_debug_model"): | |
self.model_name = model_name | |
self.model = None | |
self.tokenizer = None | |
self.model_type = None | |
self.pipeline = None | |
self._ensure_model() | |
def _log_system_info(self): | |
"""Log system information for debugging.""" | |
logger.info(f"Python version: {os.sys.version}") | |
logger.info(f"PyTorch version: {torch.__version__}") | |
try: | |
import transformers | |
logger.info(f"Transformers version: {transformers.__version__}") | |
except: | |
logger.warning("Could not get transformers version") | |
def _ensure_model(self): | |
"""Load model and tokenizer with comprehensive fallback strategies.""" | |
logger.info(f"Starting model loading process for {self.model_name}") | |
self._log_system_info() | |
try: | |
# First, let's inspect the model configuration | |
logger.info("Step 1: Inspecting model configuration...") | |
config = AutoConfig.from_pretrained(self.model_name) | |
logger.info(f"Model architecture: {config.architectures}") | |
logger.info(f"Model type: {config.model_type}") | |
logger.info(f"Config class: {type(config).__name__}") | |
# Load tokenizer | |
logger.info("Step 2: Loading tokenizer...") | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
self.model_name, | |
trust_remote_code=True, | |
use_fast=False # Sometimes fast tokenizers cause issues | |
) | |
# Add special tokens if missing | |
if self.tokenizer.pad_token is None: | |
if self.tokenizer.eos_token is not None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
else: | |
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
logger.info("β Tokenizer loaded successfully") | |
logger.info(f"Vocab size: {len(self.tokenizer)}") | |
logger.info(f"Special tokens: pad={self.tokenizer.pad_token}, eos={self.tokenizer.eos_token}") | |
# Try loading with pipeline first (often more robust) | |
logger.info("Step 3: Attempting pipeline loading...") | |
pipeline_strategies = [ | |
("text2text-generation", lambda: pipeline( | |
"text2text-generation", | |
model=self.model_name, | |
tokenizer=self.tokenizer, | |
trust_remote_code=True, | |
device=-1 # CPU | |
)), | |
("text-generation", lambda: pipeline( | |
"text-generation", | |
model=self.model_name, | |
tokenizer=self.tokenizer, | |
trust_remote_code=True, | |
device=-1 | |
)), | |
] | |
for pipe_type, pipe_func in pipeline_strategies: | |
try: | |
logger.info(f"Trying {pipe_type} pipeline...") | |
self.pipeline = pipe_func() | |
logger.info(f"β Successfully loaded {pipe_type} pipeline") | |
self.model_type = f"{pipe_type}_pipeline" | |
return # Success! | |
except Exception as e: | |
logger.warning(f"β {pipe_type} pipeline failed: {str(e)[:200]}...") | |
# If pipeline fails, try direct model loading | |
logger.info("Step 4: Attempting direct model loading...") | |
loading_strategies = [ | |
# Strategy 1: Based on config type, try the most appropriate loader | |
("Config-based AutoModel", lambda: self._load_based_on_config(config)), | |
# Strategy 2: Force different model types with trust_remote_code | |
("AutoModel + trust_remote_code", lambda: AutoModel.from_pretrained( | |
self.model_name, | |
trust_remote_code=True, | |
torch_dtype=torch.float32, | |
device_map="cpu" | |
)), | |
("AutoModelForCausalLM + trust_remote_code", lambda: AutoModelForCausalLM.from_pretrained( | |
self.model_name, | |
trust_remote_code=True, | |
torch_dtype=torch.float32, | |
device_map="cpu" | |
)), | |
("AutoModelForSeq2SeqLM + trust_remote_code + ignore_mismatched", lambda: AutoModelForSeq2SeqLM.from_pretrained( | |
self.model_name, | |
trust_remote_code=True, | |
torch_dtype=torch.float32, | |
ignore_mismatched_sizes=True, | |
device_map="cpu" | |
)), | |
# Strategy 3: Try without trust_remote_code but with other options | |
("AutoModel + low_cpu_mem", lambda: AutoModel.from_pretrained( | |
self.model_name, | |
torch_dtype=torch.float32, | |
low_cpu_mem_usage=True, | |
device_map="cpu" | |
)), | |
] | |
for strategy_name, strategy_func in loading_strategies: | |
try: | |
logger.info(f"Trying: {strategy_name}") | |
self.model = strategy_func() | |
self.model_type = type(self.model).__name__ | |
logger.info(f"β Successfully loaded model with {strategy_name}") | |
logger.info(f"Model type: {self.model_type}") | |
# Set to eval mode | |
if hasattr(self.model, 'eval'): | |
self.model.eval() | |
return # Success! | |
except Exception as e: | |
logger.warning(f"β {strategy_name} failed: {str(e)[:200]}...") | |
logger.debug(f"Full error: {traceback.format_exc()}") | |
# If we get here, all strategies failed | |
raise RuntimeError("β All model loading strategies failed") | |
except Exception as e: | |
logger.error(f"β Critical error in model loading: {e}") | |
logger.error(f"Full traceback: {traceback.format_exc()}") | |
raise | |
def _load_based_on_config(self, config): | |
"""Try to load model based on its configuration type.""" | |
config_type = type(config).__name__ | |
if "T5" in config_type or "Seq2Seq" in config_type: | |
return AutoModelForSeq2SeqLM.from_pretrained( | |
self.model_name, | |
trust_remote_code=True, | |
config=config | |
) | |
elif "GPT" in config_type or "Causal" in config_type: | |
return AutoModelForCausalLM.from_pretrained( | |
self.model_name, | |
trust_remote_code=True, | |
config=config | |
) | |
else: | |
return AutoModel.from_pretrained( | |
self.model_name, | |
trust_remote_code=True, | |
config=config | |
) | |
def debug(self, code: str) -> str: | |
"""Debug the provided code using the loaded model.""" | |
if not code or not code.strip(): | |
return "β Please provide some code to debug." | |
try: | |
# Use pipeline if available (more robust) | |
if self.pipeline is not None: | |
return self._debug_with_pipeline(code) | |
# Use direct model if pipeline not available | |
if self.model is not None: | |
return self._debug_with_model(code) | |
# Fallback: provide manual debugging suggestions | |
return self._manual_debug_suggestions(code) | |
except Exception as e: | |
logger.error(f"Error during debugging: {e}") | |
return f"β Error during debugging: {str(e)}\n\n" + self._manual_debug_suggestions(code) | |
def _debug_with_pipeline(self, code: str) -> str: | |
"""Debug using pipeline.""" | |
try: | |
prompt = f"Fix this Python code:\n\n{code}\n\nFixed code:" | |
if "text2text" in self.model_type: | |
result = self.pipeline(prompt, max_length=512, num_beams=3, early_stopping=True) | |
return result[0]['generated_text'] if result else self._manual_debug_suggestions(code) | |
elif "text-generation" in self.model_type: | |
result = self.pipeline(prompt, max_new_tokens=256, num_return_sequences=1, temperature=0.7) | |
generated = result[0]['generated_text'] if result else "" | |
# Clean up the response | |
if prompt in generated: | |
generated = generated.replace(prompt, "").strip() | |
return generated if generated else self._manual_debug_suggestions(code) | |
except Exception as e: | |
logger.error(f"Pipeline debugging failed: {e}") | |
return self._manual_debug_suggestions(code) | |
def _debug_with_model(self, code: str) -> str: | |
"""Debug using direct model.""" | |
try: | |
prompt = f"Debug and fix this Python code:\n\n{code}\n\nFixed code:" | |
inputs = self.tokenizer( | |
prompt, | |
return_tensors="pt", | |
max_length=512, | |
truncation=True, | |
padding=True | |
) | |
with torch.no_grad(): | |
if hasattr(self.model, 'generate'): | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=256, | |
num_beams=3, | |
early_stopping=True, | |
pad_token_id=self.tokenizer.pad_token_id, | |
eos_token_id=getattr(self.tokenizer, 'eos_token_id', None), | |
do_sample=True, | |
temperature=0.7 | |
) | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Clean response | |
if prompt in response: | |
response = response.replace(prompt, "").strip() | |
return response if response else self._manual_debug_suggestions(code) | |
else: | |
return f"β οΈ Model type '{self.model_type}' doesn't support generation.\n\n" + self._manual_debug_suggestions(code) | |
except Exception as e: | |
logger.error(f"Direct model debugging failed: {e}") | |
return self._manual_debug_suggestions(code) | |
def _manual_debug_suggestions(self, code: str) -> str: | |
"""Provide manual debugging suggestions when AI model fails.""" | |
suggestions = [] | |
# Check for common Python syntax errors | |
lines = code.split('\n') | |
for i, line in enumerate(lines, 1): | |
line_stripped = line.strip() | |
if not line_stripped or line_stripped.startswith('#'): | |
continue | |
# Check for missing colons | |
if any(keyword in line_stripped for keyword in ['if ', 'for ', 'while ', 'def ', 'class ', 'try:', 'except', 'else', 'elif']): | |
if not line_stripped.endswith(':') and not line_stripped.endswith(':\\'): | |
suggestions.append(f"Line {i}: Missing colon (:) at end of statement") | |
# Check for obvious indentation issues | |
if i > 1 and line_stripped and not line.startswith(' ') and not line.startswith('\t'): | |
prev_line = lines[i-2].strip() if i > 1 else "" | |
if prev_line.endswith(':'): | |
suggestions.append(f"Line {i}: Possible indentation error - code after ':' should be indented") | |
# Check for common runtime errors | |
if 'len(' in code and '[]' in code: | |
suggestions.append("β οΈ Potential division by zero: Check for empty lists before using len()") | |
if '/0' in code or '/ 0' in code: | |
suggestions.append("β οΈ Division by zero detected") | |
# Create response | |
result = f"π§ **Manual Debug Analysis for:**\n```python\n{code}\n```\n\n" | |
if suggestions: | |
result += "**Issues Found:**\n" | |
for suggestion in suggestions: | |
result += f"β’ {suggestion}\n" | |
else: | |
result += "**No obvious syntax errors detected.**\n" | |
result += "\n**General Tips:**\n" | |
result += "β’ Check for missing colons (:) after if/for/def statements\n" | |
result += "β’ Verify proper indentation (4 spaces per level)\n" | |
result += "β’ Ensure all parentheses, brackets, and quotes are balanced\n" | |
result += "β’ Check for typos in variable and function names\n" | |
result += "β’ Make sure all required imports are included\n" | |
return result | |
# Alternative lightweight debugger if the main one fails completely | |
class FallbackDebugger: | |
def __init__(self): | |
self.model = None | |
self.tokenizer = None | |
logger.info("Using fallback debugger - AI model unavailable") | |
def debug(self, code: str) -> str: | |
"""Simple rule-based debugging.""" | |
if not code or not code.strip(): | |
return "β Please provide some code to debug." | |
issues = [] | |
lines = code.split('\n') | |
# Basic syntax checking | |
for i, line in enumerate(lines, 1): | |
stripped = line.strip() | |
if not stripped or stripped.startswith('#'): | |
continue | |
# Missing colons | |
control_words = ['if ', 'elif ', 'else', 'for ', 'while ', 'def ', 'class ', 'try', 'except', 'finally'] | |
if any(word in stripped for word in control_words): | |
if not stripped.endswith(':'): | |
issues.append(f"Line {i}: Missing colon (:)") | |
# Indentation after colon | |
if i < len(lines) and stripped.endswith(':'): | |
next_line = lines[i] if i < len(lines) else "" | |
if next_line.strip() and not next_line.startswith((' ', '\t')): | |
issues.append(f"Line {i+1}: Should be indented after ':'") | |
# Generate response | |
result = f"π§ **Code Analysis** (AI Model Unavailable)\n\n" | |
result += f"```python\n{code}\n```\n\n" | |
if issues: | |
result += "**Potential Issues:**\n" | |
for issue in issues: | |
result += f"β’ {issue}\n" | |
else: | |
result += "**No obvious syntax errors found.**\n" | |
result += "\n**Common Debugging Steps:**\n" | |
result += "1. Run the code to see specific error messages\n" | |
result += "2. Check syntax: colons, indentation, parentheses\n" | |
result += "3. Verify variable names and imports\n" | |
result += "4. Use print() statements to debug logic\n" | |
return result | |