Not-Grim-Refer commited on
Commit
4ef486c
1 Parent(s): 5f2fe16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -59
app.py CHANGED
@@ -1,12 +1,6 @@
1
  import torch
2
  import gradio as gr
3
- from transformers import RobertaConfig, RobertaModel, AutoModelForSeq2SeqLM, AutoTokenizer
4
-
5
- # Create a configuration object
6
- config = RobertaConfig.from_pretrained('roberta-base')
7
-
8
- # Create the Roberta model
9
- model = RobertaModel.from_pretrained('roberta-base', config=config)
10
 
11
  # Load pretrained model and tokenizer
12
  model_name = "zonghaoyang/DistilRoBERTa-base"
@@ -15,63 +9,54 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
 
16
  # Define function to analyze input code
17
  def analyze_code(input_code):
18
- # Format code into strings and sentences for NLP
19
- code_str = " ".join(input_code.split())
20
- sentences = [s.strip() for s in code_str.split(".") if s.strip()]
21
- #Extract relevant info and intent from code
22
- variables = []
23
- functions = []
24
- logic = []
25
- for sentence in sentences:
26
- if "=" in sentence:
27
- variables.append(sentence.split("=")[0].strip())
28
- elif "(" in sentence:
29
- functions.append(sentence.split("(")[0].strip())
30
- else:
31
- logic.append(sentence)
32
- #Return info and intent in dictionary
33
- return {"variables": variables, "functions": functions, "logic": logic}
34
 
35
  # Define function to generate prompt from analyzed code
36
  def generate_prompt(code_analysis):
37
- prompt = f"Generate code with the following: \n\n"
38
- prompt += f"Variables: {', '.join(code_analysis['variables'])} \n\n"
39
- prompt += f"Functions: {', '.join(code_analysis['functions'])} \n\n"
40
- prompt += f"Logic: {' '.join(code_analysis['logic'])}"
41
- return prompt
42
-
43
  # Generate code from model and prompt
44
  def generate_code(prompt):
45
- generated_code = model.generate(prompt, max_length=100, num_beams=5, early_stopping=True)
46
- return generated_code
 
 
47
 
48
  # Suggest improvements to code
49
  def suggest_improvements(code):
50
- suggestions = ["Use more descriptive variable names", "Add comments to explain complex logic", "Refactor duplicated code into functions"]
51
- return suggestions
52
-
53
- # Define Gradio interface
54
- interface = gr.Interface(fn=generate_code, inputs=["textbox"], outputs=["textbox"])
55
-
56
- # Have a conversation about the code
57
- input_code = """x = 10
58
- y = 5
59
- def add(a, b):
60
- return a + b
61
- result = add(x, y)"""
62
- code_analysis = analyze_code(input_code)
63
- prompt = generate_prompt(code_analysis)
64
- reply = f"{prompt}\n\n{generate_code(prompt)}\n\nSuggested improvements: {', '.join(suggest_improvements(input_code))}"
65
- print(reply)
66
-
67
- while True:
68
- change = input("Would you like to make any changes to the code? (Y/N) ")
69
- if change == "Y":
70
- new_code = input("Enter the updated code: ")
71
- code_analysis = analyze_code(new_code)
72
- prompt = generate_prompt(code_analysis)
73
- reply = f"{prompt}\n\n{generate_code(prompt)}\n\nSuggested improvements: {', '.join(suggest_improvements(new_code))}"
74
- print(reply)
75
- elif change == "N":
76
- print("OK, conversation ended.")
77
- break
 
1
  import torch
2
  import gradio as gr
3
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
 
 
 
 
 
4
 
5
  # Load pretrained model and tokenizer
6
  model_name = "zonghaoyang/DistilRoBERTa-base"
 
9
 
10
  # Define function to analyze input code
11
  def analyze_code(input_code):
12
+ code_str = " ".join(input_code.split())
13
+ sentences = [s.strip() for s in code_str.split(".") if s.strip()]
14
+ variables = []
15
+ functions = []
16
+ logic = []
17
+ for sentence in sentences:
18
+ if "=" in sentence:
19
+ variables.append(sentence.split("=")[0].strip())
20
+ elif "(" in sentence:
21
+ functions.append(sentence.split("(")[0].strip())
22
+ else:
23
+ logic.append(sentence)
24
+ return {"variables": variables, "functions": functions, "logic": logic}
 
 
 
25
 
26
  # Define function to generate prompt from analyzed code
27
  def generate_prompt(code_analysis):
28
+ prompt = f"Generate code with the following: \n\n"
29
+ prompt += f"Variables: {', '.join(code_analysis['variables'])} \n\n"
30
+ prompt += f"Functions: {', '.join(code_analysis['functions'])} \n\n"
31
+ prompt += f"Logic: {' '.join(code_analysis['logic'])}"
32
+ return prompt
33
+
34
  # Generate code from model and prompt
35
  def generate_code(prompt):
36
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
37
+ generated_ids = model.generate(input_ids, max_length=100, num_beams=5, early_stopping=True)
38
+ generated_code = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
39
+ return generated_code
40
 
41
  # Suggest improvements to code
42
  def suggest_improvements(code):
43
+ suggestions = ["Use more descriptive variable names", "Add comments to explain complex logic", "Refactor duplicated code into functions"]
44
+ return suggestions
45
+
46
+ # Main function to integrate the other functions and generate_code
47
+ def main_function(input_code):
48
+ code_analysis = analyze_code(input_code)
49
+ prompt = generate_prompt(code_analysis)
50
+ generated_code = generate_code(prompt)
51
+ improvements = suggest_improvements(input_code)
52
+ return generated_code, improvements
53
+
54
+ # Create Gradio interface
55
+ iface = gr.Interface(
56
+ fn=main_function,
57
+ inputs=gr.inputs.Textbox(lines=5, label="Input Code"),
58
+ outputs=[gr.outputs.Textbox(lines=5, label="Generated Code"), gr.outputs.Textbox(lines=5, label="Suggested Improvements")]
59
+ )
60
+
61
+ # Launch Gradio interface
62
+ iface.launch()