"""Main module."""
from functools import partial
from pathlib import Path
from typing import Tuple, Union
import mrcfile
import numpy as np
import torch
[docs]
def load_mrc(
load_path: Union[str, Path],
as_tensor: bool = False,
get_voxel_size: bool = False,
) -> Union[np.ndarray, torch.Tensor, Tuple[Union[np.ndarray, torch.Tensor], np.ndarray]]:
"""Loads data from an MRC file and optionally retrieves voxel size.
Args:
load_path (Union[str, Path]): Path to the MRC file.
as_tensor (bool, optional): If True, returns tensor (default=False).
get_voxel_size (bool, optional): If True, returns voxel size (default=False).
Returns:
Tuple[Union[np.ndarray, torch.Tensor], np.ndarray]: The MRC data and voxel size.
Raises:
FileNotFoundError: If the input file does not exist.
"""
if not Path(load_path).is_file():
raise FileNotFoundError(f"Error! {load_path} does not exist.")
with mrcfile.open(load_path, permissive=True) as mrc:
voxel_size = np.array(
[mrc.voxel_size.z, mrc.voxel_size.y, mrc.voxel_size.x], dtype=np.float32
)
data = mrc.data
data = torch.from_numpy(data) if as_tensor else data
return (data, voxel_size) if get_voxel_size else data
[docs]
def save_mrc(
save_path: Union[str, Path],
save_data: Union[np.ndarray, torch.Tensor],
voxel_size: float = 1.0,
) -> None:
"""Saves data to an MRC file.
Args:
save_path (Union[str, Path]): Path to save the MRC file.
save_data (Union[np.ndarray, torch.Tensor]): Data to save.
voxel_size (float, optional): Voxel size of the data (default=1.0).
"""
if isinstance(save_data, torch.Tensor):
save_data = save_data.detach().cpu().numpy()
with mrcfile.new(save_path, overwrite=True) as mrc:
mrc.set_data(save_data)
mrc.voxel_size = (voxel_size,) * 3
[docs]
class FourierCrop:
"""Enables downsampling and other operations based on Fourier domain cropping.
It supports 2D tensors in the BCHW format and 3D tensors in the BCDHW format.
"""
def __init__(
self,
pad_mode: int = 0,
dim: Tuple = (-3, -2, -1),
epsilon: float = 1e-6,
) -> None:
"""Initializes with specified padding mode, dimensions, and epsilon value.
Args:
pad_mode (int, optional): Determines the cropping function to use.
dim (Tuple, optional):
Dimensions over which to perform operations (default=(-3, -2, -1)).
epsilon (float, optional):
Small value to avoid division by zero (default=1e-6).
"""
super().__init__()
self.crop_func = {
0: self.crop_center,
1: self.pad_center,
2: self.crop_center_pad,
}[pad_mode]
self.norm_func = partial(self.norm, dim=dim, epsilon=epsilon)
self.fft_func = partial(self.fft, dim=dim)
self.ifft_func = partial(self.ifft, dim=dim)
[docs]
@staticmethod
def crop_center(x: torch.Tensor, bin_factor: int = 2) -> torch.Tensor:
"""Crops the central region of a tensor based on a specified bin_factor factor.
Args:
x (torch.Tensor):
Input 2D tensors in the BCHW format or 3D tensors in the BCDHW format.
bin_factor (int, optional):
Factor determining the size of the cropped region (default=2).
Returns:
torch.Tensor: Cropped tensor.
"""
input_shape = x.shape[2:]
if len(input_shape) not in [2, 3]:
raise ValueError("Unsupported dimension. Supported values are 2 or 3.")
input_center = [s // 2 for s in input_shape]
target_center = [s // (2 * bin_factor) for s in input_shape]
crop_slice = tuple(slice(i - t, i + t) for i, t in zip(input_center, target_center))
return x[(..., *crop_slice)]
[docs]
@staticmethod
def crop_center_pad(x: torch.Tensor, bin_factor: int = 2) -> torch.Tensor:
"""Crops the central region of a tensor and pads it back to its original size.
Args:
x (torch.Tensor):
Input 2D tensors in the BCHW format or 3D tensors in the BCDHW format.
bin_factor (int, optional):
Factor determining the size of the cropped region (default=2).
Returns:
torch.Tensor: Cropped and padded tensor.
"""
x_pad = torch.zeros_like(x)
x_crop = FourierCrop.crop_center(x, bin_factor=bin_factor)
input_shape = x_pad.shape[2:]
input_center = [s // 2 for s in input_shape]
target_center = [s // (2 * bin_factor) for s in input_shape]
crop_slice = tuple(slice(i - t, i + t) for i, t in zip(input_center, target_center))
x_pad[(..., *crop_slice)] = x_crop
return x_pad
[docs]
@staticmethod
def pad_center(x: torch.Tensor, bin_factor: int = 2) -> torch.Tensor:
"""Centers the original tensor within the new padded tensor."""
pad_shape = [s * 2 for s in x.shape[2:]]
pad_shape = list(x.shape[:2]) + pad_shape
x_pad = torch.zeros(pad_shape, dtype=x.dtype, device=x.device)
input_shape = x_pad.shape[2:]
input_center = [s // 2 for s in input_shape]
target_center = [s // (2 * bin_factor) for s in input_shape]
crop_slice = tuple(slice(i - t, i + t) for i, t in zip(input_center, target_center))
x_pad[(..., *crop_slice)] = x
return x_pad
[docs]
@staticmethod
def norm(x: torch.Tensor, dim: Tuple = (-3, -2, -1), epsilon: float = 1e-6) -> torch.Tensor:
"""Normalizes a tensor by its mean and standard deviation."""
mean = x.mean(dim=dim, keepdim=True)
std = x.std(dim=dim, keepdim=True)
return (x - mean) / (std + epsilon)
[docs]
@staticmethod
def fft(x: torch.Tensor, dim: Tuple = (-3, -2, -1), norm: str = "ortho") -> torch.Tensor:
"""Applies 3D Fast Fourier Transform (FFT) to input data."""
return torch.fft.fftshift(torch.fft.fftn(x, dim=dim, norm=norm), dim=dim)
[docs]
@staticmethod
def ifft(x: torch.Tensor, dim: Tuple = (-3, -2, -1), norm: str = "ortho") -> torch.Tensor:
"""Applies Inverse Fast Fourier Transform (IFFT) to input data."""
return torch.fft.ifftn(torch.fft.ifftshift(x, dim=dim), dim=dim, norm=norm)
def __call__(
self,
x: torch.Tensor,
bin_factor: int = 2,
norm_flag: bool = False,
) -> torch.Tensor:
"""Applies Fourier transform, crop, and inverse transform."""
if norm_flag:
x = self.norm_func(x)
x = self.fft_func(x)
x = self.crop_func(x, bin_factor)
x = self.ifft_func(x).real
x = self.norm_func(x)
else:
x = self.fft_func(x)
x = self.crop_func(x, bin_factor)
x = self.ifft_func(x).real
return x