File size: 2,563 Bytes
5f2fe16
8e98aef
4ef486c
8e98aef
5f2fe16
4e5f040
a36c7d4
 
8e98aef
5f2fe16
 
4ef486c
 
 
 
 
 
 
 
 
 
 
 
 
e9cd936
a36c7d4
 
4ef486c
 
 
 
 
 
a36c7d4
 
4ef486c
 
 
 
8e98aef
a36c7d4
 
4ef486c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

# Load pretrained model and tokenizer
model_name = "zonghaoyang/DistilRoBERTa-base"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Define function to analyze input code
def analyze_code(input_code):             
    code_str = " ".join(input_code.split())        
    sentences = [s.strip() for s in code_str.split(".") if s.strip()]   
    variables = []              
    functions = []    
    logic = []       
    for sentence in sentences: 
        if "=" in sentence:           
            variables.append(sentence.split("=")[0].strip())       
        elif "(" in sentence:            
            functions.append(sentence.split("(")[0].strip())       
        else:           
            logic.append(sentence)               
    return {"variables": variables, "functions": functions, "logic": logic}

# Define function to generate prompt from analyzed code  
def generate_prompt(code_analysis):       
    prompt = f"Generate code with the following: \n\n"   
    prompt += f"Variables: {', '.join(code_analysis['variables'])} \n\n"   
    prompt += f"Functions: {', '.join(code_analysis['functions'])} \n\n"   
    prompt += f"Logic: {' '.join(code_analysis['logic'])}"  
    return prompt
       
# Generate code from model and prompt  
def generate_code(prompt):
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    generated_ids = model.generate(input_ids, max_length=100, num_beams=5, early_stopping=True)
    generated_code = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    return generated_code

# Suggest improvements to code
def suggest_improvements(code):
    suggestions = ["Use more descriptive variable names", "Add comments to explain complex logic", "Refactor duplicated code into functions"]
    return suggestions

# Main function to integrate the other functions and generate_code
def main_function(input_code):
    code_analysis = analyze_code(input_code)
    prompt = generate_prompt(code_analysis)
    generated_code = generate_code(prompt)
    improvements = suggest_improvements(input_code)
    return generated_code, improvements

# Create Gradio interface
iface = gr.Interface(
    fn=main_function,
    inputs=gr.inputs.Textbox(lines=5, label="Input Code"),
    outputs=[gr.outputs.Textbox(lines=5, label="Generated Code"), gr.outputs.Textbox(lines=5, label="Suggested Improvements")]
)

# Launch Gradio interface
iface.launch()