sklearn-train-basic / api_example.py
sgbaird's picture
Update num_pred value in api_example.py
0fae5f7
raw
history blame
958 Bytes
from gradio_client import Client
from sklearn.datasets import load_linnerud
import pandas as pd
import numpy as np
from time import time
X, y = load_linnerud(return_X_y=True, as_frame=True)
# create a dataframe with 1000 randomly generated values for predicting
rng = np.random.default_rng(42)
num_pred = 10
X_pred = pd.DataFrame(
{
"Chins": 50 * rng.random(num_pred),
"Situps": 80 * rng.random(num_pred),
"Jumps": 20 * rng.random(num_pred),
}
)
client = Client("AccelerationConsortium/sklearn-train-basic")
t0 = time()
result = client.predict(
{
"headers": X_pred.columns.tolist(),
"data": X_pred.values.tolist(),
}, # Dict(headers: List[str], data: List[List[Any]], metadata: Dict(str, List[Any] | None) | None) in 'X' Dataframe component
api_name="/predict",
)
print(f"Time taken: {time() - t0:.2f}s")
result_df = pd.DataFrame(result["data"], columns=result["headers"])
print(result_df)