SFM_Inference_Demo / inference.py
Anirudh Bhalekar
sliders
b7ab39c
import torch
import models_Facies, models_Fault
import timm
from util.datasets import ThebeSet, P3DFaciesSet
from util.pos_embed import interpolate_pos_embed
import random
import huggingface_hub
from huggingface_hub import hf_hub_download
from PIL import Image
import numpy as np
from matplotlib import cm
from PIL import ImageFilter
HFACE_FAULTS = "checkpoint-24.pth"
HFACE_FACIES = "checkpoint-49.pth"
FAULT_DATA_PATH = "C:\\Users\\abhalekar\\Desktop\\DATASETS\\Thebe_DATASET\\crossline_combined_data"
FACIES_DATA_PATH = "C:\\Users\\abhalekar\\Desktop\\DATASETS\\P3D_Vol_DATASET"
def predict(seismic: torch.Tensor, task='Fault', model_type='vit_large_patch16', device = 'cpu', hface = True, thresh = 0.5):
if task == 'Fault':
model = models_Fault.__dict__[model_type](
img_size=768,
num_classes=1,
drop_path_rate=0.1,
in_chans=1,
)
checkpoint_path = hf_hub_download(repo_id="Ani24/SFM_Finetuned", filename=HFACE_FAULTS, subfolder="ckpts-Tversky-Neut")
elif task == 'Facies':
model = models_Facies.__dict__[model_type](
img_size=128,
num_classes=6,
drop_path_rate=0.1,
in_chans=1,
)
checkpoint_path = hf_hub_download(repo_id="Ani24/SFM_Finetuned", filename=HFACE_FACIES, subfolder="ckpts-RSVSFacies-P3D")
else:
raise ValueError(f"Task not configured yet: {task}")
model.to(device)
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
checkpoint_model = checkpoint['model']
state_dict = model.state_dict()
for k in ['head.weight', 'head.bias']:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
interpolate_pos_embed(model, checkpoint_model)
msg = model.load_state_dict(checkpoint_model, strict=False)
print(msg)
print("Seismic data shape:", seismic.shape)
with torch.no_grad():
output = model(seismic.unsqueeze(0))
output = output.squeeze(0)
if task in ['Fault']:
output = torch.nn.functional.sigmoid(output)
output = output.detach().cpu().numpy()[0, :, :]
elif task in ['Facies']:
output = output.argmax(dim=0)
output = output.detach().cpu().numpy()
output_image = output/ output.max() # Normalize output to [0, 1] range
# output is numpy 2d array - convert to pil RGB image
output_image = Image.fromarray((output_image * 255).astype(np.uint8)).convert("RGB")
return output_image, output
def random_sample(task = 'Fault', data_path = None, batch_size=1, num_workers=0):
if task == 'Fault':
data_path = FAULT_DATA_PATH
dataset = ThebeSet(data_path, [768, 768], 'test')
elif task == 'Facies':
data_path = FACIES_DATA_PATH
dataset = P3DFaciesSet(data_path, mode = 'train')
else:
raise ValueError(f"Task not configured yet: {task}")
index = random.randint(0, len(dataset) - 1)
seis, label = dataset[index]
seis_image = seis.detach().cpu().numpy().squeeze(0)
seis_image = (seis_image - seis_image.min()) / (seis_image.max() - seis_image.min()) # Normalize to [0, 1] range
seis_image = Image.fromarray(np.uint8(cm.seismic(seis_image) * 255)) # Convert to PIL Image
return seis_image, seis
def overlay_images(seismic_image: Image, prediction_image: Image, alpha = 0.5) -> Image:
# Create an overlay of the predicted facies/faults on the original seismic image
prediction_image = Image.fromarray(np.array(prediction_image).astype(np.uint8)).convert("RGBA")
seismic_image = Image.fromarray(np.array(seismic_image).astype(np.uint8)).convert("RGBA")
prediction_image.putalpha(int(255 * alpha)) # Set alpha for overlay
overlay_image = Image.alpha_composite(seismic_image, prediction_image)
return overlay_image
def post_process(processed_prediction_image: Image, prediction_image: Image, method: str = 'None', value = None) -> Image:
if method == 'None':
return processed_prediction_image
elif method == 'Thresholding':
return apply_thresholding(processed_prediction_image)
elif method == 'Closing':
return apply_closing(processed_prediction_image, value)
elif method == 'Opening':
return apply_opening(processed_prediction_image, value)
elif method == 'Canny Edge':
return apply_canny_edge(processed_prediction_image, value)
elif method == 'Gaussian Smoothing':
return apply_gaussian_smoothing(processed_prediction_image, value)
elif method == 'Hysteresis':
return apply_hysteresis(processed_prediction_image, value)
else:
raise ValueError(f"Unknown post-processing method: {method}")
def apply_thresholding(image: Image, value: int) -> Image:
return image.point(lambda p: p > value and 255)
def apply_closing(image: Image, value: int) -> Image:
# Apply closing (dilation followed by erosion)
return image.filter(ImageFilter.MaxFilter(size=value)).filter(ImageFilter.MinFilter(size=value))
def apply_opening(image: Image, value: int) -> Image:
# Apply opening (erosion followed by dilation)
return image.filter(ImageFilter.MinFilter(size=value)).filter(ImageFilter.MaxFilter(size=value))
def apply_canny_edge(image: Image, value: int) -> Image:
return image.filter(ImageFilter.FIND_EDGES)
def apply_gaussian_smoothing(image: Image, value: int) -> Image:
return image.filter(ImageFilter.GaussianBlur(radius=value))
def apply_hysteresis(image: Image, value: int) -> Image:
return image.point(lambda p: p > value and 255) # Simple thresholding for hysteresis