Skip to content

Model hard

Neural Network Model Implementation for FastVPINNs with Hard Constraints.

This module implements a custom neural network model designed specifically for FastVPINNs methodology, incorporating hard constraint enforcement. It extends TensorFlow's Keras Model class to provide efficient PDE residual computation and gradient-based training.

The implementation supports
  • Flexible neural network architecture definition
  • Hard constraint enforcement through constraint functions
  • Adaptive learning rate scheduling
  • Attention mechanisms (optional)
  • Efficient tensor operations for PDE residuals
  • Custom gradient computation and training loops
Key classes
  • DenseModel_Hard: Main neural network model with hard constraints
Authors
  • Thivin Anandh (https://thivinanandh.github.io/)
Versions
  • 27-Dec-2024 (Version 0.1): Initial Implementation

DenseModel_Hard

Bases: Model

Neural network model with hard constraint enforcement for FastVPINNs.

This class implements a custom neural network architecture specifically designed for solving PDEs using the FastVPINNs methodology. It supports hard constraint enforcement through custom constraint functions and efficient tensor operations for PDE residual computation.

Attributes:

Name Type Description
layer_dims

List of neurons per layer including input/output

learning_rate_dict

Learning rate configuration containing: - initial_learning_rate: Starting learning rate - use_lr_scheduler: Whether to use learning rate decay - decay_steps: Steps between learning rate updates - decay_rate: Factor for learning rate decay - staircase: Whether to use staircase decay

params_dict

Model parameters including: - n_cells: Number of cells in the domain

loss_function

Custom loss function for PDE residuals

tensor_dtype

TensorFlow data type for computations

use_attention

Whether to use attention mechanism

activation

Activation function for hidden layers

hessian

Whether to compute second derivatives

optimizer

Adam optimizer with optional learning rate schedule

layer_list

List of neural network layers

Example

model = DenseModel_Hard( ... layer_dims=[2, 64, 64, 1], ... learning_rate_dict={'initial_learning_rate': 0.001}, ... params_dict={'n_cells': 100}, ... loss_function=custom_loss, ... tensor_dtype=tf.float32 ... ) history = model.fit(x_train, epochs=1000)

Source code in scirex/core/sciml/fastvpinns/model/model_hard.py
class DenseModel_Hard(tf.keras.Model):
    """Neural network model with hard constraint enforcement for FastVPINNs.

    This class implements a custom neural network architecture specifically
    designed for solving PDEs using the FastVPINNs methodology. It supports
    hard constraint enforcement through custom constraint functions and
    efficient tensor operations for PDE residual computation.

    Attributes:
        layer_dims: List of neurons per layer including input/output
        learning_rate_dict: Learning rate configuration containing:
            - initial_learning_rate: Starting learning rate
            - use_lr_scheduler: Whether to use learning rate decay
            - decay_steps: Steps between learning rate updates
            - decay_rate: Factor for learning rate decay
            - staircase: Whether to use staircase decay
        params_dict: Model parameters including:
            - n_cells: Number of cells in the domain
        loss_function: Custom loss function for PDE residuals
        tensor_dtype: TensorFlow data type for computations
        use_attention: Whether to use attention mechanism
        activation: Activation function for hidden layers
        hessian: Whether to compute second derivatives
        optimizer: Adam optimizer with optional learning rate schedule
        layer_list: List of neural network layers

    Example:
        >>> model = DenseModel_Hard(
        ...     layer_dims=[2, 64, 64, 1],
        ...     learning_rate_dict={'initial_learning_rate': 0.001},
        ...     params_dict={'n_cells': 100},
        ...     loss_function=custom_loss,
        ...     tensor_dtype=tf.float32
        ... )
        >>> history = model.fit(x_train, epochs=1000)
    """

    def __init__(
        self,
        layer_dims,
        learning_rate_dict,
        params_dict,
        loss_function,
        input_tensors_list,
        orig_factor_matrices,
        force_function_list,
        tensor_dtype,
        use_attention=False,
        activation="tanh",
        hessian=False,
        hard_constraint_function=None,
    ):
        super(DenseModel_Hard, self).__init__()
        self.layer_dims = layer_dims
        self.use_attention = use_attention
        self.activation = activation
        self.layer_list = []
        self.loss_function = loss_function
        self.hessian = hessian
        if hard_constraint_function is None:
            self.hard_constraint_function = lambda x, y: y
        else:
            self.hard_constraint_function = hard_constraint_function

        self.tensor_dtype = tensor_dtype

        # if dtype is not a valid tensorflow dtype, raise an error
        if not isinstance(self.tensor_dtype, tf.DType):
            raise TypeError("The given dtype is not a valid tensorflow dtype")

        self.orig_factor_matrices = orig_factor_matrices
        self.shape_function_mat_list = copy.deepcopy(orig_factor_matrices[0])
        self.shape_function_grad_x_factor_mat_list = copy.deepcopy(
            orig_factor_matrices[1]
        )
        self.shape_function_grad_y_factor_mat_list = copy.deepcopy(
            orig_factor_matrices[2]
        )

        self.force_function_list = force_function_list

        self.input_tensors_list = input_tensors_list
        self.input_tensor = copy.deepcopy(input_tensors_list[0])
        self.dirichlet_input = copy.deepcopy(input_tensors_list[1])
        self.dirichlet_actual = copy.deepcopy(input_tensors_list[2])

        self.params_dict = params_dict

        self.pre_multiplier_val = self.shape_function_mat_list
        self.pre_multiplier_grad_x = self.shape_function_grad_x_factor_mat_list
        self.pre_multiplier_grad_y = self.shape_function_grad_y_factor_mat_list

        self.force_matrix = self.force_function_list

        self.gradients = None

        print(f"{'-'*74}")
        print(f"| {'PARAMETER':<25} | {'SHAPE':<25} |")
        print(f"{'-'*74}")
        print(
            f"| {'input_tensor':<25} | {str(self.input_tensor.shape):<25} | {self.input_tensor.dtype}"
        )
        print(
            f"| {'force_matrix':<25} | {str(self.force_matrix.shape):<25} | {self.force_matrix.dtype}"
        )
        print(
            f"| {'pre_multiplier_grad_x':<25} | {str(self.pre_multiplier_grad_x.shape):<25} | {self.pre_multiplier_grad_x.dtype}"
        )
        print(
            f"| {'pre_multiplier_grad_y':<25} | {str(self.pre_multiplier_grad_y.shape):<25} | {self.pre_multiplier_grad_y.dtype}"
        )
        print(
            f"| {'pre_multiplier_val':<25} | {str(self.pre_multiplier_val.shape):<25} | {self.pre_multiplier_val.dtype}"
        )
        print(
            f"| {'dirichlet_input':<25} | {str(self.dirichlet_input.shape):<25} | {self.dirichlet_input.dtype}"
        )
        print(
            f"| {'dirichlet_actual':<25} | {str(self.dirichlet_actual.shape):<25} | {self.dirichlet_actual.dtype}"
        )
        print(f"{'-'*74}")

        self.n_cells = params_dict["n_cells"]

        ## ----------------------------------------------------------------- ##
        ## ---------- LEARNING RATE AND OPTIMISER FOR THE MODEL ------------ ##
        ## ----------------------------------------------------------------- ##

        # parse the learning rate dictionary
        self.learning_rate_dict = learning_rate_dict
        initial_learning_rate = learning_rate_dict["initial_learning_rate"]
        use_lr_scheduler = learning_rate_dict["use_lr_scheduler"]
        decay_steps = learning_rate_dict["decay_steps"]
        decay_rate = learning_rate_dict["decay_rate"]
        staircase = learning_rate_dict["staircase"]

        if use_lr_scheduler:
            learning_rate_fn = tf.keras.optimizers.schedules.ExponentialDecay(
                initial_learning_rate, decay_steps, decay_rate, staircase=True
            )
        else:
            learning_rate_fn = initial_learning_rate

        self.optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate_fn)

        ## ----------------------------------------------------------------- ##
        ## --------------------- MODEL ARCHITECTURE ------------------------ ##
        ## ----------------------------------------------------------------- ##

        # Build dense layers based on the input list
        for dim in range(len(self.layer_dims) - 2):
            self.layer_list.append(
                TensorflowDense.create_layer(
                    units=self.layer_dims[dim + 1],
                    activation=self.activation,
                    dtype=self.tensor_dtype,
                    kernel_initializer="glorot_uniform",
                    bias_initializer="zeros",
                )
            )

        # Add a output layer with no activation
        self.layer_list.append(
            TensorflowDense.create_layer(
                units=self.layer_dims[-1],
                activation=None,
                dtype=self.tensor_dtype,
                kernel_initializer="glorot_uniform",
                bias_initializer="zeros",
            )
        )

        # Add attention layer if required
        if self.use_attention:
            self.attention_layer = layers.Attention()

        # Compile the model
        self.compile(optimizer=self.optimizer)
        self.build(input_shape=(None, self.layer_dims[0]))

        # print the summary of the model
        self.summary()

    # def build(self, input_shape):
    #     super(DenseModel, self).build(input_shape)

    def call(self, inputs) -> tf.Tensor:
        """This method is used to define the forward pass of the model.

        Args:
            inputs: Input tensor to the model

        Returns:
            Output tensor from the model
        """
        x = inputs

        # Apply attention layer after input if flag is True
        if self.use_attention:
            x = self.attention_layer([x, x])

        # Loop through the dense layers
        for layer in self.layer_list:
            x = layer(x)

        x = self.hard_constraint_function(inputs, x)

        return x

    def get_config(self) -> dict:
        """This method is used to get the configuration of the model.

        Args:
            None

        Returns:
            Configuration dictionary of the model
        """
        # Get the base configuration
        base_config = super().get_config()

        # Add the non-serializable arguments to the configuration
        base_config.update(
            {
                "learning_rate_dict": self.learning_rate_dict,
                "loss_function": self.loss_function,
                "input_tensors_list": self.input_tensors_list,
                "orig_factor_matrices": self.orig_factor_matrices,
                "force_function_list": self.force_function_list,
                "params_dict": self.params_dict,
                "use_attention": self.use_attention,
                "activation": self.activation,
                "hessian": self.hessian,
                "layer_dims": self.layer_dims,
                "tensor_dtype": self.tensor_dtype,
            }
        )

        return base_config

    @tf.function
    def train_step(
        self, beta=10, bilinear_params_dict=None
    ) -> dict:  # pragma: no cover
        """This method is used to define the training step of the mode.

        Args:
            beta: The penalty parameter for the hard constraints
            bilinear_params_dict: The bilinear parameters dictionary

        Returns:
            Dictionary containing the loss values
        """

        with tf.GradientTape(persistent=True) as tape:
            # Predict the values for dirichlet boundary conditions

            # initialize total loss as a tensor with shape (1,) and value 0.0
            total_pde_loss = 0.0

            with tf.GradientTape(persistent=True) as tape1:
                # tape gradient
                tape1.watch(self.input_tensor)
                # Compute the predicted values from the model
                predicted_values = self(self.input_tensor)

            # compute the gradients of the predicted values wrt the input which is (x, y)
            gradients = tape1.gradient(predicted_values, self.input_tensor)

            # Split the gradients into x and y components and reshape them to (-1, 1)
            # the reshaping is done for the tensorial operations purposes (refer Notebook)
            pred_grad_x = tf.reshape(
                gradients[:, 0], [self.n_cells, self.pre_multiplier_grad_x.shape[-1]]
            )  # shape : (N_cells , N_quadrature_points)
            pred_grad_y = tf.reshape(
                gradients[:, 1], [self.n_cells, self.pre_multiplier_grad_y.shape[-1]]
            )  # shape : (N_cells , N_quadrature_points)

            pred_val = tf.reshape(
                predicted_values, [self.n_cells, self.pre_multiplier_val.shape[-1]]
            )  # shape : (N_cells , N_quadrature_points)

            cells_residual = self.loss_function(
                test_shape_val_mat=self.pre_multiplier_val,
                test_grad_x_mat=self.pre_multiplier_grad_x,
                test_grad_y_mat=self.pre_multiplier_grad_y,
                pred_nn=pred_val,
                pred_grad_x_nn=pred_grad_x,
                pred_grad_y_nn=pred_grad_y,
                forcing_function=self.force_matrix,
                bilinear_params=bilinear_params_dict,
            )

            residual = tf.reduce_sum(cells_residual)

            # tf.print("Residual : ", residual)
            # tf.print("Residual Shape : ", residual.shape)

            # Compute the total loss for the PDE
            total_pde_loss = total_pde_loss + residual

            # convert predicted_values_dirichlet to tf.float64
            # predicted_values_dirichlet = tf.cast(predicted_values_dirichlet, tf.float64)

            # tf.print("Boundary Loss : ", boundary_loss)
            # tf.print("Boundary Loss Shape : ", boundary_loss.shape)
            # tf.print("Total PDE Loss : ", total_pde_loss)
            # tf.print("Total PDE Loss Shape : ", total_pde_loss.shape)
            boundary_loss = 0.0
            # Compute Total Loss
            total_loss = total_pde_loss

        trainable_vars = self.trainable_variables
        self.gradients = tape.gradient(total_loss, trainable_vars)
        self.optimizer.apply_gradients(zip(self.gradients, trainable_vars))

        return {
            "loss_pde": total_pde_loss,
            "loss_dirichlet": boundary_loss,
            "loss": total_loss,
        }

