abhiswain's picture
Update gradio_app.py
ffc1c91
import gradio as gr
import torch
import json
import src.config as CFG
from src.model import HNet
def classify(model, image, mapping):
image = torch.from_numpy(image).float()
image = image.permute(2, 0, 1).unsqueeze(0)
outputs = model(image)
_, preds = torch.max(outputs, 1)
return f"The predicted character is: {mapping[str(preds[0].item())]}"
def upload_and_clasify(image, option):
mapping, model = None, None
if option == "Digit":
if CFG.BEST_MODEL_DIGIT.exists():
model = HNet(num_classes=10)
model.load_state_dict(
torch.load(CFG.BEST_MODEL_DIGIT, map_location=CFG.DEVICE)
)
with open(CFG.INDEX_DIGIT, "r") as f:
mapping = json.load(f)
return classify(model, image, mapping)
else:
if CFG.BEST_MODEL_VYANJAN.exists():
model = HNet(num_classes=36)
model.load_state_dict(
torch.load(CFG.BEST_MODEL_VYANJAN, map_location=CFG.DEVICE)
)
with open(CFG.INDEX_VYNAJAN, "r") as f:
mapping = json.load(f)
return classify(model, image, mapping)
demo = gr.Interface(
fn=upload_and_clasify,
inputs=["image", gr.Dropdown(["Digit", "Vyanjan"])],
outputs="text",
)
demo.launch(debug=True, enable_queue=True)