# Test MSSSIMLoss & SSIMLoss
batch_size, channels, height, width = 2, 1, 128, 128
input_tensor = torch.randn(batch_size, channels, height, width)
target_tensor = torch.randn(batch_size, channels, height, width)
msssim_loss = MSSSIMLoss(spatial_dims=2, levels=3)
ssim_loss = SSIMLoss(spatial_dims=2)
loss_msssim = msssim_loss(input_tensor, target_tensor)
loss_ssim = ssim_loss(input_tensor, target_tensor)
# Both should return scalar tensors
test_is(loss_msssim.dim(), 0) # MS-SSIM loss should return a scalar
test_is(loss_ssim.dim(), 0) #SSIM loss should return a scalarLoss functions
MSELoss
def MSELoss(
inp:Any, targ:Any
)->Tensor:
L1Loss
def L1Loss(
inp:Any, targ:Any
)->Tensor:
SSIMLoss
def SSIMLoss(
spatial_dims:int, data_range:float=1.0, kernel_type:KernelType | str=gaussian, win_size:int | Sequence[int]=11,
kernel_sigma:float | Sequence[float]=1.5, k1:float=0.01, k2:float=0.03, reduction:LossReduction | str=mean
):
Compute the loss function based on the Structural Similarity Index Measure (SSIM) Metric.
For more info, visit https://vicuesoft.com/glossary/term/ssim-ms-ssim/
SSIM reference paper: Wang, Zhou, et al. “Image quality assessment: from error visibility to structural similarity.” IEEE transactions on image processing 13.4 (2004): 600-612.
Combined Losses
CombinedLoss
def CombinedLoss(
spatial_dims:int=2, # Number of spatial dimensions (2 for 2D images, 3 for 3D images)
mse_weight:float=0.33, # Weight for the MSE loss component
mae_weight:float=0.33, # Weight for the MAE loss component
):
CombinedLoss computes a weighted combination of SSIM, MSE, and MAE losses.
This class allows for the combination of three different loss functions: Structural Similarity Index (SSIM), Mean Squared Error (MSE), and Mean Absolute Error (MAE). The weights for MSE and MAE can be adjusted, and the weight for SSIM is automatically calculated as the remaining weight.
CombinedLoss reference paper: Shah, Z. H., Müller, M., Hammer, B., Huser, T., & Schenck, W. (2022, July). Impact of different loss functions on denoising of microscopic images. In 2022 International Joint Conference on Neural Networks (IJCNN) (pp. 1-10). IEEE.
MSSSIMLoss
def MSSSIMLoss(
spatial_dims:int=2, # Number of spatial dimensions (2 for 2D images, 3 for 3D images).
window_size:int=8, # Size of the Gaussian filter for SSIM.
sigma:float=1.5, # Standard deviation of the Gaussian filter.
reduction:str='mean', # Specifies the reduction to apply to the output ('mean', 'sum', or 'none').
levels:int=3, # Number of scales to use for MS-SSIM.
weights:NoneType=None, # Weights to apply to each scale. If None, default values are used.
):
Multi-Scale Structural Similarity (MSSSIM) Loss using MONAI’s SSIMLoss as the base.
MSSSIML1Loss
def MSSSIML1Loss(
spatial_dims:int=2, # Number of spatial dimensions.
alpha:float=0.025, # Weighting factor between MS-SSIM and L1 loss.
window_size:int=8, # Size of the Gaussian filter for SSIM.
sigma:float=1.5, # Standard deviation of the Gaussian filter.
reduction:str='mean', # Specifies the reduction to apply to the output ('mean', 'sum', or 'none').
levels:int=3, # Number of scales to use for MS-SSIM.
weights:NoneType=None, # Weights to apply to each scale. If None, default values are used.
):
Multi-Scale Structural Similarity (MSSSIM) with Gaussian-weighted L1 Loss.
Reference paper: Zhao, H., Gallo, O., Frosio, I., & Kautz, J. (2016). Loss functions for image restoration with neural networks. IEEE Transactions on computational imaging, 3(1), 47-57.
msssiml1_loss = MSSSIML1Loss(alpha=0.025, window_size=11, sigma=1.5, levels=3)
input_image = torch.randn(4, 1, 128, 128) # Batch of 4 grayscale images (1 channel)
target_image = torch.randn(4, 1, 128, 128)
# Compute MSSSIM + Gaussian-weighted L1 loss
loss = msssiml1_loss(input_image, target_image)
loss2 = ssim_loss(input_image, target_image)
print("ms-ssim: ", loss, '\nssim: ', loss2)
test_is(loss.dim(), 0)
test_is(loss2.dim(), 0)ms-ssim: tensor(0.0249)
ssim: tensor(0.9940)
MSSSIML2Loss
def MSSSIML2Loss(
spatial_dims:int=2, # Number of spatial dimensions.
alpha:float=0.1, # Weighting factor between MS-SSIM and L2 loss.
window_size:int=11, # Size of the Gaussian window for SSIM.
sigma:float=1.5, # Standard deviation of the Gaussian.
reduction:str='mean', # Specifies the reduction to apply to the output ('mean', 'sum', or 'none').
levels:int=3, # Number of scales to use for MS-SSIM.
weights:NoneType=None, # Weights to apply to each scale. If None, default values are used.
):
Multi-Scale Structural Similarity (MSSSIM) with Gaussian-weighted L2 Loss.
Reference paper: Zhao, H., Gallo, O., Frosio, I., & Kautz, J. (2016). Loss functions for image restoration with neural networks. IEEE Transactions on computational imaging, 3(1), 47-57.
msssim_l2_loss = MSSSIML2Loss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
output = torch.rand(10, 3, 64, 64).to(device) # Example output with even dimensions
target = torch.rand(10, 3, 64, 64).to(device) # Example target with even dimensions
loss = msssim_l2_loss(output, target)
print(loss)
test_is(loss.dim(), 0)tensor(0.0936, device='cuda:0')
CellLoss
def CellLoss(
):
Combined classification and regression loss for cell prediction tasks.
The first column of pred/targ is treated as a binary classification logit (cell presence), while the remaining columns are treated as continuous regression targets.
The total loss is computed as:
0.5 * MSE(pred[:, 1:], 5 * targ[:, 1:]) +
BCEWithLogitsLoss(pred[:, 0], targ[:, 0])
Attributes: MSE_loss (nn.MSELoss): Mean squared error loss for regression targets. BCE_loss (nn.BCEWithLogitsLoss): Binary cross-entropy loss with logits for cell presence prediction.
Reference paper: Stringer, C., Wang, T., Michaelos, M., & Pachitariu, M. (2021). Cellpose: a generalist algorithm for cellular segmentation. Nature methods, 18(1), 100-106.
class BCELoss:
"""
Multi-channel binary instance segmentation loss.
All channels of `pred`/`targ` are treated as independent binary
instance masks. Each channel represents one object instance and
is trained using Binary Cross-Entropy with Logits.
The total loss is computed as:
BCEWithLogitsLoss(pred, targ)
This formulation is appropriate for multi-label instance segmentation
where instances may overlap and predictions are raw logits.
Attributes:
BCE_loss (nn.BCEWithLogitsLoss): Binary cross-entropy loss with logits
applied independently across all channels and pixels.
"""
def __init__(self):
self.BCE_loss = nn.BCEWithLogitsLoss()
def __call__(self,
pred, # Model logits of shape (bs, n_instances, h, w)
targ, # Binary ground truth masks of same shape as `pred`
) -> torchTensor: # Returns the segmentation loss as a scalar tensor.
"""
Compute the instance segmentation loss.
"""
return self.BCE_loss(pred, targ)InstanceSegLoss
def InstanceSegLoss(
mse_weight:float=1.0, # Weight applied to the MSE component
ssim_weight:float=0.0, # Weight applied to the SSIM component
spatial_dims:int=2, # Number of spatial dimensions (2 for 2D, 3 for 3D)
):
Combined Binary Cross-Entropy and Mean Squared Error loss for multi-channel instance segmentation tasks.
All channels of pred/targ are treated as independent binary instance masks. Each channel represents one object instance.
The total loss is computed as:
BCEWithLogitsLoss(sigmoid(pred), targ) +
mse_weight * MSE(pred, targ) +
ssim_weight * SSIMLoss(pred, targ)
Binary Cross-Entropy provides stable pixel-wise supervision, while the MSE term penalizes probability deviations and can encourage smoother mask predictions.
Attributes: BCE_loss (nn.BCEWithLogitsLoss): Binary cross-entropy loss with logits. MSE_loss (nn.MSELoss): Mean squared error loss applied to probabilities. mse_weight (float): Weight applied to the MSE component.
This is a variation on CellLoss proposed in: Stringer, C., Wang, T., Michaelos, M., & Pachitariu, M. (2021). Cellpose: a generalist algorithm for cellular segmentation. Nature methods, 18(1), 100-106.
DiceBCELoss
def DiceBCELoss(
dice_weight:float=1.0, # Weight applied to Dice component
smooth:float=1e-06, # Smoothing constant for Dice stability
):
Hybrid Binary Cross-Entropy + Dice loss for multi-channel instance segmentation.
All channels of pred/targ are treated as independent binary instance masks. Each channel represents one object instance.
The total loss is computed as:
BCEWithLogitsLoss(pred, targ) +
dice_weight * DiceLoss(sigmoid(pred), targ)
where DiceLoss is defined as:
1 - (2 * intersection + smooth) / (union + smooth)
This formulation improves stability (BCE) while directly optimizing overlap quality (Dice), which is particularly beneficial for imbalanced foreground/background segmentation tasks.
Attributes: BCE_loss (nn.BCEWithLogitsLoss): Binary cross-entropy loss with logits. dice_weight (float): Weight applied to the Dice loss component. smooth (float): Smoothing constant to avoid division by zero.
CrossEntropy and Dice Loss
CrossEntropyLossFlat3D
def CrossEntropyLossFlat3D(
args:VAR_POSITIONAL, axis:int=-1, # Class axis
weight:NoneType=None, ignore_index:int=-100, reduction:str='mean',
flatten:bool=True, # Flatten `inp` and `targ` before calculating loss
floatify:bool=False, # Convert `targ` to `float`
is_2d:bool=True, # Whether `flatten` keeps one or two channels when applied
):
Same as nn.CrossEntropyLoss, but flattens input and target for 3D inputs.
BCELoss
def BCELoss(
):
Multi-channel binary instance segmentation loss.
All channels of pred/targ are treated as independent binary instance masks. Each channel represents one object instance and is trained using Binary Cross-Entropy with Logits.
The total loss is computed as:
BCEWithLogitsLoss(pred, targ)
This formulation is appropriate for multi-label instance segmentation where instances may overlap and predictions are raw logits.
Attributes: BCE_loss (nn.BCEWithLogitsLoss): Binary cross-entropy loss with logits applied independently across all channels and pixels.
DiceLoss
def DiceLoss(
smooth:int=1, # Smoothing factor to avoid division by zero
):
DiceLoss computes the Sørensen–Dice coefficient loss, which is often used for evaluating the performance of image segmentation algorithms.
The Dice coefficient is a measure of overlap between two samples. It ranges from 0 (no overlap) to 1 (perfect overlap). The Dice loss is computed as 1 - Dice coefficient, so it ranges from 1 (no overlap) to 0 (perfect overlap).
Attributes: smooth (float): A smoothing factor to avoid division by zero and ensure numerical stability.
Methods: forward(inputs, targets): Computes the Dice loss between the predicted probabilities (inputs) and the ground truth (targets).
# inputs and targets must be equally dimensional tensors
from torch import randn, randintinputs = randn((1, 1, 256, 256)) # Input
targets = randint(0, 2, (1, 1, 256, 256)).float() # Ground Truth
# Initialize
dice_loss = DiceLoss()
# Compute loss
loss = dice_loss(inputs, targets)
print('Dice Loss:', loss.item())
test_is(loss.dim(), 0)Dice Loss: 0.500794529914856
Fourier Ring Correlation
FRCLoss
def FRCLoss(
image1, # The first input image.
image2, # The second input image.
):
Compute the Fourier Ring Correlation (FRC) loss between two images.
Returns: - torch.Tensor: The FRC loss.
FCRCutoff
def FCRCutoff(
image1, # The first input image.
image2, # The second input image.
):
Calculate the cutoff frequency at when Fourier ring correlation drops to 1/7.
Returns: - float: The cutoff frequency.