rahulhans's picture
Upload 15 files
a26e606 verified
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import re
import textwrap
class CodeGenerator:
def __init__(self):
print("Initializing Code Generator...")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
# Load model and tokenizer
self.model_name = "microsoft/CodeGPT-small-py-adaptedGPT2"
print(f"Loading model {self.model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
).to(self.device)
print(f"Model loaded and moved to {self.device}")
def generate_code(self, prompt, max_length=150, temperature=0.7, top_p=0.95):
"""
Generate code based on the given prompt
Args:
prompt (str): The prompt describing the code to generate
max_length (int): Maximum length of the generated code
temperature (float): Controls randomness in generation
top_p (float): Controls diversity of generation
Returns:
str: Generated code
"""
try:
print(f"Generating code on {self.device}...")
# Format prompt for better code generation
formatted_prompt = f"# Python\n# Task: {prompt}\n# Solution:\n"
inputs = self.tokenizer(
formatted_prompt,
return_tensors="pt",
truncation=True,
max_length=512
).to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=max_length + len(inputs["input_ids"][0]),
temperature=temperature,
top_p=top_p,
num_return_sequences=1,
pad_token_id=self.tokenizer.eos_token_id,
do_sample=True,
repetition_penalty=1.1,
no_repeat_ngram_size=3
)
generated_code = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove the prompt from the generated code
generated_code = generated_code[len(formatted_prompt):]
# Format the code
formatted_code = self._format_code(generated_code)
return formatted_code
except Exception as e:
return f"Error generating code: {str(e)}"
def _format_code(self, code):
"""
Format the generated code for better readability
Args:
code (str): The code to format
Returns:
str: Formatted code
"""
# Remove any trailing whitespace
code = code.strip()
# Split into lines and remove duplicates
lines = code.split('\n')
unique_lines = []
seen_lines = set()
for line in lines:
stripped_line = line.strip()
if stripped_line and stripped_line not in seen_lines:
seen_lines.add(stripped_line)
unique_lines.append(line)
# Fix common indentation issues
formatted_lines = []
# Track indentation level
indent_level = 0
for line in unique_lines:
# Skip empty lines
if not line.strip():
formatted_lines.append('')
continue
# Calculate current indentation
current_indent = len(line) - len(line.lstrip())
# Handle indentation changes
if line.strip().endswith(':'):
# Increase indent after colons
indent_level = current_indent + 4
elif current_indent > indent_level:
# Decrease indent if too deep
indent_level = max(0, indent_level - 4)
# Apply proper indentation
formatted_line = ' ' * indent_level + line.lstrip()
formatted_lines.append(formatted_line)
# Join lines with proper spacing
formatted_code = '\n'.join(formatted_lines)
# Add docstrings if missing
if 'def ' in formatted_code and '"""' not in formatted_code:
formatted_code = self._add_docstrings(formatted_code)
# Ensure proper spacing between functions/classes
formatted_code = re.sub(r'\n{3,}', '\n\n', formatted_code)
# Remove any duplicate code blocks
formatted_code = self._remove_duplicate_blocks(formatted_code)
return formatted_code
def _remove_duplicate_blocks(self, code):
"""
Remove duplicate code blocks
Args:
code (str): The code to clean
Returns:
str: Code with duplicates removed
"""
# Split into blocks (functions/classes)
blocks = re.split(r'(?=\n\s*(?:def|class)\s)', code)
unique_blocks = []
seen_blocks = set()
for block in blocks:
# Normalize block by removing whitespace
normalized = re.sub(r'\s+', ' ', block.strip())
if normalized and normalized not in seen_blocks:
seen_blocks.add(normalized)
unique_blocks.append(block)
return ''.join(unique_blocks).strip()
def _add_docstrings(self, code):
"""
Add docstrings to functions if missing
Args:
code (str): The code to add docstrings to
Returns:
str: Code with docstrings
"""
lines = code.split('\n')
formatted_lines = []
i = 0
while i < len(lines):
line = lines[i]
formatted_lines.append(line)
# Check for function definition
if line.strip().startswith('def '):
# Add docstring if next line doesn't have one
if i + 1 < len(lines) and '"""' not in lines[i + 1]:
indent = len(line) - len(line.lstrip())
docstring = f'{indent * " "} """\n{indent * " "} Docstring\n{indent * " "} """'
formatted_lines.append(docstring)
i += 1
return '\n'.join(formatted_lines)