Loss functions

A set of custom loss functions

MSELoss

 MSELoss (inp:Any, targ:Any)

L1Loss

 L1Loss (inp:Any, targ:Any)

SSIMLoss

 SSIMLoss (spatial_dims:int, data_range:float=1.0,
           kernel_type:monai.metrics.regression.KernelType|str=gaussian,
           win_size:int|collections.abc.Sequence[int]=11,
           kernel_sigma:float|collections.abc.Sequence[float]=1.5,
           k1:float=0.01, k2:float=0.03,
           reduction:monai.utils.enums.LossReduction|str=mean)

Compute the loss function based on the Structural Similarity Index Measure (SSIM) Metric.

For more info, visit https://vicuesoft.com/glossary/term/ssim-ms-ssim/

SSIM reference paper: Wang, Zhou, et al. “Image quality assessment: from error visibility to structural similarity.” IEEE transactions on image processing 13.4 (2004): 600-612.

Combined Losses


CombinedLoss

 CombinedLoss (spatial_dims=2, mse_weight=0.33, mae_weight=0.33)

CombinedLoss computes a weighted combination of SSIM, MSE, and MAE losses.

This class allows for the combination of three different loss functions: Structural Similarity Index (SSIM), Mean Squared Error (MSE), and Mean Absolute Error (MAE). The weights for MSE and MAE can be adjusted, and the weight for SSIM is automatically calculated as the remaining weight.

CombinedLoss reference paper: Shah, Z. H., Müller, M., Hammer, B., Huser, T., & Schenck, W. (2022, July). Impact of different loss functions on denoising of microscopic images. In 2022 International Joint Conference on Neural Networks (IJCNN) (pp. 1-10). IEEE.

Type Default Details
spatial_dims int 2 Number of spatial dimensions (2 for 2D images, 3 for 3D images)
mse_weight float 0.33 Weight for the MSE loss component
mae_weight float 0.33 Weight for the MAE loss component

MSSSIMLoss

 MSSSIMLoss (spatial_dims=2, window_size:int=8, sigma:float=1.5,
             reduction:str='mean', levels:int=3, weights=None)

Multi-Scale Structural Similarity (MSSSIM) Loss using MONAI’s SSIMLoss as the base.

Type Default Details
spatial_dims int 2 Number of spatial dimensions (2 for 2D images, 3 for 3D images).
window_size int 8 Size of the Gaussian filter for SSIM.
sigma float 1.5 Standard deviation of the Gaussian filter.
reduction str mean Specifies the reduction to apply to the output (‘mean’, ‘sum’, or ‘none’).
levels int 3 Number of scales to use for MS-SSIM.
weights NoneType None Weights to apply to each scale. If None, default values are used.
msssim_loss = MSSSIMLoss(levels=3)
ssim_loss = SSIMLoss(2)
output = torch.rand(10, 3, 64, 64).cuda()  # Example output
target = torch.rand(10, 3, 64, 64).cuda()  # Example target
loss = msssim_loss(output, target)
loss2 = ssim_loss(output,target)
print("ms-ssim: ",loss, '\nssim: ', loss2)
ms-ssim:  tensor(0.9686, device='cuda:0') 
ssim:  tensor(0.9949, device='cuda:0')

MSSSIML1Loss

 MSSSIML1Loss (spatial_dims=2, alpha:float=0.025, window_size:int=8,
               sigma:float=1.5, reduction:str='mean', levels:int=3,
               weights=None)

Multi-Scale Structural Similarity (MSSSIM) with Gaussian-weighted L1 Loss.

Reference paper: Zhao, H., Gallo, O., Frosio, I., & Kautz, J. (2016). Loss functions for image restoration with neural networks. IEEE Transactions on computational imaging, 3(1), 47-57.

