Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# coding: utf-8 | |
# In[1]: | |
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
import torch | |
from torch import nn | |
from torch.nn import init, MarginRankingLoss | |
from torch.optim import Adam | |
from distutils.version import LooseVersion | |
from torch.utils.data import Dataset, DataLoader | |
from torch.autograd import Variable | |
import math | |
from transformers import AutoConfig, AutoModel, AutoTokenizer | |
import nltk | |
import re | |
import torch.optim as optim | |
from transformers import AutoModelForMaskedLM | |
import torch.nn.functional as F | |
import random | |
# In[2]: | |
# eng_dict = [] | |
# with open('eng_dict.txt', 'r') as file: | |
# # Read each line from the file and append it to the list | |
# for line in file: | |
# # Remove leading and trailing whitespace (e.g., newline characters) | |
# cleaned_line = line.strip() | |
# eng_dict.append(cleaned_line) | |
# In[14]: | |
def greet(X, ny): | |
global eng_dict | |
ny = int(ny) | |
if ny == 0: | |
rand_no = random.random() | |
tok_map = {2: 0.4363429005892416, | |
1: 0.6672580202327398, | |
4: 0.7476060740459144, | |
3: 0.9618703668504087, | |
6: 0.9701028532809564, | |
7: 0.9729244545819342, | |
8: 0.9739508754144756, | |
5: 0.9994508859743607, | |
9: 0.9997507867114407, | |
10: 0.9999112969650892, | |
11: 0.9999788802297832, | |
0: 0.9999831041838266, | |
12: 0.9999873281378701, | |
22: 0.9999957760459568, | |
14: 1.0000000000000002} | |
for key in tok_map.keys(): | |
if rand_no < tok_map[key]: | |
num_sub_tokens_label = key | |
break | |
else: | |
num_sub_tokens_label = ny | |
tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base") | |
model = AutoModelForMaskedLM.from_pretrained("microsoft/graphcodebert-base") | |
model.load_state_dict(torch.load('model_26_2')) | |
model.eval() | |
X_init = X | |
X_init = X_init.replace("[MASK]", " [MASK] ") | |
X_init = X_init.replace("[MASK]", " ".join([tokenizer.mask_token] * num_sub_tokens_label)) | |
tokens = tokenizer.encode_plus(X_init, add_special_tokens=False,return_tensors='pt') | |
input_id_chunki = tokens['input_ids'][0].split(510) | |
input_id_chunks = [] | |
mask_chunks = [] | |
mask_chunki = tokens['attention_mask'][0].split(510) | |
for tensor in input_id_chunki: | |
input_id_chunks.append(tensor) | |
for tensor in mask_chunki: | |
mask_chunks.append(tensor) | |
xi = torch.full((1,), fill_value=101) | |
yi = torch.full((1,), fill_value=1) | |
zi = torch.full((1,), fill_value=102) | |
for r in range(len(input_id_chunks)): | |
input_id_chunks[r] = torch.cat([xi, input_id_chunks[r]],dim = -1) | |
input_id_chunks[r] = torch.cat([input_id_chunks[r],zi],dim=-1) | |
mask_chunks[r] = torch.cat([yi, mask_chunks[r]],dim=-1) | |
mask_chunks[r] = torch.cat([mask_chunks[r],yi],dim=-1) | |
di = torch.full((1,), fill_value=0) | |
for i in range(len(input_id_chunks)): | |
pad_len = 512 - input_id_chunks[i].shape[0] | |
if pad_len > 0: | |
for p in range(pad_len): | |
input_id_chunks[i] = torch.cat([input_id_chunks[i],di],dim=-1) | |
mask_chunks[i] = torch.cat([mask_chunks[i],di],dim=-1) | |
vb = torch.ones_like(input_id_chunks[0]) | |
fg = torch.zeros_like(input_id_chunks[0]) | |
maski = [] | |
for l in range(len(input_id_chunks)): | |
masked_pos = [] | |
for i in range(len(input_id_chunks[l])): | |
if input_id_chunks[l][i] == tokenizer.mask_token_id: #103 | |
if i != 0 and input_id_chunks[l][i-1] == tokenizer.mask_token_id: | |
continue | |
masked_pos.append(i) | |
maski.append(masked_pos) | |
input_ids = torch.stack(input_id_chunks) | |
att_mask = torch.stack(mask_chunks) | |
outputs = model(input_ids, attention_mask = att_mask) | |
last_hidden_state = outputs[0].squeeze() | |
l_o_l_sa = [] | |
sum_state = [] | |
for t in range(num_sub_tokens_label): | |
c = [] | |
l_o_l_sa.append(c) | |
if len(maski) == 1: | |
masked_pos = maski[0] | |
for k in masked_pos: | |
for t in range(num_sub_tokens_label): | |
l_o_l_sa[t].append(last_hidden_state[k+t]) | |
else: | |
for p in range(len(maski)): | |
masked_pos = maski[p] | |
for k in masked_pos: | |
for t in range(num_sub_tokens_label): | |
if (k+t) >= len(last_hidden_state[p]): | |
l_o_l_sa[t].append(last_hidden_state[p+1][k+t-len(last_hidden_state[p])]) | |
continue | |
l_o_l_sa[t].append(last_hidden_state[p][k+t]) | |
for t in range(num_sub_tokens_label): | |
sum_state.append(l_o_l_sa[t][0]) | |
for i in range(len(l_o_l_sa[0])): | |
if i == 0: | |
continue | |
for t in range(num_sub_tokens_label): | |
sum_state[t] = sum_state[t] + l_o_l_sa[t][i] | |
yip = len(l_o_l_sa[0]) | |
# qw = [] | |
er = "" | |
for t in range(num_sub_tokens_label): | |
sum_state[t] /= yip | |
idx = torch.topk(sum_state[t], k=5, dim=0)[1] | |
wor = [tokenizer.decode(i.item()).strip() for i in idx] | |
for kl in wor: | |
if all(char.isalpha() for char in kl): | |
# qw.append(kl.lower()) | |
er+=kl | |
break | |
# print(er) | |
# astr = "" | |
# for j in range(len(qw)): | |
# mock = "" | |
# mock+= qw[j] | |
# if (j+2) < len(qw) and ((mock+qw[j+1]+qw[j+2]) in eng_dict): | |
# mock +=qw[j+1] | |
# mock +=qw[j+2] | |
# j = j+2 | |
# elif (j+1) < len(qw) and ((mock+qw[j+1]) in eng_dict): | |
# mock +=qw[j+1] | |
# j = j+1 | |
# if len(astr) == 0: | |
# astr+=mock | |
# else: | |
# astr+=mock.capitalize() | |
return er | |
title = "Rename a variable in a Java class" | |
description = """This model is a fine-tuned GraphCodeBERT model fine-tuned to output higher-quality variable names for Java classes. Long classes are handled by the | |
model. Replace any variable name with a "[MASK]" to get an identifier renaming. | |
""" | |
ex = [["""import java.io.*; | |
public class x { | |
public static void main(String[] args) { | |
String f = "file.txt"; | |
BufferedReader [MASK] = null; | |
String l; | |
try { | |
[MASK] = new BufferedReader(new FileReader(f)); | |
while ((l = [MASK].readLine()) != null) { | |
System.out.println(l); | |
} | |
} catch (IOException e) { | |
e.printStackTrace(); | |
} finally { | |
try { | |
if ([MASK] != null) [MASK].close(); | |
} catch (IOException ex) { | |
ex.printStackTrace(); | |
} | |
} | |
} | |
}""", "0"], ["""import java.net.*; | |
import java.io.*; | |
public class s { | |
public static void main(String[] args) throws IOException { | |
ServerSocket [MASK] = new ServerSocket(8000); | |
try { | |
Socket s = [MASK].accept(); | |
PrintWriter pw = new PrintWriter(s.getOutputStream(), true); | |
BufferedReader br = new BufferedReader(new InputStreamReader(s.getInputStream())); | |
String i; | |
while ((i = br.readLine()) != null) { | |
pw.println(i); | |
} | |
} finally { | |
if ([MASK] != null) [MASK].close(); | |
} | |
} | |
}""", "2"], ["""import java.io.*; | |
import java.util.*; | |
public class y { | |
public static void main(String[] args) { | |
String [MASK] = "data.csv"; | |
String l = ""; | |
String cvsSplitBy = ","; | |
try (BufferedReader br = new BufferedReader(new FileReader([MASK]))) { | |
while ((l = br.readLine()) != null) { | |
String[] z = l.split(cvsSplitBy); | |
System.out.println("Values [field-1= " + z[0] + " , field-2=" + z[1] + "]"); | |
} | |
} catch (IOException e) { | |
e.printStackTrace(); | |
} | |
} | |
}""", "2"]] | |
# We instantiate the Textbox class | |
textbox = gr.Textbox(label="Type Java code snippet:", placeholder="replace variable with [MASK]", lines=10) | |
textbox1 = gr.Textbox(label="Number of tokens in name:", placeholder="0 for randomly sampled number of tokens",lines=1) | |
gr.Interface(title = title, description = description, examples = ex, fn=greet, inputs=[ | |
textbox,textbox1 | |
], outputs="text").launch() | |
# In[ ]: | |