from PIL import Image from io import BytesIO from torchvision import transforms from typing import Literal, Any import gradio as gr from matplotlib.figure import Figure import matplotlib.pyplot as plt import spaces import torch import torch.nn.functional as F LABELS = [ "Panoramic", "Feature", "Detail", "Enclosed", "Focal", "Ephemeral", "Canopied", ] device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = torch.load( "Litton-7type-visual-landscape-model.pth", map_location=device, weights_only=False ).module model.eval() preprocess = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ] ) @spaces.GPU def predict(image: Image.Image) -> Figure: image = image.convert("RGB") input_tensor = preprocess(image).unsqueeze(0).to(device) with torch.no_grad(): logits = model(input_tensor) probs = F.softmax(logits[:, :7], dim=1).cpu() return draw_bar_chart( { "class": LABELS, "probs": probs[0] * 100, } ) def draw_bar_chart(data: dict[str, list[str | float]]): classes = data["class"] probabilities = data["probs"] fig, ax = plt.subplots(figsize=(8, 6)) ax.bar(classes, probabilities, color="skyblue") ax.set_xlabel("Class") ax.set_ylabel("Probability (%)") ax.set_title("Class Probability") for i, prob in enumerate(probabilities): ax.text(i, prob + 0.01, f"{prob:.2f}%", ha="center", va="bottom") fig.tight_layout() return fig def get_layout(): css = """ .main-title { font-size: 24px; font-weight: bold; text-align: center; margin-bottom: 20px; } .reference { text-align: center; font-size: 1.2em; color: #d1d5db; margin-bottom: 20px; } .reference a { color: #FB923C; text-decoration: none; } .reference a:hover { text-decoration: underline; color: #FB923C; } .title { border-bottom: 1px solid; } .footer { text-align: center; margin-top: 30px; padding-top: 20px; border-top: 1px solid #ddd; color: #d1d5db; font-size: 14px; } """ theme = gr.themes.Base( primary_hue="orange", secondary_hue="orange", neutral_hue="gray", font=gr.themes.GoogleFont("Source Sans Pro"), ).set( background_fill_primary="*neutral_950", # 主背景色(深黑) button_primary_background_fill="*primary_500", # 按鈕顏色(橘色) body_text_color="*neutral_200", # 文字顏色(淺色) ) with gr.Blocks(css=css, theme=theme) as demo: with gr.Column(): gr.HTML( value=( '
Litton7景觀分類模型
' '
引用資料:' '' "何立智、李沁築、邱浩修(2024)。Litton7:Litton視覺景觀分類深度學習模型。戶外遊憩研究,37(2)" "" "
" ), ) with gr.Row(equal_height=True): image_input = gr.Image(label="上傳影像", type="pil") chart = gr.Plot(label="分類結果") start_button = gr.Button("開始分類", variant="primary") gr.HTML( '', ) start_button.click( fn=predict, inputs=image_input, outputs=chart, ) return demo if __name__ == "__main__": get_layout().queue().launch()