Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from torchvision import datasets | |
from torchvision.transforms import ToTensor | |
# Define model | |
class NeuralNetwork(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.flatten = nn.Flatten() | |
self.linear_relu_stack = nn.Sequential( | |
nn.Linear(28*28, 512), | |
nn.ReLU(), | |
nn.Linear(512, 512), | |
nn.ReLU(), | |
nn.Linear(512, 10) | |
) | |
def forward(self, x): | |
x = self.flatten(x) | |
logits = self.linear_relu_stack(x) | |
return logits | |
model = NeuralNetwork() | |
model.load_state_dict(torch.load("model_mnist_mlp.pth")) | |
model.eval() | |
import gradio as gr | |
from torchvision import transforms | |
def predict(image): | |
tsr_image = transforms.ToTensor()(image) | |
with torch.no_grad(): | |
pred = model(tsr_image) | |
prob = torch.nn.functional.softmax(pred[0], dim=0) | |
confidences = {i: float(prob[i]) for i in range(10)} | |
return confidences | |
with gr.Blocks(css=".gradio-container {background:lightyellow;color:red;}", title="γγΉγ" | |
) as demo: | |
gr.HTML('<div style="font-size:12pt; text-align:center; color:yellow;"MNIST ει‘ε¨</div>') | |
with gr.Row(): | |
with gr.Tab("γγ£γ³γγΉ"): | |
input_image1 = gr.Image(label="η»εε ₯ε", source="canvas", type="pil", image_mode="L", shape=(28,28), invert_colors=True) | |
send_btn1 = gr.Button("δΊζΈ¬γγ") | |
with gr.Tab("η»εγγ‘γ€γ«"): | |
input_image2 = gr.Image(label="η»εε ₯ε", type="pil", image_mode="L", shape=(28, 28), invert_colors=True) | |
send_btn2 = gr.Button("δΊζΈ¬γγ") | |
gr.Examples(['examples/example02.png', 'examples/example04.png'], inputs=input_image2) | |
output_label=gr.Label(label="δΊζΈ¬η’Ίη", num_top_classes=5) | |
send_btn1.click(fn=predict, inputs=input_image1, outputs=output_label) | |
send_btn2.click(fn=predict, inputs=input_image2, outputs=output_label) | |
# demo.queue(concurrency_count=3) | |
demo.launch() | |
### EOF ### |