Skip to content

loss

Module: loss.py

This module implements loss functions for Neural Networks Currenty, it uses optax library for loss functions (documentation: https://optax.readthedocs.io/en/latest/api/losses.html)

Authors
  • Lokesh Mohanty (lokeshm@iisc.ac.in)
Version Info
  • 01/01/2025: Initial version

cross_entropy_loss(output, y)

Compute the cross-entropy loss

Parameters:

Name Type Description Default
output Array

output of the model

required
y Array

Batched target labels

required
Source code in scirex/core/dl/nn/loss.py
def cross_entropy_loss(output: jax.Array, y: jax.Array) -> float:
    """
    Compute the cross-entropy loss

    Args:
        output: output of the model
        y: Batched target labels
    """

    n_classes = output.shape[-1]
    loss = optax.softmax_cross_entropy(output, jax.nn.one_hot(y, n_classes)).mean()
    return loss

mse_loss(output, y)

Compute mean squared error loss

Parameters:

Name Type Description Default
output Array

output of the model

required
y Array

target values

required
Source code in scirex/core/dl/nn/loss.py
def mse_loss(output: jax.Array, y: jax.Array) -> float:
    """
    Compute mean squared error loss

    Args:
        output: output of the model
        y: target values
    """
    return jnp.mean(jnp.square(output - y))