KN2024DockerFinal / pages /⭐ Thoracic Classification.py
datnguyentien204's picture
Upload 338 files
8e0b903 verified
import streamlit as st
from PIL import Image
import cv2
import pydicom
import numpy as np
from streamlit_image_zoom import image_zoom
import time
import pandas as pd
import os
import subprocess
import sys
try:
import torchmcubes
import torch
import torchvision
import fpdf
except ImportError:
subprocess.check_call(['pip', 'install', 'git+https://github.com/tatsy/torchmcubes.git'])
subprocess.check_call(['pip', 'install','fpdf'])
from fpdf import FPDF
############### Import PATH
script_dir = os.path.dirname(os.path.abspath(__file__))
chestXray14_path = os.path.join(script_dir, '..', 'chestXray14')
sys.path.append(chestXray14_path)
@st.cache_resource
def convert_dcm_to_png(input_image_path, output_image_path='a.png'):
ds = pydicom.dcmread(input_image_path)
img = ds.pixel_array
img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
cv2.imwrite(output_image_path, img)
from chestXray14.test import process_image
def predictAll(image_path):
result, segment_result_path, cam_result_path = process_image(image_path)
print("Prediction Results:", result)
print(f"Segmentation Result Saved to: {segment_result_path}")
print(f"CAM Result Saved to: {cam_result_path}")
def load_report():
df = pd.read_csv('pages/images/prediction_results.csv')
df.columns = ['Bệnh lý', 'Xác suất']
translation_dict = {
'Infiltration': 'Thâm nhiễm',
'Nodule': 'Nốt',
'Pleural Thickening': 'Dày màng phổi',
'Cardiomegaly': 'Tim to',
'Effusion': 'Tràn dịch',
'Pneumonia': 'Viêm phổi',
'Atelectasis': 'Xẹp phổi',
'Mass': 'Khối u',
'Fibrosis': 'Xơ phổi',
'Pneumothorax': 'Tràn khí màng phổi'
}
df['Bệnh lý'] = df['Bệnh lý'].map(translation_dict)
df['Xác suất'] = df['Xác suất'].astype(float) * 100
df['Xác suất'] = df['Xác suất'].round(2).astype(str) + '%'
def highlight_rows(row):
if row.name == 0:
return ['background-color: darkred; color: white'] * len(row)
if row.name == 1:
return ['background-color: darkblue; color: white'] * len(row)
if row.name == 2:
return ['background-color: lightblue; color: white'] * len(row)
else:
return [''] * len(row)
df_styled = df.style.apply(highlight_rows, axis=1).set_table_styles(
[{'selector': 'thead th', 'props': [('background-color', '#d3d3d3')]}]
)
return df_styled
st.markdown("<h1 style='text-align: center;'>Welcome to Thoracic Classification 🎈</h1>", unsafe_allow_html=True)
with st.sidebar:
st.markdown("## Upload your scans")
uploaded_files = st.file_uploader("Choose scans...", type=["jpg", "jpeg", "png", "dicom"], accept_multiple_files=True)
with st.expander("Hướng dẫn"):
st.markdown("1. Tải lên ảnh Scan của bạn bằng cách ấn vào **Browse files** hoặc có thể **Kéo và thả** file ảnh của bạn vào phần browse files. Các định dạng cho phép bao gồm **DICOM, PNG, JPG, JPEG**, các định dạng khác cần phải chuyển về các định dạng được chấp nhận.")
st.markdown("2. Sau đó ảnh sẽ tự được mở lên")
st.markdown("3. Để phóng to ảnh, bạn chuyển chuột trái vào trong ảnh, dùng lăn chuột để thực hiện phóng to- thu nhỏ ảnh")
st.markdown("4. Để kéo xuống xem ảnh phía dưới, bạn di chuột ra ngoài vùng ảnh và dùng lăn chuột cuộn trang như bình thường.")
status_images=False
col_1, col_2 = st.columns([7, 5.5])
with col_1:
if uploaded_files:
for uploaded_file in uploaded_files:
file_type = uploaded_file.name.split('.')[-1].lower()
if file_type in ["jpg", "jpeg", "png"]:
img = Image.open(uploaded_file)
img.save('temp_image.png')
st.markdown("<div style='display: flex; justify-content: center;'>", unsafe_allow_html=True)
width, height = img.size
image_zoom(img, mode="both")
st.markdown("</div>", unsafe_allow_html=True)
status_images=True
elif file_type in ["dicom", "dcm"]:
convert_dcm_to_png(uploaded_file)
img = Image.open('a.png').convert('RGB')
img.save('temp_image.png')
st.markdown("<div style='display: flex; justify-content: center;'>", unsafe_allow_html=True)
width, height = img.size
image_zoom(img, mode="both",size=(width//4, height//4), keep_aspect_ratio=True, zoom_factor=4.0, increment=0.2)
st.markdown("</div>", unsafe_allow_html=True)
status_images=True
else:
st.info("Please upload some scans to view them.")
############ CREATE PDF
import io
def generate_pdf(name, age, gender, address, phone):
pdf = FPDF()
pdf.add_page()
pdf.set_font("Arial", size=12)
# Title
pdf.set_font("Arial", style='B', size=16)
pdf.cell(200, 10, txt="Patient Report", ln=True, align='C')
pdf.ln(10)
# Patient details
pdf.set_font("Arial", size=12)
pdf.cell(200, 10, txt=f"Name: {name}", ln=True, align='L')
pdf.cell(200, 10, txt=f"Age: {age}", ln=True, align='L')
pdf.cell(200, 10, txt=f"Gender: {gender}", ln=True, align='L')
pdf.cell(200, 10, txt=f"Address: {address}", ln=True, align='L')
pdf.cell(200, 10, txt=f"Phone: {phone}", ln=True, align='L')
pdf.ln(10)
# Placeholder for additional content
pdf.cell(200, 10, txt="Predicted Disease Probabilities:", ln=True, align='L')
pdf.ln(10)
# Simulate adding prediction data (replace this with actual data)
diseases = ['Disease A', 'Disease B', 'Disease C']
probabilities = ['70%', '50%', '30%']
for disease, probability in zip(diseases, probabilities):
pdf.cell(200, 10, txt=f"{disease}: {probability}", ln=True, align='L')
# Add image (optional)
pdf.ln(10)
pdf.cell(200, 10, txt="Class Activation Map (CAM):", ln=True, align='L')
image_path = 'pages/images/cam_result.png' # Adjust this path as necessary
if os.path.exists(image_path):
pdf.image(image_path, x=10, y=pdf.get_y(), w=100)
# Save the PDF to a bytes buffer
pdf_buffer = io.BytesIO()
pdf.output(pdf_buffer)
pdf_buffer.seek(0) # Move the cursor to the beginning of the buffer
return pdf_buffer
def download_report(name, age, gender, address, phone):
st.markdown("<h2 style='text-align: center;'>Patient Report</h2>", unsafe_allow_html=True)
st.write(f"**Name:** {name}")
st.write(f"**Age:** {age}")
st.write(f"**Gender:** {gender}")
st.write(f"**Address:** {address}")
st.write(f"**Phone:** {phone}")
# Load the prediction report
df_styled = load_report()
st.markdown("<div style='display: flex; justify-content: center;'>", unsafe_allow_html=True)
st.write(df_styled.to_html(), unsafe_allow_html=True)
# Simulating the addition of an image with a caption
img = Image.open('pages/images/cam_result.png').convert('RGB')
st.image(img, caption="Class Activation Map (CAM) Visualization", use_column_width=True)
# Provide a link to download the report
st.markdown(
"<a href='pages/images/prediction_results.csv' download='prediction_results.csv'>Click here to download the report</a>",
unsafe_allow_html=True)
if(status_images):
with col_2:
st.markdown("<h2 style='text-align: center;'>Function</h2>", unsafe_allow_html=True)
btn_predictAll_Scans = st.button("Predict All Scans")
btn_CAM_Visualization = st.button("CAM Visualization")
btn_Segment_Lung = st.button("Segmentation Visualization for Lung")
btn_View_Report = st.button("View Report")
btn_Download_Report = st.button("Download Report")
if btn_predictAll_Scans:
start_time = time.time()
predictAll('temp_image.png')
elapsed_time = time.time() - start_time
st.success(f"Predicted all Scans success - ⏳ {int(elapsed_time)} seconds. You can use CAM, Segmentation, View, and Download Report", icon="✅")
button_status=True
st.divider()
col_3, col_4, col_5 = st.columns([4,7.5, 6])
with col_4:
if btn_Segment_Lung:
st.markdown("<h2 style='text-align: center;color:red;margin-left: 100px'>Segmentation Image for Lung </h2>", unsafe_allow_html=True)
img = Image.open('pages/images/segment_result.png').convert('RGB')
st.markdown("<div style='display: flex; justify-content: center;'>", unsafe_allow_html=True)
width, height = img.size
image_zoom(img, mode="both", size=(width // 4, height // 4), keep_aspect_ratio=True, zoom_factor=4.0, increment=0.2)
with col_4:
if btn_CAM_Visualization:
st.markdown("<h2 style='text-align: center;text-color:red;margin-left: 100px'>Class Activation Map(CAM) Visualization </h2>", unsafe_allow_html=True)
img = Image.open('pages/images/cam_result.png').convert('RGB')
st.markdown("<div style='display: flex; justify-content: center;'>", unsafe_allow_html=True)
width, height = img.size
image_zoom(img, mode="both", size=(width // 4, height // 4), keep_aspect_ratio=True, zoom_factor=4.0, increment=0.2)
with col_4:
if btn_Download_Report:
with st.form("patient_info_form"):
st.write("Please provide patient details before downloading the report:")
name = st.text_input("Name")
age = st.number_input("Age", min_value=0, max_value=130)
gender = st.selectbox("Gender", ["Male", "Female", "Other"])
address = st.text_input("Address")
phone = st.text_input("Phone")
submit = st.form_submit_button("Submit")
if submit:
pdf_buffer = generate_pdf(name, age, gender, address, phone)
print(pdf_buffer)
st.download_button(
label="Download Report",
data=pdf_buffer,
file_name="patient_report.pdf",
mime="application/pdf"
)
col_6, col_7, col_8 = st.columns([7.8, 4.5, 8])
with col_7:
if btn_View_Report:
st.markdown("<h2 style='text-align: center;'>Prediction Report</h2>", unsafe_allow_html=True)
df_styled = load_report()
st.markdown("<div style='display: flex; justify-content: center;'>", unsafe_allow_html=True)
st.write(df_styled.to_html(), unsafe_allow_html=True)