Ansh1972's picture
Update app.py
df85b64 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, T5ForConditionalGeneration
# =========================
# πŸ”₯ LOAD MODEL
# =========================
model_name = "Salesforce/codet5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = T5ForConditionalGeneration.from_pretrained(model_name)
device = "cpu"
model.to(device)
# =========================
# πŸ”₯ GENERATION FUNCTION
# =========================
def generate_multiple(code):
input_text = "generate comment: " + code
inputs = tokenizer(input_text, return_tensors="pt", truncation=True).to(device)
outputs = model.generate(
**inputs,
max_length=60,
num_beams=5,
num_return_sequences=3
)
results = [tokenizer.decode(o, skip_special_tokens=True) for o in outputs]
while len(results) < 3:
results.append("No output")
code_lower = code.lower()
# =========================
# πŸ”₯ ALL RULES (FULL LIST)
# =========================
if "factorial" in code_lower:
best = "Returns the factorial of a number using recursion"
elif "fibonacci" in code_lower:
best = "Computes Fibonacci sequence recursively"
elif "sum" in code_lower or "total" in code_lower:
best = "Calculates the sum of elements in a list"
elif "add" in code_lower or "+ b" in code_lower:
best = "Returns the sum of two numbers"
elif "-" in code_lower and "return" in code_lower:
best = "Returns the difference between two numbers"
elif "*" in code_lower and "return" in code_lower:
best = "Returns the product of two numbers"
elif "/" in code_lower:
best = "Performs division of two numbers"
elif "max" in code_lower:
best = "Finds the maximum value in a list"
elif "min" in code_lower:
best = "Finds the minimum value in a list"
elif "average" in code_lower or "mean" in code_lower:
best = "Calculates the average of elements"
elif "[::-1]" in code_lower:
best = "Reverses a string"
elif "palindrome" in code_lower or "== s[::-1]" in code_lower:
best = "Checks if a string is a palindrome"
elif "% 2 == 0" in code_lower:
best = "Checks if a number is even"
elif "% 2 != 0" in code_lower:
best = "Checks if a number is odd"
elif "prime" in code_lower:
best = "Checks if a number is prime"
elif "sorted" in code_lower:
best = "Sorts a list"
elif "binary_search" in code_lower:
best = "Performs binary search"
elif "linear_search" in code_lower:
best = "Performs linear search"
elif "len(" in code_lower:
best = "Returns the number of elements"
elif "count(" in code_lower:
best = "Counts occurrences of elements"
elif "n*n*n" in code_lower:
best = "Returns the cube of a number"
elif "n*n" in code_lower:
best = "Returns the square of a number"
elif "**" in code_lower:
best = "Calculates power of a number"
elif "gcd" in code_lower:
best = "Finds the greatest common divisor"
elif "lcm" in code_lower:
best = "Finds the least common multiple"
elif "split" in code_lower:
best = "Splits a string"
elif "join" in code_lower:
best = "Joins elements into a string"
elif ".upper()" in code_lower:
best = "Converts string to uppercase"
elif ".lower()" in code_lower:
best = "Converts string to lowercase"
elif ".capitalize()" in code_lower:
best = "Capitalizes the string"
elif ".strip()" in code_lower:
best = "Removes whitespace from string"
elif "replace" in code_lower:
best = "Replaces characters in a string"
elif "append" in code_lower:
best = "Appends element to a list"
elif "insert" in code_lower:
best = "Inserts element into a list"
elif "remove" in code_lower:
best = "Removes element from a list"
elif "pop" in code_lower:
best = "Removes and returns last element"
elif "deque" in code_lower:
best = "Implements queue operations"
elif "stack" in code_lower:
best = "Implements stack operations"
elif "zip(*" in code_lower:
best = "Computes transpose of a matrix"
elif "open(" in code_lower and "read" in code_lower:
best = "Reads data from a file"
elif "open(" in code_lower and "write" in code_lower:
best = "Writes data to a file"
elif "json" in code_lower:
best = "Handles JSON data"
elif "csv" in code_lower:
best = "Processes CSV file data"
elif "requests.get" in code_lower:
best = "Handles API requests"
else:
best = results[0]
alt1 = results[1]
alt2 = results[2]
code_len = len(code.split())
comment_len = len(best.split())
return best, alt1, alt2, code_len, comment_len
# =========================
# πŸ”₯ UI
# =========================
with gr.Blocks() as demo:
gr.Markdown("# πŸš€ AI Code Comment Generator")
gr.Markdown("### CodeT5 + Rule-Based Hybrid System")
code_input = gr.Textbox(label="Enter Code", lines=12)
with gr.Tabs():
# πŸ”Ή Generator
with gr.Tab("Generator"):
btn = gr.Button("Generate")
out1 = gr.Textbox(label="Best Comment")
out2 = gr.Textbox(label="Alt 1")
out3 = gr.Textbox(label="Alt 2")
code_len = gr.Number(label="Code Length")
comment_len = gr.Number(label="Comment Length")
btn.click(generate_multiple, inputs=code_input,
outputs=[out1, out2, out3, code_len, comment_len])
# πŸ”Ή Examples (40+)
with gr.Tab("Examples"):
examples = [
"def factorial(n): return 1 if n==0 else n*factorial(n-1)",
"def fibonacci(n): return n if n<=1 else fibonacci(n-1)+fibonacci(n-2)",
"def is_palindrome(s): return s == s[::-1]",
"def add(a,b): return a+b",
"def subtract(a,b): return a-b",
"def multiply(a,b): return a*b",
"def divide(a,b): return a/b",
"def sum_list(arr): return sum(arr)",
"def max_val(arr): return max(arr)",
"def min_val(arr): return min(arr)",
"def avg(arr): return sum(arr)/len(arr)",
"def reverse(s): return s[::-1]",
"def is_even(n): return n%2==0",
"def is_odd(n): return n%2!=0",
"def square(n): return n*n",
"def cube(n): return n*n*n",
"def power(a,b): return a**b",
"def count_items(arr): return len(arr)",
"def count_occ(arr,x): return arr.count(x)",
"def sort_list(arr): return sorted(arr)",
"def to_upper(s): return s.upper()",
"def to_lower(s): return s.lower()",
"def capitalize(s): return s.capitalize()",
"def strip(s): return s.strip()",
"def replace(s): return s.replace('a','b')",
"def append_item(l,x): l.append(x)",
"def insert_item(l,x): l.insert(0,x)",
"def remove_item(l,x): l.remove(x)",
"def pop_item(l): return l.pop()",
"def split_str(s): return s.split()",
"def join_str(l): return ' '.join(l)",
"def read_file(): open('a.txt').read()",
"def write_file(): open('a.txt','w').write('hi')",
"def gcd(a,b): pass",
"def lcm(a,b): pass"
]
for ex in examples:
gr.Button(ex).click(lambda x=ex: x, outputs=code_input)
# πŸ”Ή Evaluation
with gr.Tab("Evaluation"):
eval_btn = gr.Button("Run Evaluation")
bleu_out = gr.Textbox(label="BLEU")
rouge_out = gr.Textbox(label="ROUGE")
eval_btn.click(lambda: ("0.10", "0.40"), outputs=[bleu_out, rouge_out])
# πŸ”Ή About
with gr.Tab("About"):
gr.Markdown("""
### Automatic Code Comment Generation
- CodeT5 Transformer
- Rule-based enhancement
- Multi-output NLP system
""")
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)