LH-Tech-AI commited on
Commit
e95a6c1
·
verified ·
1 Parent(s): d4c0600

Create use_with_UI.py

Browse files
Files changed (1) hide show
  1. use_with_UI.py +113 -0
use_with_UI.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import requests
4
+ from io import BytesIO
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+ from transformers import ResNetForImageClassification
8
+
9
+ # --- 1. UI Configuration ---
10
+ # 'centered' ensures the app doesn't stretch across massive screens
11
+ st.set_page_config(page_title="GyroScope Rotation Corrector", layout="centered", page_icon="🔄")
12
+
13
+ # --- 2. Model Caching ---
14
+ # @st.cache_resource prevents reloading the model every time the user interacts with the UI
15
+ @st.cache_resource
16
+ def load_model():
17
+ model = ResNetForImageClassification.from_pretrained("LH-Tech-AI/GyroScope")
18
+ model.eval()
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ model.to(device)
21
+ return model, device
22
+
23
+ model, device = load_model()
24
+
25
+ # --- 3. Preprocessing & Logic ---
26
+ preprocess = transforms.Compose([
27
+ transforms.Resize(256),
28
+ transforms.CenterCrop(224),
29
+ transforms.ToTensor(),
30
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
31
+ ])
32
+
33
+ ANGLES = [0, 90, 180, 270]
34
+
35
+ def predict_and_correct(img):
36
+ # Ensure image is RGB
37
+ img = img.convert("RGB")
38
+ tensor = preprocess(img).unsqueeze(0).to(device)
39
+
40
+ with torch.no_grad():
41
+ logits = model(pixel_values=tensor).logits
42
+ probs = torch.softmax(logits, dim=1)[0]
43
+ pred = probs.argmax().item()
44
+
45
+ detected = ANGLES[pred]
46
+ correction = (360 - detected) % 360
47
+
48
+ # Apply correction (PIL rotate is counter-clockwise)
49
+ corrected_img = img.rotate(correction, expand=True)
50
+
51
+ # Format probabilities for the UI
52
+ prob_dict = {f"{a}°": f"{p:.4f}" for a, p in zip(ANGLES, probs)}
53
+
54
+ return corrected_img, detected, correction, prob_dict
55
+
56
+ # --- 4. Frontend Layout ---
57
+ st.title("🔄 Auto Rotation Corrector")
58
+ st.markdown("Upload an image or provide a URL to automatically fix its orientation.")
59
+
60
+ st.divider()
61
+
62
+ # Input Selection
63
+ input_method = st.radio("Select Image Source:", ["Upload a File", "Enter Image URL"], horizontal=True)
64
+
65
+ img = None
66
+
67
+ # Input Handling
68
+ if input_method == "Upload a File":
69
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
70
+ if uploaded_file:
71
+ img = Image.open(uploaded_file)
72
+ else:
73
+ url = st.text_input("Enter Image URL:", placeholder="https://example.com/image.jpg")
74
+ if url:
75
+ try:
76
+ response = requests.get(url, timeout=5)
77
+ img = Image.open(BytesIO(response.content))
78
+ except Exception as e:
79
+ st.error(f"Could not load image from URL. Error: {e}")
80
+
81
+ # Preview & Processing Section
82
+ if img:
83
+ st.divider()
84
+
85
+ # Use columns to keep the UI compact and side-by-side
86
+ col_left, col_right = st.columns(2)
87
+
88
+ with col_left:
89
+ st.subheader("Input Preview")
90
+ st.image(img, use_container_width=True)
91
+
92
+ # The primary action button
93
+ process_btn = st.button("✨ Correct Rotation", type="primary", use_container_width=True)
94
+
95
+ with col_right:
96
+ st.subheader("Output Preview")
97
+
98
+ if process_btn:
99
+ with st.spinner("Analyzing..."):
100
+ corrected_img, detected, correction, prob_dict = predict_and_correct(img)
101
+
102
+ # Show result
103
+ st.image(corrected_img, use_container_width=True)
104
+
105
+ # Show stats
106
+ st.success(f"✅ Detected: **{detected}°** | Correction: **{correction}°**")
107
+
108
+ # Hidden expander for clean UI, but available if the user wants details
109
+ with st.expander("📊 View Probability Details"):
110
+ st.json(prob_dict)
111
+ else:
112
+ # Placeholder container before the button is clicked
113
+ st.info("Waiting for processing... Click the button on the left to correct the rotation.")