diamond-price-predictor / inference.py
pgurazada1's picture
inference minor variations
136959f
raw
history blame contribute delete
No virus
1.27 kB
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from gradio_client import Client
client = Client("pgurazada1/diamond-price-predictor")
dataset = fetch_openml(data_id=43355, as_frame=True, parser='auto')
diamond_prices = dataset.data
target = ['price']
numeric_features = ['carat']
categorical_features = ['shape', 'cut', 'color', 'clarity', 'report', 'type']
X = diamond_prices.drop(columns=target)
y = diamond_prices[target]
Xtrain, Xtest, ytrain, ytest = train_test_split(
X, y,
test_size=0.2,
random_state=42
)
job = client.submit(
3, # float in 'Carat' Number component
"Round", # Literal['Round', 'Princess', 'Emerald', 'Asscher', 'Cushion', 'Radiant', 'Oval', 'Pear', 'Marquise'] in 'Shape' Dropdown component
"Ideal", # Literal['Ideal', 'Premium', 'Very Good', 'Good', 'Fair'] in 'Cut' Dropdown component
"D", # Literal['D', 'E', 'F', 'G', 'H', 'I', 'J'] in 'Color' Dropdown component
"IF", # Literal['IF', 'VVS1', 'VVS2', 'VS1', 'VS2', 'SI1', 'SI2', 'I1'] in 'Clarity' Dropdown component
"GIA", # Literal['GIA', 'IGI', 'HRD', 'AGS'] in 'Report' Dropdown component
"Natural", # Literal['Natural', 'Lab Grown'] in 'Type' Dropdown component
api_name="/predict"
)
print(job.result())