|
|
|
import gradio as gr |
|
|
|
import pandas as pd |
|
import numpy as np |
|
import lightgbm as lgb |
|
from sklearn.model_selection import train_test_split |
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
title = "RegMix: Data Mixture as Regression for Language Model Pre-training" |
|
description = "We propose a regression-based method to find high-performing data mixture for language model pre-training." |
|
|
|
def infer(inputs, additional_inputs): |
|
df = pd.DataFrame(inputs, columns=headers) |
|
|
|
X_columns = df.columns[0:-1] |
|
y_column = df.columns[-1] |
|
|
|
df_train, df_val = train_test_split(df, test_size=0.125, random_state=42) |
|
|
|
hyper_params = { |
|
'task': 'train', |
|
'boosting_type': 'gbdt', |
|
'objective': 'regression', |
|
'metric': ['l1','l2'], |
|
"num_iterations": 1000, |
|
'seed': 42, |
|
'learning_rate': 1e-2, |
|
} |
|
|
|
target = df_train[y_column] |
|
eval_target = df_val[y_column] |
|
|
|
np.random.seed(42) |
|
|
|
gbm = lgb.LGBMRegressor(**hyper_params) |
|
|
|
reg = gbm.fit(df_train[X_columns].values, target, |
|
eval_set=[(df_val[X_columns].values, eval_target)], |
|
eval_metric='l2', |
|
callbacks=[ |
|
lgb.early_stopping(stopping_rounds=3), |
|
] |
|
) |
|
|
|
predictions = reg.predict(df_val[X_columns].values) |
|
df_val['Prediction'] = predictions |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
plt.rcParams["font.family"] = "Times New Roman" |
|
plt.rcParams.update({'font.size': 24}) |
|
plt.rcParams.update({'axes.labelpad': 20}) |
|
|
|
from matplotlib import cm |
|
from matplotlib.ticker import LinearLocator |
|
|
|
fig, ax = plt.subplots(figsize=(12, 12), layout='compressed', subplot_kw={"projection": "3d"}) |
|
|
|
stride = 0.025 |
|
X = np.arange(0, 1+stride, stride) |
|
Y = np.arange(0, 1+stride, stride) |
|
|
|
X, Y = np.meshgrid(X, Y) |
|
Z = [] |
|
for (x,y) in zip(X.reshape(-1), Y.reshape(-1)): |
|
if (x+y)>1: |
|
Z.append(np.inf) |
|
else: |
|
Z.append( |
|
reg.predict(np.asarray([x, y, 1-x-y]).reshape(1, -1) |
|
)[0]) |
|
Z = np.asarray(Z).reshape(len(np.arange(0, 1+stride, stride)), len(np.arange(0, 1+stride, stride))) |
|
|
|
|
|
surf = ax.plot_surface(X, Y, Z, |
|
edgecolor='white', |
|
lw=0.5, rstride=2, cstride=2, |
|
alpha=0.85, |
|
cmap='coolwarm', |
|
vmin=min(Z[Z!=np.inf]), |
|
vmax=max(Z[Z!=np.inf]), |
|
|
|
antialiased=False, ) |
|
|
|
ax.zaxis.set_major_locator(LinearLocator(10)) |
|
ax.zaxis.set_major_formatter('{x:.02f}') |
|
|
|
ax.view_init(elev=25, azim=45, roll=0) |
|
|
|
ax.contourf(X, Y, Z, zdir='z', |
|
offset=np.min(Z)-0.35, |
|
cmap=cm.coolwarm) |
|
|
|
from matplotlib.patches import Circle |
|
from mpl_toolkits.mplot3d import art3d |
|
|
|
def add_point(ax, x, y, z, fc = None, ec = None, radius = 0.005): |
|
xy_len, z_len = ax.get_figure().get_size_inches() |
|
axis_length = [x[1] - x[0] for x in [ax.get_xbound(), ax.get_ybound(), ax.get_zbound()]] |
|
axis_rotation = {'z': ((x, y, z), axis_length[1]/axis_length[0]), |
|
'y': ((x, z, y), axis_length[2]/axis_length[0]*xy_len/z_len), |
|
'x': ((y, z, x), axis_length[2]/axis_length[1]*xy_len/z_len)} |
|
for a, ((x0, y0, z0), ratio) in axis_rotation.items(): |
|
p = Circle((x0, y0), radius, lw=1.5, |
|
|
|
fc=fc, |
|
ec=ec) |
|
ax.add_patch(p) |
|
art3d.pathpatch_2d_to_3d(p, z=z0, zdir=a) |
|
|
|
|
|
add_point(ax, X.reshape(-1)[np.argmin(Z)], Y.reshape(-1)[np.argmin(Z)], np.min(Z), |
|
fc='Red', |
|
ec='Red', radius=0.015) |
|
|
|
add_point(ax, X.reshape(-1)[np.argmin(Z)], Y.reshape(-1)[np.argmin(Z)], np.min(Z)-0.35, |
|
fc='Red', |
|
ec='Red', radius=0.015) |
|
|
|
|
|
ax.set_xlabel('Github (%)', fontdict={ |
|
'size':24 |
|
}) |
|
ax.set_ylabel('Hacker News (%)', fontdict={ |
|
'size':24 |
|
}) |
|
|
|
ax.set_xticks(np.arange(0, 1, 0.2), [str(np.round(num, 1)) for num in np.arange(0, 100, 20)], ) |
|
ax.set_yticks(np.arange(0, 1, 0.2), [str(np.round(num, 1)) for num in np.arange(0, 100, 20)], ) |
|
|
|
ax.set_zticks(np.arange(np.min(Z), np.max(Z[Z!=np.inf]), 0.2), [str(np.round(num, 1)) for num in np.arange(np.min(Z), np.max(Z[Z!=np.inf]), 0.2)], ) |
|
|
|
ax.zaxis.labelpad=1 |
|
|
|
ax.set_zlim(np.min(Z)-0.35, max(Z[Z!=np.inf])+0.01) |
|
ax.set_xlim(0, 1) |
|
ax.set_ylim(0, 1) |
|
ax.set_box_aspect(aspect=None, zoom=0.775) |
|
|
|
ax.zaxis._axinfo['juggled'] = (1,2,2) |
|
|
|
|
|
cbar = fig.colorbar(surf, |
|
shrink=0.5, |
|
aspect=25, pad=0.01 |
|
) |
|
cbar.ax.set_ylabel('Prediction', fontdict={ |
|
'size':32 |
|
}, |
|
|
|
|
|
) |
|
|
|
|
|
filename = "tmp.png" |
|
plt.savefig(filename, bbox_inches='tight', pad_inches=0.1) |
|
|
|
return [gr.ScatterPlot( |
|
value=df_val, |
|
x="Prediction", |
|
y="Target", |
|
title="Scatter", |
|
tooltip=["Prediction", "Target"], |
|
x_lim=[min(min(predictions), min(df_val[y_column]))-0.25, max(max(predictions), max(df_val[y_column]))+0.25], |
|
y_lim=[min(min(predictions), min(df_val[y_column]))-0.25, max(max(predictions), max(df_val[y_column]))+0.25] |
|
), |
|
gr.Image(Image.open('tmp.png')), |
|
df_val[['Target', 'Prediction']], ] |
|
|
|
def upload_csv(file): |
|
df = pd.read_csv(file.name, |
|
|
|
) |
|
|
|
|
|
return df |
|
|
|
df = pd.read_csv('data.csv') |
|
headers = df.columns.tolist() |
|
|
|
inputs = [gr.Dataframe(headers=headers, row_count = (8, "dynamic"), datatype='number', col_count=(4,"fixed"), label="Dataset", interactive=1)] |
|
outputs = [gr.ScatterPlot(), gr.Image(), gr.Dataframe(row_count = (2, "dynamic"), col_count=(2, "fixed"), datatype='number', label="Results", headers=["Target", "Prediction"])] |
|
|
|
with gr.Blocks() as demo: |
|
|
|
|
|
upload_button = gr.UploadButton(label="Upload", file_types = ['.csv'], |
|
|
|
file_count = "single", render=False) |
|
upload_button.upload(fn=upload_csv, inputs=upload_button, outputs=inputs, api_name="upload_csv") |
|
|
|
|
|
gr.Interface(infer, inputs=inputs, outputs=outputs, title=title, |
|
additional_inputs = [upload_button], |
|
additional_inputs_accordion='Upload CSV', |
|
description = description, |
|
examples=[[df], []], |
|
cache_examples=False, allow_flagging='never') |
|
|
|
|
|
demo.launch(debug=False) |