Spaces:
Running
Running
import transformers | |
from transformers import pipeline | |
class CodeGenerator: | |
def __init__(self, model_name="bigscience/T0_3B"): | |
""" | |
Initializes the CodeGenerator with a specified model. | |
Args: | |
model_name (str): The name of the model to be used for code generation. | |
""" | |
self.model = transformers.AutoModelForCausalLM.from_pretrained(model_name) | |
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) | |
def generate_code(self, idea): | |
""" | |
Generates code based on a given idea using the specified model. | |
Args: | |
idea (str): The idea for the code to be generated. | |
Returns: | |
str: The generated code. | |
""" | |
input_text = self._format_input(idea) | |
input_ids = self.tokenizer.encode(input_text, return_tensors="pt") | |
output_sequences = self._generate_output(input_ids) | |
generated_code = self._extract_code(output_sequences) | |
return generated_code | |
def _format_input(self, idea): | |
""" | |
Formats the input text for the model. | |
Args: | |
idea (str): The idea for the code to be generated. | |
Returns: | |
str: Formatted input text. | |
""" | |
return f"# Idea: {idea}\n# Code:\n" | |
def _generate_output(self, input_ids): | |
""" | |
Generates output sequences from the model. | |
Args: | |
input_ids (tensor): The input IDs for the model. | |
Returns: | |
tensor: The generated output sequences. | |
""" | |
return self.model.generate( | |
input_ids=input_ids, | |
max_length=1024, | |
num_return_sequences=1, | |
no_repeat_ngram_size=2, | |
early_stopping=True, | |
temperature=0.7, | |
top_k=50, | |
) | |
def _extract_code(self, output_sequences): | |
""" | |
Extracts the generated code from the output sequences. | |
Args: | |
output_sequences (tensor): The generated output sequences. | |
Returns: | |
str: The extracted code. | |
""" | |
generated_code = self.tokenizer.decode(output_sequences[0], skip_special_tokens=True) | |
return generated_code.split("\n# Code:")[1].strip() | |
# Example usage | |
if __name__ == "__main__": | |
idea = "Write a Python function to calculate the factorial of a number" | |
code_generator = CodeGenerator() | |
generated_code = code_generator.generate_code(idea) | |
print(generated_code) |