|
|
import os |
|
|
import json |
|
|
import torch |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Any |
|
|
|
|
|
try: |
|
|
from rtdl import FTTransformer |
|
|
except ImportError: |
|
|
print("RTDL not available") |
|
|
FTTransformer = None |
|
|
|
|
|
class GohanInference: |
|
|
def __init__(self, model_path: str = None): |
|
|
"""Initialize the Gohan inference model""" |
|
|
self.model_path = model_path or "epoch_030_p30_0.7736.pt" |
|
|
self.config_path = "configs/config.json" |
|
|
|
|
|
|
|
|
self.config = self._load_config() |
|
|
|
|
|
|
|
|
self.model = self._load_model() |
|
|
self.encoders = self._load_encoders() |
|
|
self.product_master = self._load_product_master() |
|
|
|
|
|
def _load_config(self) -> Dict[str, Any]: |
|
|
"""Load model configuration""" |
|
|
with open(self.config_path, 'r', encoding='utf-8') as f: |
|
|
return json.load(f) |
|
|
|
|
|
def _load_model(self): |
|
|
"""Load the PyTorch model""" |
|
|
if FTTransformer is None: |
|
|
raise ImportError("RTDL is required for model inference") |
|
|
|
|
|
|
|
|
model = torch.load(self.model_path, map_location='cpu') |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
def _load_encoders(self) -> Dict[str, Any]: |
|
|
"""Load JSON encoders""" |
|
|
encoders = {} |
|
|
encoder_config = self.config['encoders'] |
|
|
|
|
|
for key, file_path in encoder_config.items(): |
|
|
if key != 'product_master': |
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
|
encoders[key] = json.load(f) |
|
|
|
|
|
return encoders |
|
|
|
|
|
def _load_product_master(self) -> pd.DataFrame: |
|
|
"""Load product master data""" |
|
|
return pd.read_csv(self.config['product_master'], encoding='utf-8-sig') |
|
|
|
|
|
def predict(self, input_data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
"""Make predictions""" |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
def load_model(model_path: str = None): |
|
|
"""Load function for Hugging Face""" |
|
|
return GohanInference(model_path) |
|
|
|
|
|
def predict(model, inputs): |
|
|
"""Prediction function for Hugging Face""" |
|
|
return model.predict(inputs) |