jina-code-debugger / model_wrapper.py
Girinath11's picture
Update model_wrapper.py
4707188 verified
import os
import threading
# model_wrapper.py - Enhanced version with better debugging
from transformers import (
AutoTokenizer,
AutoModel,
AutoModelForSeq2SeqLM,
AutoModelForCausalLM,
AutoConfig,
pipeline
)
import torch
import logging
import os
import traceback
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class CodeDebuggerWrapper:
def __init__(self, model_name="Girinath11/aiml_code_debug_model"):
self.model_name = model_name
self.model = None
self.tokenizer = None
self.model_type = None
self.pipeline = None
self._ensure_model()
def _log_system_info(self):
"""Log system information for debugging."""
logger.info(f"Python version: {os.sys.version}")
logger.info(f"PyTorch version: {torch.__version__}")
try:
import transformers
logger.info(f"Transformers version: {transformers.__version__}")
except:
logger.warning("Could not get transformers version")
def _ensure_model(self):
"""Load model and tokenizer with comprehensive fallback strategies."""
logger.info(f"Starting model loading process for {self.model_name}")
self._log_system_info()
try:
# First, let's inspect the model configuration
logger.info("Step 1: Inspecting model configuration...")
config = AutoConfig.from_pretrained(self.model_name)
logger.info(f"Model architecture: {config.architectures}")
logger.info(f"Model type: {config.model_type}")
logger.info(f"Config class: {type(config).__name__}")
# Load tokenizer
logger.info("Step 2: Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True,
use_fast=False # Sometimes fast tokenizers cause issues
)
# Add special tokens if missing
if self.tokenizer.pad_token is None:
if self.tokenizer.eos_token is not None:
self.tokenizer.pad_token = self.tokenizer.eos_token
else:
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
logger.info("βœ… Tokenizer loaded successfully")
logger.info(f"Vocab size: {len(self.tokenizer)}")
logger.info(f"Special tokens: pad={self.tokenizer.pad_token}, eos={self.tokenizer.eos_token}")
# Try loading with pipeline first (often more robust)
logger.info("Step 3: Attempting pipeline loading...")
pipeline_strategies = [
("text2text-generation", lambda: pipeline(
"text2text-generation",
model=self.model_name,
tokenizer=self.tokenizer,
trust_remote_code=True,
device=-1 # CPU
)),
("text-generation", lambda: pipeline(
"text-generation",
model=self.model_name,
tokenizer=self.tokenizer,
trust_remote_code=True,
device=-1
)),
]
for pipe_type, pipe_func in pipeline_strategies:
try:
logger.info(f"Trying {pipe_type} pipeline...")
self.pipeline = pipe_func()
logger.info(f"βœ… Successfully loaded {pipe_type} pipeline")
self.model_type = f"{pipe_type}_pipeline"
return # Success!
except Exception as e:
logger.warning(f"❌ {pipe_type} pipeline failed: {str(e)[:200]}...")
# If pipeline fails, try direct model loading
logger.info("Step 4: Attempting direct model loading...")
loading_strategies = [
# Strategy 1: Based on config type, try the most appropriate loader
("Config-based AutoModel", lambda: self._load_based_on_config(config)),
# Strategy 2: Force different model types with trust_remote_code
("AutoModel + trust_remote_code", lambda: AutoModel.from_pretrained(
self.model_name,
trust_remote_code=True,
torch_dtype=torch.float32,
device_map="cpu"
)),
("AutoModelForCausalLM + trust_remote_code", lambda: AutoModelForCausalLM.from_pretrained(
self.model_name,
trust_remote_code=True,
torch_dtype=torch.float32,
device_map="cpu"
)),
("AutoModelForSeq2SeqLM + trust_remote_code + ignore_mismatched", lambda: AutoModelForSeq2SeqLM.from_pretrained(
self.model_name,
trust_remote_code=True,
torch_dtype=torch.float32,
ignore_mismatched_sizes=True,
device_map="cpu"
)),
# Strategy 3: Try without trust_remote_code but with other options
("AutoModel + low_cpu_mem", lambda: AutoModel.from_pretrained(
self.model_name,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
device_map="cpu"
)),
]
for strategy_name, strategy_func in loading_strategies:
try:
logger.info(f"Trying: {strategy_name}")
self.model = strategy_func()
self.model_type = type(self.model).__name__
logger.info(f"βœ… Successfully loaded model with {strategy_name}")
logger.info(f"Model type: {self.model_type}")
# Set to eval mode
if hasattr(self.model, 'eval'):
self.model.eval()
return # Success!
except Exception as e:
logger.warning(f"❌ {strategy_name} failed: {str(e)[:200]}...")
logger.debug(f"Full error: {traceback.format_exc()}")
# If we get here, all strategies failed
raise RuntimeError("❌ All model loading strategies failed")
except Exception as e:
logger.error(f"❌ Critical error in model loading: {e}")
logger.error(f"Full traceback: {traceback.format_exc()}")
raise
def _load_based_on_config(self, config):
"""Try to load model based on its configuration type."""
config_type = type(config).__name__
if "T5" in config_type or "Seq2Seq" in config_type:
return AutoModelForSeq2SeqLM.from_pretrained(
self.model_name,
trust_remote_code=True,
config=config
)
elif "GPT" in config_type or "Causal" in config_type:
return AutoModelForCausalLM.from_pretrained(
self.model_name,
trust_remote_code=True,
config=config
)
else:
return AutoModel.from_pretrained(
self.model_name,
trust_remote_code=True,
config=config
)
def debug(self, code: str) -> str:
"""Debug the provided code using the loaded model."""
if not code or not code.strip():
return "❌ Please provide some code to debug."
try:
# Use pipeline if available (more robust)
if self.pipeline is not None:
return self._debug_with_pipeline(code)
# Use direct model if pipeline not available
if self.model is not None:
return self._debug_with_model(code)
# Fallback: provide manual debugging suggestions
return self._manual_debug_suggestions(code)
except Exception as e:
logger.error(f"Error during debugging: {e}")
return f"❌ Error during debugging: {str(e)}\n\n" + self._manual_debug_suggestions(code)
def _debug_with_pipeline(self, code: str) -> str:
"""Debug using pipeline."""
try:
prompt = f"Fix this Python code:\n\n{code}\n\nFixed code:"
if "text2text" in self.model_type:
result = self.pipeline(prompt, max_length=512, num_beams=3, early_stopping=True)
return result[0]['generated_text'] if result else self._manual_debug_suggestions(code)
elif "text-generation" in self.model_type:
result = self.pipeline(prompt, max_new_tokens=256, num_return_sequences=1, temperature=0.7)
generated = result[0]['generated_text'] if result else ""
# Clean up the response
if prompt in generated:
generated = generated.replace(prompt, "").strip()
return generated if generated else self._manual_debug_suggestions(code)
except Exception as e:
logger.error(f"Pipeline debugging failed: {e}")
return self._manual_debug_suggestions(code)
def _debug_with_model(self, code: str) -> str:
"""Debug using direct model."""
try:
prompt = f"Debug and fix this Python code:\n\n{code}\n\nFixed code:"
inputs = self.tokenizer(
prompt,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True
)
with torch.no_grad():
if hasattr(self.model, 'generate'):
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
num_beams=3,
early_stopping=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=getattr(self.tokenizer, 'eos_token_id', None),
do_sample=True,
temperature=0.7
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean response
if prompt in response:
response = response.replace(prompt, "").strip()
return response if response else self._manual_debug_suggestions(code)
else:
return f"⚠️ Model type '{self.model_type}' doesn't support generation.\n\n" + self._manual_debug_suggestions(code)
except Exception as e:
logger.error(f"Direct model debugging failed: {e}")
return self._manual_debug_suggestions(code)
def _manual_debug_suggestions(self, code: str) -> str:
"""Provide manual debugging suggestions when AI model fails."""
suggestions = []
# Check for common Python syntax errors
lines = code.split('\n')
for i, line in enumerate(lines, 1):
line_stripped = line.strip()
if not line_stripped or line_stripped.startswith('#'):
continue
# Check for missing colons
if any(keyword in line_stripped for keyword in ['if ', 'for ', 'while ', 'def ', 'class ', 'try:', 'except', 'else', 'elif']):
if not line_stripped.endswith(':') and not line_stripped.endswith(':\\'):
suggestions.append(f"Line {i}: Missing colon (:) at end of statement")
# Check for obvious indentation issues
if i > 1 and line_stripped and not line.startswith(' ') and not line.startswith('\t'):
prev_line = lines[i-2].strip() if i > 1 else ""
if prev_line.endswith(':'):
suggestions.append(f"Line {i}: Possible indentation error - code after ':' should be indented")
# Check for common runtime errors
if 'len(' in code and '[]' in code:
suggestions.append("⚠️ Potential division by zero: Check for empty lists before using len()")
if '/0' in code or '/ 0' in code:
suggestions.append("⚠️ Division by zero detected")
# Create response
result = f"πŸ”§ **Manual Debug Analysis for:**\n```python\n{code}\n```\n\n"
if suggestions:
result += "**Issues Found:**\n"
for suggestion in suggestions:
result += f"β€’ {suggestion}\n"
else:
result += "**No obvious syntax errors detected.**\n"
result += "\n**General Tips:**\n"
result += "β€’ Check for missing colons (:) after if/for/def statements\n"
result += "β€’ Verify proper indentation (4 spaces per level)\n"
result += "β€’ Ensure all parentheses, brackets, and quotes are balanced\n"
result += "β€’ Check for typos in variable and function names\n"
result += "β€’ Make sure all required imports are included\n"
return result
# Alternative lightweight debugger if the main one fails completely
class FallbackDebugger:
def __init__(self):
self.model = None
self.tokenizer = None
logger.info("Using fallback debugger - AI model unavailable")
def debug(self, code: str) -> str:
"""Simple rule-based debugging."""
if not code or not code.strip():
return "❌ Please provide some code to debug."
issues = []
lines = code.split('\n')
# Basic syntax checking
for i, line in enumerate(lines, 1):
stripped = line.strip()
if not stripped or stripped.startswith('#'):
continue
# Missing colons
control_words = ['if ', 'elif ', 'else', 'for ', 'while ', 'def ', 'class ', 'try', 'except', 'finally']
if any(word in stripped for word in control_words):
if not stripped.endswith(':'):
issues.append(f"Line {i}: Missing colon (:)")
# Indentation after colon
if i < len(lines) and stripped.endswith(':'):
next_line = lines[i] if i < len(lines) else ""
if next_line.strip() and not next_line.startswith((' ', '\t')):
issues.append(f"Line {i+1}: Should be indented after ':'")
# Generate response
result = f"πŸ”§ **Code Analysis** (AI Model Unavailable)\n\n"
result += f"```python\n{code}\n```\n\n"
if issues:
result += "**Potential Issues:**\n"
for issue in issues:
result += f"β€’ {issue}\n"
else:
result += "**No obvious syntax errors found.**\n"
result += "\n**Common Debugging Steps:**\n"
result += "1. Run the code to see specific error messages\n"
result += "2. Check syntax: colons, indentation, parentheses\n"
result += "3. Verify variable names and imports\n"
result += "4. Use print() statements to debug logic\n"
return result