OffWorldTensor commited on
Commit
59cce6a
·
1 Parent(s): 3dc8b9b

docs: Finalize project files and fix Gradio SDK version

Browse files
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import joblib
4
+ import pandas as pd
5
+ import os
6
+ import json
7
+ import re
8
+ from safetensors.torch import load_file
9
+ from typing import List, Tuple
10
+ from network import PricePredictor
11
+
12
+ MODEL_DIR = "model"
13
+ DATA_DIR = "data"
14
+ SCALER_PATH = os.path.join(DATA_DIR, "scaler.pkl")
15
+ DATA_PATH = os.path.join(DATA_DIR, "pokemon_final_with_labels.csv")
16
+
17
+
18
+ def load_model_and_config(model_dir: str) -> Tuple[torch.nn.Module, List[str]]:
19
+ config_path = os.path.join(model_dir, "config.json")
20
+ with open(config_path, "r") as f:
21
+ model_config = json.load(f)
22
+
23
+ model = PricePredictor(input_size=model_config["input_size"])
24
+ weights_path = os.path.join(model_dir, "model.safetensors")
25
+ model.load_state_dict(load_file(weights_path))
26
+ model.eval()
27
+ return model, model_config["feature_columns"]
28
+
29
+
30
+ def perform_prediction(model: torch.nn.Module, scaler, input_features: pd.Series) -> Tuple[bool, float]:
31
+ features_np = input_features.to_numpy(dtype="float32").reshape(1, -1)
32
+ features_scaled = scaler.transform(features_np)
33
+ features_tensor = torch.tensor(features_scaled, dtype=torch.float32)
34
+
35
+ with torch.no_grad():
36
+ logit = model(features_tensor)
37
+ probability = torch.sigmoid(logit).item()
38
+ predicted_class = bool(round(probability))
39
+
40
+ return predicted_class, probability
41
+
42
+ try:
43
+ model, feature_columns = load_model_and_config(MODEL_DIR)
44
+ scaler = joblib.load(SCALER_PATH)
45
+ full_data = pd.read_csv(DATA_PATH)
46
+
47
+ full_data['display_name'] = full_data.apply(
48
+ lambda row: f"{row['name']} (ID: {row['tcgplayer_id']})", axis=1
49
+ )
50
+ card_choices = sorted(full_data['display_name'].unique().tolist())
51
+ ASSETS_LOADED = True
52
+ except FileNotFoundError as e:
53
+ print(f"Error loading necessary files: {e}")
54
+ print("Please make sure you have uploaded the 'model' and 'data' directories to your Hugging Face Space.")
55
+ card_choices = ["Error: Model or data files not found. Check logs."]
56
+ ASSETS_LOADED = False
57
+
58
+
59
+ def predict_price_trend(card_display_name: str) -> str:
60
+ if not ASSETS_LOADED:
61
+ return "## Application Error\nAssets could not be loaded. Please check the logs on Hugging Face Spaces for details. You may need to upload your `model` and `data` directories."
62
+
63
+ try:
64
+ tcgplayer_id = int(re.search(r'\(ID: (\d+)\)', card_display_name).group(1))
65
+ except (AttributeError, ValueError):
66
+ return f"## Input Error\nCould not parse ID from '{card_display_name}'. Please select a valid card from the dropdown."
67
+
68
+ card_data = full_data[full_data['tcgplayer_id'] == tcgplayer_id]
69
+ if card_data.empty:
70
+ return f"## Internal Error\nCould not find data for ID {tcgplayer_id}. Please restart the Space or select another card."
71
+
72
+ card_sample = card_data.iloc[0]
73
+ sample_features = card_sample[feature_columns]
74
+
75
+ predicted_class, probability = perform_prediction(model, scaler, sample_features)
76
+
77
+ prediction_text = "**RISE**" if predicted_class else "**NOT RISE**"
78
+ confidence = probability if predicted_class else 1 - probability
79
+
80
+ target_col = 'price_will_rise_30_in_6m' # NOTE: Assumed target column name. Change if yours is different.
81
+ true_label_text = ""
82
+ if target_col in card_sample and pd.notna(card_sample[target_col]):
83
+ true_label = bool(card_sample[target_col])
84
+ true_label_text = f"\n- **Actual Result in Dataset:** The price did **{'RISE' if true_label else 'NOT RISE'}**."
85
+
86
+ output = f"""
87
+ ## 🔮 Prediction Report for {card_sample['name']}
88
+ - **Prediction:** The model predicts the card's price will {prediction_text} by 30% in the next 6 months.
89
+ - **Confidence:** {confidence:.2%}
90
+ {true_label_text}
91
+ """
92
+ return output
93
+
94
+
95
+ iface = gr.Interface(
96
+ fn=predict_price_trend,
97
+ inputs=gr.Dropdown(
98
+ choices=card_choices,
99
+ label="Select a Pokémon Card",
100
+ info="Choose a card from the dataset to predict its price trend."
101
+ ),
102
+ outputs=gr.Markdown(),
103
+ title="PricePoke: Pokémon Card Price Trend Predictor",
104
+ description="""
105
+ Select a Pokémon card to predict whether its market price will increase by 30% or more over the next 6 months.
106
+ This model was trained on historical TCGPlayer market data.
107
+ """,
108
+ examples=[[card_choices[0]] if card_choices and ASSETS_LOADED else []],
109
+ allow_flagging="never"
110
+ )
111
+
112
+ if __name__ == "__main__":
113
+ iface.launch()
data/pokemon_final_with_labels.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e338ce22e2d28270c7b5eaa18fdad4b8465b1e1bf83dc085372c542aa11f092e
3
+ size 8231628
data/scaler.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57bae1c7e9c16028c4f21def0302ba1514e7a3d8be131937702da75007ccd866
3
+ size 2151
model/config.json ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input_size": 64,
3
+ "model_class": "PricePredictor",
4
+ "feature_columns": [
5
+ "rawPrice",
6
+ "gradedPriceTen",
7
+ "gradedPriceNine",
8
+ "first_raw",
9
+ "price_ratio_to_first",
10
+ "log_raw",
11
+ "log_g10",
12
+ "log_g9",
13
+ "price_vs_rolling_avg",
14
+ "rawPrice_missing",
15
+ "gradedPriceTen_missing",
16
+ "gradedPriceNine_missing",
17
+ "rarity_ACE SPEC Rare",
18
+ "rarity_Amazing Rare",
19
+ "rarity_Black White Rare",
20
+ "rarity_Classic Collection",
21
+ "rarity_Code Card",
22
+ "rarity_Common",
23
+ "rarity_Double Rare",
24
+ "rarity_Holo Rare",
25
+ "rarity_Hyper Rare",
26
+ "rarity_Illustration Rare",
27
+ "rarity_Prism Rare",
28
+ "rarity_Promo",
29
+ "rarity_Radiant Rare",
30
+ "rarity_Rare",
31
+ "rarity_Rare Ace",
32
+ "rarity_Rare BREAK",
33
+ "rarity_Secret Rare",
34
+ "rarity_Shiny Holo Rare",
35
+ "rarity_Shiny Rare",
36
+ "rarity_Shiny Ultra Rare",
37
+ "rarity_Special Illustration Rare",
38
+ "rarity_Ultra Rare",
39
+ "rarity_Uncommon",
40
+ "energyType_Colorless",
41
+ "energyType_Darkness",
42
+ "energyType_Dragon",
43
+ "energyType_Energy",
44
+ "energyType_Fairy",
45
+ "energyType_Fighting",
46
+ "energyType_Fire",
47
+ "energyType_Grass",
48
+ "energyType_Lightning",
49
+ "energyType_Metal",
50
+ "energyType_Psychic",
51
+ "energyType_Water",
52
+ "energyType_nan",
53
+ "cardType_Energy",
54
+ "cardType_Item",
55
+ "cardType_Pokemon",
56
+ "cardType_Stadium",
57
+ "cardType_Supporter",
58
+ "cardType_Tool",
59
+ "cardType_Trainer",
60
+ "cardType_nan",
61
+ "variant_1st Edition",
62
+ "variant_1st Edition Holofoil",
63
+ "variant_Holofoil",
64
+ "variant_Normal",
65
+ "variant_Reverse Holofoil",
66
+ "variant_Unlimited",
67
+ "variant_Unlimited Holofoil",
68
+ "variant_nan"
69
+ ]
70
+ }
model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38b217807a8bf227beba2a74448010f2234742071f415f52ea9429915d37cd54
3
+ size 199132
network.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ """
5
+ Neural Network Classifier Architecture
6
+ """
7
+
8
+ class PricePredictor(nn.Module):
9
+
10
+ def __init__(self, input_size: int):
11
+ super(PricePredictor, self).__init__()
12
+ self.model = nn.Sequential(
13
+ nn.Linear(input_size, 256),
14
+ nn.ReLU(),
15
+ nn.Dropout(0.4),
16
+ nn.Linear(256, 128),
17
+ nn.ReLU(),
18
+ nn.Dropout(0.4),
19
+ nn.Linear(128, 1),
20
+ )
21
+
22
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
23
+ return self.model(x)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ pandas
3
+ numpy
4
+ scikit-learn
5
+ safetensors
6
+ gradio