call(inputs)

This method is used to define the forward pass of the model.

Parameters:

Name Type Description Default
inputs

Input tensor to the model

required

Returns:

Type Description
Tensor

Output tensor from the model

Source code in scirex/core/sciml/fastvpinns/model/model_hard.py
def call(self, inputs) -> tf.Tensor:
    """This method is used to define the forward pass of the model.

    Args:
        inputs: Input tensor to the model

    Returns:
        Output tensor from the model
    """
    x = inputs

    # Apply attention layer after input if flag is True
    if self.use_attention:
        x = self.attention_layer([x, x])

    # Loop through the dense layers
    for layer in self.layer_list:
        x = layer(x)

    x = self.hard_constraint_function(inputs, x)

    return x

get_config()

This method is used to get the configuration of the model.

Returns:

Type Description
dict

Configuration dictionary of the model

Source code in scirex/core/sciml/fastvpinns/model/model_hard.py
def get_config(self) -> dict:
    """This method is used to get the configuration of the model.

    Args:
        None

    Returns:
        Configuration dictionary of the model
    """
    # Get the base configuration
    base_config = super().get_config()

    # Add the non-serializable arguments to the configuration
    base_config.update(
        {
            "learning_rate_dict": self.learning_rate_dict,
            "loss_function": self.loss_function,
            "input_tensors_list": self.input_tensors_list,
            "orig_factor_matrices": self.orig_factor_matrices,
            "force_function_list": self.force_function_list,
            "params_dict": self.params_dict,
            "use_attention": self.use_attention,
            "activation": self.activation,
            "hessian": self.hessian,
            "layer_dims": self.layer_dims,
            "tensor_dtype": self.tensor_dtype,
        }
    )

    return base_config

