pruning
Example Script: pruning.py
This script implements model compression using TensorFlow Model Optimization Toolkit's pruning capabilities
on neural network architectures. It demonstrates a reusable implementation for applying and evaluating
model pruning with polynomial decay schedule.
This example includes:
- Implementation of a reusable ModelPruning class for model compression
- Building and training baseline CNN models
- Applying progressive pruning from 50% to 80% sparsity
- Training and evaluation workflows for both baseline and pruned models
Authors:
- Nithyashree R (nithyashreer@iisc.ac.in)
Version Info:
- 06/01/2024: Initial version
ModelPruning
A reusable class for performing pruning on any model and dataset.
Attributes:
Name | Type | Description |
---|---|---|
model |
Model
|
The base model architecture that will undergo pruning. |
pruned_model |
Model
|
The model after pruning. |
baseline_model_accuracy |
float
|
Accuracy of the baseline model evaluated on test data. |
pruned_model_accuracy |
float
|
Accuracy of the pruned model evaluated on test data. |
Source code in scirex/core/model_compression/pruning.py
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
|
__init__(input_shape=(28, 28), num_classes=10, epochs=10, batch_size=35, validation_split=0.1)
Initializes the pruning process for a model.
:param input_shape: Shape of the input data. :type input_shape: tuple :param num_classes: Number of output classes. :type num_classes: int :param epochs: Number of epochs to train the pruned model. Default is 10. :type epochs: int :param batch_size: Size of the training batch. Default is 35. :type batch_size: int :param validation_split: Fraction of training data to be used for validation. Default is 0.1. :type validation_split: float
Source code in scirex/core/model_compression/pruning.py
apply_pruning()
Applies pruning to the base model.
:return: A pruned model. :rtype: tf.keras.Model
Source code in scirex/core/model_compression/pruning.py
evaluate_baseline(test_images, test_labels)
Evaluates the baseline model.
:param test_images: Test data features. :param test_labels: Test data labels. :return: Accuracy of the baseline model. :rtype: float
Source code in scirex/core/model_compression/pruning.py
evaluate_pruned_model(test_images, test_labels)
Evaluates the pruned model.
:param test_images: Test data features. :param test_labels: Test data labels. :return: Accuracy of the pruned model. :rtype: float
Source code in scirex/core/model_compression/pruning.py
save_baseline_model()
Saves the baseline model to a temporary file using the .keras format.
:return: Path to the saved model file. :rtype: str
Source code in scirex/core/model_compression/pruning.py
train_baseline_model(train_images, train_labels)
Trains the baseline model without pruning.
:param train_images: Training data features. :param train_labels: Training data labels.
Source code in scirex/core/model_compression/pruning.py
train_pruned_model(train_images, train_labels)
Trains the pruned model.
:param train_images: Training data features. :param train_labels: Training data labels.