from PIL import Image from io import BytesIO from torchvision import transforms from typing import Literal, Any import gradio as gr from matplotlib.figure import Figure import matplotlib.pyplot as plt import spaces import torch import torch.nn.functional as F LABELS = [ "Panoramic", "Feature", "Detail", "Enclosed", "Focal", "Ephemeral", "Canopied", ] device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = torch.load( "Litton-7type-visual-landscape-model.pth", map_location=device, weights_only=False ).module model.eval() preprocess = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ] ) @spaces.GPU def predict(image: Image.Image) -> Figure: image = image.convert("RGB") input_tensor = preprocess(image).unsqueeze(0).to(device) with torch.no_grad(): logits = model(input_tensor) probs = F.softmax(logits[:, :7], dim=1).cpu() return draw_bar_chart( { "class": LABELS, "probs": probs[0] * 100, } ) def draw_bar_chart(data: dict[str, list[str | float]]): classes = data["class"] probabilities = data["probs"] fig, ax = plt.subplots(figsize=(8, 6)) ax.bar(classes, probabilities, color="skyblue") ax.set_xlabel("Class") ax.set_ylabel("Probability (%)") ax.set_title("Class Probability") for i, prob in enumerate(probabilities): ax.text(i, prob + 0.01, f"{prob:.2f}%", ha="center", va="bottom") fig.tight_layout() return fig def get_layout(): css = """ .main-title { font-size: 24px; font-weight: bold; text-align: center; margin-bottom: 20px; } .reference { text-align: center; font-size: 1.2em; color: #d1d5db; margin-bottom: 20px; } .reference a { color: #FB923C; text-decoration: none; } .reference a:hover { text-decoration: underline; color: #FB923C; } .title { border-bottom: 1px solid; } .footer { text-align: center; margin-top: 30px; padding-top: 20px; border-top: 1px solid #ddd; color: #d1d5db; font-size: 14px; } """ theme = gr.themes.Base( primary_hue="orange", secondary_hue="orange", neutral_hue="gray", font=gr.themes.GoogleFont("Source Sans Pro"), ).set( background_fill_primary="*neutral_950", # 主背景色(深黑) button_primary_background_fill="*primary_500", # 按鈕顏色(橘色) body_text_color="*neutral_200", # 文字顏色(淺色) ) with gr.Blocks(css=css, theme=theme) as demo: with gr.Column(): gr.HTML( value=( '