Spaces:
Build error
Build error
| import torch | |
| import models_Facies, models_Fault | |
| import timm | |
| from util.datasets import ThebeSet, P3DFaciesSet | |
| from util.pos_embed import interpolate_pos_embed | |
| import random | |
| import huggingface_hub | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| import numpy as np | |
| from matplotlib import cm | |
| from PIL import ImageFilter | |
| HFACE_FAULTS = "checkpoint-24.pth" | |
| HFACE_FACIES = "checkpoint-49.pth" | |
| FAULT_DATA_PATH = "C:\\Users\\abhalekar\\Desktop\\DATASETS\\Thebe_DATASET\\crossline_combined_data" | |
| FACIES_DATA_PATH = "C:\\Users\\abhalekar\\Desktop\\DATASETS\\P3D_Vol_DATASET" | |
| def predict(seismic: torch.Tensor, task='Fault', model_type='vit_large_patch16', device = 'cpu', hface = True, thresh = 0.5): | |
| if task == 'Fault': | |
| model = models_Fault.__dict__[model_type]( | |
| img_size=768, | |
| num_classes=1, | |
| drop_path_rate=0.1, | |
| in_chans=1, | |
| ) | |
| checkpoint_path = hf_hub_download(repo_id="Ani24/SFM_Finetuned", filename=HFACE_FAULTS, subfolder="ckpts-Tversky-Neut") | |
| elif task == 'Facies': | |
| model = models_Facies.__dict__[model_type]( | |
| img_size=128, | |
| num_classes=6, | |
| drop_path_rate=0.1, | |
| in_chans=1, | |
| ) | |
| checkpoint_path = hf_hub_download(repo_id="Ani24/SFM_Finetuned", filename=HFACE_FACIES, subfolder="ckpts-RSVSFacies-P3D") | |
| else: | |
| raise ValueError(f"Task not configured yet: {task}") | |
| model.to(device) | |
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) | |
| checkpoint_model = checkpoint['model'] | |
| state_dict = model.state_dict() | |
| for k in ['head.weight', 'head.bias']: | |
| if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: | |
| print(f"Removing key {k} from pretrained checkpoint") | |
| del checkpoint_model[k] | |
| interpolate_pos_embed(model, checkpoint_model) | |
| msg = model.load_state_dict(checkpoint_model, strict=False) | |
| print(msg) | |
| print("Seismic data shape:", seismic.shape) | |
| with torch.no_grad(): | |
| output = model(seismic.unsqueeze(0)) | |
| output = output.squeeze(0) | |
| if task in ['Fault']: | |
| output = torch.nn.functional.sigmoid(output) | |
| output = output.detach().cpu().numpy()[0, :, :] | |
| elif task in ['Facies']: | |
| output = output.argmax(dim=0) | |
| output = output.detach().cpu().numpy() | |
| output_image = output/ output.max() # Normalize output to [0, 1] range | |
| # output is numpy 2d array - convert to pil RGB image | |
| output_image = Image.fromarray((output_image * 255).astype(np.uint8)).convert("RGB") | |
| return output_image, output | |
| def random_sample(task = 'Fault', data_path = None, batch_size=1, num_workers=0): | |
| if task == 'Fault': | |
| data_path = FAULT_DATA_PATH | |
| dataset = ThebeSet(data_path, [768, 768], 'test') | |
| elif task == 'Facies': | |
| data_path = FACIES_DATA_PATH | |
| dataset = P3DFaciesSet(data_path, mode = 'train') | |
| else: | |
| raise ValueError(f"Task not configured yet: {task}") | |
| index = random.randint(0, len(dataset) - 1) | |
| seis, label = dataset[index] | |
| seis_image = seis.detach().cpu().numpy().squeeze(0) | |
| seis_image = (seis_image - seis_image.min()) / (seis_image.max() - seis_image.min()) # Normalize to [0, 1] range | |
| seis_image = Image.fromarray(np.uint8(cm.seismic(seis_image) * 255)) # Convert to PIL Image | |
| return seis_image, seis | |
| def overlay_images(seismic_image: Image, prediction_image: Image, alpha = 0.5) -> Image: | |
| # Create an overlay of the predicted facies/faults on the original seismic image | |
| prediction_image = Image.fromarray(np.array(prediction_image).astype(np.uint8)).convert("RGBA") | |
| seismic_image = Image.fromarray(np.array(seismic_image).astype(np.uint8)).convert("RGBA") | |
| prediction_image.putalpha(int(255 * alpha)) # Set alpha for overlay | |
| overlay_image = Image.alpha_composite(seismic_image, prediction_image) | |
| return overlay_image | |
| def post_process(processed_prediction_image: Image, prediction_image: Image, method: str = 'None', value = None) -> Image: | |
| if method == 'None': | |
| return processed_prediction_image | |
| elif method == 'Thresholding': | |
| return apply_thresholding(processed_prediction_image) | |
| elif method == 'Closing': | |
| return apply_closing(processed_prediction_image, value) | |
| elif method == 'Opening': | |
| return apply_opening(processed_prediction_image, value) | |
| elif method == 'Canny Edge': | |
| return apply_canny_edge(processed_prediction_image, value) | |
| elif method == 'Gaussian Smoothing': | |
| return apply_gaussian_smoothing(processed_prediction_image, value) | |
| elif method == 'Hysteresis': | |
| return apply_hysteresis(processed_prediction_image, value) | |
| else: | |
| raise ValueError(f"Unknown post-processing method: {method}") | |
| def apply_thresholding(image: Image, value: int) -> Image: | |
| return image.point(lambda p: p > value and 255) | |
| def apply_closing(image: Image, value: int) -> Image: | |
| # Apply closing (dilation followed by erosion) | |
| return image.filter(ImageFilter.MaxFilter(size=value)).filter(ImageFilter.MinFilter(size=value)) | |
| def apply_opening(image: Image, value: int) -> Image: | |
| # Apply opening (erosion followed by dilation) | |
| return image.filter(ImageFilter.MinFilter(size=value)).filter(ImageFilter.MaxFilter(size=value)) | |
| def apply_canny_edge(image: Image, value: int) -> Image: | |
| return image.filter(ImageFilter.FIND_EDGES) | |
| def apply_gaussian_smoothing(image: Image, value: int) -> Image: | |
| return image.filter(ImageFilter.GaussianBlur(radius=value)) | |
| def apply_hysteresis(image: Image, value: int) -> Image: | |
| return image.point(lambda p: p > value and 255) # Simple thresholding for hysteresis |