Type Default Details
spatial_dims int 2 Number of spatial dimensions.
alpha float 0.025 Weighting factor between MS-SSIM and L1 loss.
window_size int 8 Size of the Gaussian filter for SSIM.
sigma float 1.5 Standard deviation of the Gaussian filter.
reduction str mean Specifies the reduction to apply to the output (‘mean’, ‘sum’, or ‘none’).
levels int 3 Number of scales to use for MS-SSIM.
weights NoneType None Weights to apply to each scale. If None, default values are used.
msssiml1_loss = MSSSIML1Loss(alpha=0.025, window_size=11, sigma=1.5, levels=3)
input_image = torch.randn(4, 1, 128, 128)  # Batch of 4 grayscale images (1 channel)
target_image = torch.randn(4, 1, 128, 128)

# Compute MSSSIM + Gaussian-weighted L1 loss
loss = msssiml1_loss(input_image, target_image)
loss2 = ssim_loss(input_image, target_image)
print("ms-ssim: ", loss, '\nssim: ', loss2)
ms-ssim:  tensor(0.0250) 
ssim:  tensor(0.9955)

MSSSIML2Loss

 MSSSIML2Loss (spatial_dims=2, alpha:float=0.1, window_size:int=11,
               sigma:float=1.5, reduction:str='mean', levels:int=3,
               weights=None)

Multi-Scale Structural Similarity (MSSSIM) with Gaussian-weighted L2 Loss.

Reference paper: Zhao, H., Gallo, O., Frosio, I., & Kautz, J. (2016). Loss functions for image restoration with neural networks. IEEE Transactions on computational imaging, 3(1), 47-57.

Type Default Details
spatial_dims int 2 Number of spatial dimensions.
alpha float 0.1 Weighting factor between MS-SSIM and L2 loss.
window_size int 11 Size of the Gaussian window for SSIM.
sigma float 1.5 Standard deviation of the Gaussian.
reduction str mean Specifies the reduction to apply to the output (‘mean’, ‘sum’, or ‘none’).
levels int 3 Number of scales to use for MS-SSIM.
weights NoneType None Weights to apply to each scale. If None, default values are used.
msssim_l2_loss = MSSSIML2Loss()
output = torch.rand(10, 3, 64, 64).cuda()  # Example output with even dimensions
target = torch.rand(10, 3, 64, 64).cuda()  # Example target with even dimensions
loss = msssim_l2_loss(output, target)
print(loss)
tensor(0.0956, device='cuda:0')

CrossEntropy and Dice Loss


CrossEntropyLossFlat3D

 CrossEntropyLossFlat3D (*args, axis:int=-1, weight=None,
                         ignore_index=-100, reduction='mean',
                         flatten:bool=True, floatify:bool=False,
                         is_2d:bool=True)

Same as nn.CrossEntropyLoss, but flattens input and target for 3D inputs.


DiceLoss

 DiceLoss (smooth=1)

DiceLoss computes the Sørensen–Dice coefficient loss, which is often used for evaluating the performance of image segmentation algorithms.

The Dice coefficient is a measure of overlap between two samples. It ranges from 0 (no overlap) to 1 (perfect overlap). The Dice loss is computed as 1 - Dice coefficient, so it ranges from 1 (no overlap) to 0 (perfect overlap).

Attributes: smooth (float): A smoothing factor to avoid division by zero and ensure numerical stability.

Methods: forward(inputs, targets): Computes the Dice loss between the predicted probabilities (inputs) and the ground truth (targets).

Type Default Details
smooth int 1 Smoothing factor to avoid division by zero
# inputs and targets must be equally dimensional tensors
from torch import randn, randint
inputs = randn((1, 1, 256, 256))  # Input
targets = randint(0, 2, (1, 1, 256, 256)).float()  # Ground Truth

# Initialize
dice_loss = DiceLoss()

# Compute loss
loss = dice_loss(inputs, targets)
print('Dice Loss:', loss.item())
Dice Loss: 0.4982335567474365

Fourier Ring Correlation


FRCLoss

 FRCLoss (image1, image2)

Compute the Fourier Ring Correlation (FRC) loss between two images.

Returns: - torch.Tensor: The FRC loss.

Details
image1 The first input image.
image2 The second input image.

FCRCutoff

 FCRCutoff (image1, image2)

Calculate the cutoff frequency at when Fourier ring correlation drops to 1/7.

Returns: - float: The cutoff frequency.

Details
image1 The first input image.
image2 The second input image.