train_step(beta=10, bilinear_params_dict=None)

This method is used to define the training step of the mode.

Parameters:

Name Type Description Default
beta

The penalty parameter for the hard constraints

10
bilinear_params_dict

The bilinear parameters dictionary

None

Returns:

Type Description
dict

Dictionary containing the loss values

Source code in scirex/core/sciml/fastvpinns/model/model_hard.py
@tf.function
def train_step(
    self, beta=10, bilinear_params_dict=None
) -> dict:  # pragma: no cover
    """This method is used to define the training step of the mode.

    Args:
        beta: The penalty parameter for the hard constraints
        bilinear_params_dict: The bilinear parameters dictionary

    Returns:
        Dictionary containing the loss values
    """

    with tf.GradientTape(persistent=True) as tape:
        # Predict the values for dirichlet boundary conditions

        # initialize total loss as a tensor with shape (1,) and value 0.0
        total_pde_loss = 0.0

        with tf.GradientTape(persistent=True) as tape1:
            # tape gradient
            tape1.watch(self.input_tensor)
            # Compute the predicted values from the model
            predicted_values = self(self.input_tensor)

        # compute the gradients of the predicted values wrt the input which is (x, y)
        gradients = tape1.gradient(predicted_values, self.input_tensor)

        # Split the gradients into x and y components and reshape them to (-1, 1)
        # the reshaping is done for the tensorial operations purposes (refer Notebook)
        pred_grad_x = tf.reshape(
            gradients[:, 0], [self.n_cells, self.pre_multiplier_grad_x.shape[-1]]
        )  # shape : (N_cells , N_quadrature_points)
        pred_grad_y = tf.reshape(
            gradients[:, 1], [self.n_cells, self.pre_multiplier_grad_y.shape[-1]]
        )  # shape : (N_cells , N_quadrature_points)

        pred_val = tf.reshape(
            predicted_values, [self.n_cells, self.pre_multiplier_val.shape[-1]]
        )  # shape : (N_cells , N_quadrature_points)

        cells_residual = self.loss_function(
            test_shape_val_mat=self.pre_multiplier_val,
            test_grad_x_mat=self.pre_multiplier_grad_x,
            test_grad_y_mat=self.pre_multiplier_grad_y,
            pred_nn=pred_val,
            pred_grad_x_nn=pred_grad_x,
            pred_grad_y_nn=pred_grad_y,
            forcing_function=self.force_matrix,
            bilinear_params=bilinear_params_dict,
        )

        residual = tf.reduce_sum(cells_residual)

        # tf.print("Residual : ", residual)
        # tf.print("Residual Shape : ", residual.shape)

        # Compute the total loss for the PDE
        total_pde_loss = total_pde_loss + residual

        # convert predicted_values_dirichlet to tf.float64
        # predicted_values_dirichlet = tf.cast(predicted_values_dirichlet, tf.float64)

        # tf.print("Boundary Loss : ", boundary_loss)
        # tf.print("Boundary Loss Shape : ", boundary_loss.shape)
        # tf.print("Total PDE Loss : ", total_pde_loss)
        # tf.print("Total PDE Loss Shape : ", total_pde_loss.shape)
        boundary_loss = 0.0
        # Compute Total Loss
        total_loss = total_pde_loss

    trainable_vars = self.trainable_variables
    self.gradients = tape.gradient(total_loss, trainable_vars)
    self.optimizer.apply_gradients(zip(self.gradients, trainable_vars))

    return {
        "loss_pde": total_pde_loss,
        "loss_dirichlet": boundary_loss,
        "loss": total_loss,
    }