import re import gradio as gr import torch from transformers import T5ForConditionalGeneration, RobertaTokenizer tokenizer = RobertaTokenizer.from_pretrained("mamiksik/CommitPredictorT5PL", revision="fb08d01") model = T5ForConditionalGeneration.from_pretrained("mamiksik/CommitPredictorT5PL", revision="fb08d01") def parse_files(accumulator: list[str], patch: str): lines = patch.splitlines() filename_before = None for line in lines: if line.startswith("index") or line.startswith("diff"): continue if line.startswith("---"): filename_before = line.split(" ", 1)[1][1:] continue if line.startswith("+++"): filename_after = line.split(" ", 1)[1][1:] if filename_before == filename_after: accumulator.append(f"{filename_before}") else: accumulator.append(f"{filename_after}") accumulator.append(f"{filename_before}") continue line = re.sub("@@[^@@]*@@", "", line) if len(line) == 0: continue if line[0] == "+": line = line.replace("+", "", 1) elif line[0] == "-": line = line.replace("-", "", 1) else: line = f"{line}" accumulator.append(line) return accumulator def predict(patch, max_length, min_length, num_beams, prediction_count): accumulator = [] parse_files(accumulator, patch) input_text = '\n'.join(accumulator) with torch.no_grad(): token_count = tokenizer(input_text, return_tensors="pt").input_ids.shape[1] input_ids = tokenizer( input_text, truncation=True, padding=True, return_tensors="pt", ).input_ids outputs = model.generate( input_ids, max_length=max_length, min_length=min_length, num_beams=num_beams, num_return_sequences=prediction_count, ) result = tokenizer.batch_decode(outputs, skip_special_tokens=True) return token_count, '\n'.join(accumulator), {k: 0 for k in result} iface = gr.Interface(fn=predict, inputs=[ gr.Textbox(label="Patch (as generated by git diff)"), gr.Slider(1, 128, value=20, label="Max message length"), gr.Slider(1, 128, value=5, label="Min message length"), gr.Slider(1, 10, value=7, label="Number of beams"), gr.Slider(1, 15, value=5, label="Number of predictions"), ], outputs=[ gr.Textbox(label="Token count"), gr.Textbox(label="Parsed patch"), gr.Label(label="Predictions") ]) if __name__ == "__main__": iface.launch()