Spaces:
Running
Running
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import re | |
import textwrap | |
class CodeGenerator: | |
def __init__(self): | |
print("Initializing Code Generator...") | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {self.device}") | |
# Load model and tokenizer | |
self.model_name = "microsoft/CodeGPT-small-py-adaptedGPT2" | |
print(f"Loading model {self.model_name}...") | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
self.model_name, | |
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 | |
).to(self.device) | |
print(f"Model loaded and moved to {self.device}") | |
def generate_code(self, prompt, max_length=150, temperature=0.7, top_p=0.95): | |
""" | |
Generate code based on the given prompt | |
Args: | |
prompt (str): The prompt describing the code to generate | |
max_length (int): Maximum length of the generated code | |
temperature (float): Controls randomness in generation | |
top_p (float): Controls diversity of generation | |
Returns: | |
str: Generated code | |
""" | |
try: | |
print(f"Generating code on {self.device}...") | |
# Format prompt for better code generation | |
formatted_prompt = f"# Python\n# Task: {prompt}\n# Solution:\n" | |
inputs = self.tokenizer( | |
formatted_prompt, | |
return_tensors="pt", | |
truncation=True, | |
max_length=512 | |
).to(self.device) | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_length=max_length + len(inputs["input_ids"][0]), | |
temperature=temperature, | |
top_p=top_p, | |
num_return_sequences=1, | |
pad_token_id=self.tokenizer.eos_token_id, | |
do_sample=True, | |
repetition_penalty=1.1, | |
no_repeat_ngram_size=3 | |
) | |
generated_code = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Remove the prompt from the generated code | |
generated_code = generated_code[len(formatted_prompt):] | |
# Format the code | |
formatted_code = self._format_code(generated_code) | |
return formatted_code | |
except Exception as e: | |
return f"Error generating code: {str(e)}" | |
def _format_code(self, code): | |
""" | |
Format the generated code for better readability | |
Args: | |
code (str): The code to format | |
Returns: | |
str: Formatted code | |
""" | |
# Remove any trailing whitespace | |
code = code.strip() | |
# Split into lines and remove duplicates | |
lines = code.split('\n') | |
unique_lines = [] | |
seen_lines = set() | |
for line in lines: | |
stripped_line = line.strip() | |
if stripped_line and stripped_line not in seen_lines: | |
seen_lines.add(stripped_line) | |
unique_lines.append(line) | |
# Fix common indentation issues | |
formatted_lines = [] | |
# Track indentation level | |
indent_level = 0 | |
for line in unique_lines: | |
# Skip empty lines | |
if not line.strip(): | |
formatted_lines.append('') | |
continue | |
# Calculate current indentation | |
current_indent = len(line) - len(line.lstrip()) | |
# Handle indentation changes | |
if line.strip().endswith(':'): | |
# Increase indent after colons | |
indent_level = current_indent + 4 | |
elif current_indent > indent_level: | |
# Decrease indent if too deep | |
indent_level = max(0, indent_level - 4) | |
# Apply proper indentation | |
formatted_line = ' ' * indent_level + line.lstrip() | |
formatted_lines.append(formatted_line) | |
# Join lines with proper spacing | |
formatted_code = '\n'.join(formatted_lines) | |
# Add docstrings if missing | |
if 'def ' in formatted_code and '"""' not in formatted_code: | |
formatted_code = self._add_docstrings(formatted_code) | |
# Ensure proper spacing between functions/classes | |
formatted_code = re.sub(r'\n{3,}', '\n\n', formatted_code) | |
# Remove any duplicate code blocks | |
formatted_code = self._remove_duplicate_blocks(formatted_code) | |
return formatted_code | |
def _remove_duplicate_blocks(self, code): | |
""" | |
Remove duplicate code blocks | |
Args: | |
code (str): The code to clean | |
Returns: | |
str: Code with duplicates removed | |
""" | |
# Split into blocks (functions/classes) | |
blocks = re.split(r'(?=\n\s*(?:def|class)\s)', code) | |
unique_blocks = [] | |
seen_blocks = set() | |
for block in blocks: | |
# Normalize block by removing whitespace | |
normalized = re.sub(r'\s+', ' ', block.strip()) | |
if normalized and normalized not in seen_blocks: | |
seen_blocks.add(normalized) | |
unique_blocks.append(block) | |
return ''.join(unique_blocks).strip() | |
def _add_docstrings(self, code): | |
""" | |
Add docstrings to functions if missing | |
Args: | |
code (str): The code to add docstrings to | |
Returns: | |
str: Code with docstrings | |
""" | |
lines = code.split('\n') | |
formatted_lines = [] | |
i = 0 | |
while i < len(lines): | |
line = lines[i] | |
formatted_lines.append(line) | |
# Check for function definition | |
if line.strip().startswith('def '): | |
# Add docstring if next line doesn't have one | |
if i + 1 < len(lines) and '"""' not in lines[i + 1]: | |
indent = len(line) - len(line.lstrip()) | |
docstring = f'{indent * " "} """\n{indent * " "} Docstring\n{indent * " "} """' | |
formatted_lines.append(docstring) | |
i += 1 | |
return '\n'.join(formatted_lines) |