File size: 1,611 Bytes
c90d799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os
import numpy as np
import torch
from transformers import ViTForImageClassification, ViTImageProcessor
import nibabel as nib  # For loading .nii files
from PIL import Image  # For loading .jpg and .jpeg files

# Function to preprocess images based on their file format
def preprocess_image(image_path):
    ext = os.path.splitext(image_path)[-1].lower()  # Get the file extension

    # Case 1: .nii files (NIfTI format)
    if ext == '.nii' or ext == '.nii.gz':
        # Load the .nii image
        nii_image = nib.load(image_path)
        image_data = nii_image.get_fdata()

        # Convert to tensor and reshape to [C, H, W] format
        image_tensor = torch.tensor(image_data).float()

        # Handle cases where the image might have a different shape (e.g., single channel vs multiple channels)
        if len(image_tensor.shape) == 3:
            image_tensor = image_tensor.unsqueeze(0)  # Add channel dimension if not present

    # Case 2: .jpg and .jpeg files (JPEG format)
    elif ext in ['.jpg', '.jpeg']:
        # Load the image using PIL
        img = Image.open(image_path).convert('RGB')  # Convert to RGB
        img = img.resize((224, 224))  # Resize to the input size expected by ViT (224x224)

        # Convert to numpy array and then to tensor
        img_np = np.array(img)
        image_tensor = torch.tensor(img_np).permute(2, 0, 1).float()  # Rearrange to [C, H, W]

    else:
        raise ValueError(f"Unsupported file format: {ext}")

    # Normalize image tensor (if required)
    image_tensor /= 255.0  # Normalize pixel values to [0, 1]

    return image_tensor