regmix / app.py
xszheng2020's picture
Update app.py
52cdfd2 verified
# import sklearn
import gradio as gr
# import joblib
import pandas as pd
import numpy as np
import lightgbm as lgb
from sklearn.model_selection import train_test_split
from PIL import Image
# import datasets
# pipe = joblib.load("./model.pkl")
title = "RegMix: Data Mixture as Regression for Language Model Pre-training"
description = "We propose a regression-based method to find high-performing data mixture for language model pre-training."
def infer(inputs, additional_inputs):
df = pd.DataFrame(inputs, columns=headers)
X_columns = df.columns[0:-1]
y_column = df.columns[-1]
df_train, df_val = train_test_split(df, test_size=0.125, random_state=42)
hyper_params = {
'task': 'train',
'boosting_type': 'gbdt',
'objective': 'regression',
'metric': ['l1','l2'],
"num_iterations": 1000,
'seed': 42,
'learning_rate': 1e-2,
}
target = df_train[y_column]
eval_target = df_val[y_column]
np.random.seed(42)
gbm = lgb.LGBMRegressor(**hyper_params)
reg = gbm.fit(df_train[X_columns].values, target,
eval_set=[(df_val[X_columns].values, eval_target)],
eval_metric='l2',
callbacks=[
lgb.early_stopping(stopping_rounds=3),
]
)
predictions = reg.predict(df_val[X_columns].values)
df_val['Prediction'] = predictions
####
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "Times New Roman" # !!!!
plt.rcParams.update({'font.size': 24})
plt.rcParams.update({'axes.labelpad': 20})
from matplotlib import cm
from matplotlib.ticker import LinearLocator
fig, ax = plt.subplots(figsize=(12, 12), layout='compressed', subplot_kw={"projection": "3d"})
stride = 0.025
X = np.arange(0, 1+stride, stride)
Y = np.arange(0, 1+stride, stride)
X, Y = np.meshgrid(X, Y)
Z = []
for (x,y) in zip(X.reshape(-1), Y.reshape(-1)):
if (x+y)>1:
Z.append(np.inf)
else:
Z.append(
reg.predict(np.asarray([x, y, 1-x-y]).reshape(1, -1)
)[0])
Z = np.asarray(Z).reshape(len(np.arange(0, 1+stride, stride)), len(np.arange(0, 1+stride, stride)))
# Plot the surface.
surf = ax.plot_surface(X, Y, Z,
edgecolor='white',
lw=0.5, rstride=2, cstride=2,
alpha=0.85,
cmap='coolwarm',
vmin=min(Z[Z!=np.inf]),
vmax=max(Z[Z!=np.inf]),
# linewidth=8,
antialiased=False, )
ax.zaxis.set_major_locator(LinearLocator(10))
ax.zaxis.set_major_formatter('{x:.02f}')
ax.view_init(elev=25, azim=45, roll=0) #####
ax.contourf(X, Y, Z, zdir='z',
offset=np.min(Z)-0.35,
cmap=cm.coolwarm)
from matplotlib.patches import Circle
from mpl_toolkits.mplot3d import art3d
def add_point(ax, x, y, z, fc = None, ec = None, radius = 0.005):
xy_len, z_len = ax.get_figure().get_size_inches()
axis_length = [x[1] - x[0] for x in [ax.get_xbound(), ax.get_ybound(), ax.get_zbound()]]
axis_rotation = {'z': ((x, y, z), axis_length[1]/axis_length[0]),
'y': ((x, z, y), axis_length[2]/axis_length[0]*xy_len/z_len),
'x': ((y, z, x), axis_length[2]/axis_length[1]*xy_len/z_len)}
for a, ((x0, y0, z0), ratio) in axis_rotation.items():
p = Circle((x0, y0), radius, lw=1.5,
# width = radius, height = radius*ratio,
fc=fc,
ec=ec)
ax.add_patch(p)
art3d.pathpatch_2d_to_3d(p, z=z0, zdir=a)
add_point(ax, X.reshape(-1)[np.argmin(Z)], Y.reshape(-1)[np.argmin(Z)], np.min(Z),
fc='Red',
ec='Red', radius=0.015)
add_point(ax, X.reshape(-1)[np.argmin(Z)], Y.reshape(-1)[np.argmin(Z)], np.min(Z)-0.35,
fc='Red',
ec='Red', radius=0.015)
ax.set_xlabel('Github (%)', fontdict={
'size':24
})
ax.set_ylabel('Hacker News (%)', fontdict={
'size':24
})
ax.set_xticks(np.arange(0, 1, 0.2), [str(np.round(num, 1)) for num in np.arange(0, 100, 20)], )
ax.set_yticks(np.arange(0, 1, 0.2), [str(np.round(num, 1)) for num in np.arange(0, 100, 20)], )
ax.set_zticks(np.arange(np.min(Z), np.max(Z[Z!=np.inf]), 0.2), [str(np.round(num, 1)) for num in np.arange(np.min(Z), np.max(Z[Z!=np.inf]), 0.2)], )
ax.zaxis.labelpad=1
ax.set_zlim(np.min(Z)-0.35, max(Z[Z!=np.inf])+0.01)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_box_aspect(aspect=None, zoom=0.775)
ax.zaxis._axinfo['juggled'] = (1,2,2)
# Add a color bar which maps values to colors.
cbar = fig.colorbar(surf,
shrink=0.5,
aspect=25, pad=0.01
)
cbar.ax.set_ylabel('Prediction', fontdict={
'size':32
},
# rotation=270,
# labelpad=-90
)
filename = "tmp.png"
plt.savefig(filename, bbox_inches='tight', pad_inches=0.1)
####
return [gr.ScatterPlot(
value=df_val,
x="Prediction",
y="Target",
title="Scatter",
tooltip=["Prediction", "Target"],
x_lim=[min(min(predictions), min(df_val[y_column]))-0.25, max(max(predictions), max(df_val[y_column]))+0.25],
y_lim=[min(min(predictions), min(df_val[y_column]))-0.25, max(max(predictions), max(df_val[y_column]))+0.25]
),
gr.Image(Image.open('tmp.png')),
df_val[['Target', 'Prediction']], ]
def upload_csv(file):
df = pd.read_csv(file.name,
# encoding='utf-8'
)
# Return as formatted string
# print(df.head())
return df
df = pd.read_csv('data.csv')
headers = df.columns.tolist()
inputs = [gr.Dataframe(headers=headers, row_count = (8, "dynamic"), datatype='number', col_count=(4,"fixed"), label="Dataset", interactive=1)]
outputs = [gr.ScatterPlot(), gr.Image(), gr.Dataframe(row_count = (2, "dynamic"), col_count=(2, "fixed"), datatype='number', label="Results", headers=["Target", "Prediction"])]
with gr.Blocks() as demo:
####
upload_button = gr.UploadButton(label="Upload", file_types = ['.csv'],
# live=True,
file_count = "single", render=False)
upload_button.upload(fn=upload_csv, inputs=upload_button, outputs=inputs, api_name="upload_csv")
####
gr.Interface(infer, inputs=inputs, outputs=outputs, title=title,
additional_inputs = [upload_button],
additional_inputs_accordion='Upload CSV',
description = description,
examples=[[df], []],
cache_examples=False, allow_flagging='never')
demo.launch(debug=False)