Min 4 лет назад
Родитель
Сommit
a49e1fc4c3
1 измененных файлов с 299 добавлено и 0 удалено
  1. 299 0
      models/quantized_net.py

+ 299 - 0
models/quantized_net.py

@@ -0,0 +1,299 @@
+from numpy import (
+    array,
+    zeros,
+    dot,
+    median,
+    log2,
+    linspace,
+    argmin,
+    abs,
+)
+from scipy.linalg import norm
+from tensorflow.keras.backend import function as Kfunction
+from tensorflow.keras.models import Model, clone_model
+from collections import namedtuple
+from typing import List, Generator
+from time import time
+
+from models.autoencoder import Autoencoder
+
+QuantizedNeuron = namedtuple("QuantizedNeuron", ["layer_idx", "neuron_idx", "q"])
+QuantizedFilter = namedtuple(
+    "QuantizedFilter", ["layer_idx", "filter_idx", "channel_idx", "q_filtr"]
+)
+SegmentedData = namedtuple("SegmentedData", ["wX_seg", "qX_seg"])
+
+
+class QuantizedNeuralNetwork:
+    def __init__(
+            self,
+            network: Model,
+            batch_size: int,
+            get_data: Generator[array, None, None],
+            logger=None,
+            ignore_layers=[],
+            bits=log2(3),
+            alphabet_scalar=1,
+    ):
+
+        self.get_data = get_data
+
+        # The pre-trained network.
+        self.trained_net = network
+
+        # This copies the network structure but not the weights.
+        if isinstance(network, Autoencoder):
+            # The pre-trained network.
+            self.trained_net_layers = network.all_layers
+            self.quantized_net = Autoencoder(network.N, network.channel, bipolar=network.bipolar)
+            self.quantized_net.set_weights(network.get_weights())
+            self.quantized_net_layers = self.quantized_net.all_layers
+            # self.quantized_net.layers = self.quantized_net.layers
+        else:
+            # The pre-trained network.
+            self.trained_net_layers = network.layers
+            self.quantized_net = clone_model(network)
+            # Set all the weights to be the same a priori.
+            self.quantized_net.set_weights(network.get_weights())
+            self.quantized_net_layers = self.quantized_net.layers
+
+        self.batch_size = batch_size
+
+        self.alphabet_scalar = alphabet_scalar
+
+        # Create a dictionary encoding which layers are Dense, and what their dimensions are.
+        self.layer_dims = {
+            layer_idx: layer.get_weights()[0].shape
+            for layer_idx, layer in enumerate(network.layers)
+            if layer.__class__.__name__ == "Dense"
+        }
+
+        # This determines the alphabet. There will be 2**bits atoms in our alphabet.
+        self.bits = bits
+
+        # Construct the (unscaled) alphabet. Layers will scale this alphabet based on the
+        # distribution of that layer's weights.
+        self.alphabet = linspace(-1, 1, num=int(round(2 ** (bits))))
+
+        self.logger = logger
+
+        self.ignore_layers = ignore_layers
+
+    def _log(self, msg: str):
+        if self.logger:
+            self.logger.info(msg)
+        else:
+            print(msg)
+
+    def _bit_round(self, t: float, rad: float) -> float:
+        """Rounds a quantity to the nearest atom in the (scaled) quantization alphabet.
+
+        Parameters
+        -----------
+        t : float
+            The value to quantize.
+        rad : float
+            Scaling factor for the quantization alphabet.
+
+        Returns
+        -------
+        bit : float
+            The quantized value.
+        """
+
+        # Scale the alphabet appropriately.
+        layer_alphabet = rad * self.alphabet
+        return layer_alphabet[argmin(abs(layer_alphabet - t))]
+
+    def _quantize_weight(
+            self, w: float, u: array, X: array, X_tilde: array, rad: float
+    ) -> float:
+        """Quantizes a single weight of a neuron.
+
+        Parameters
+        -----------
+        w : float
+            The weight.
+        u : array ,
+            Residual vector.
+        X : array
+            Vector from the analog network's random walk.
+        X_tilde : array
+            Vector from the quantized network's random walk.
+        rad : float
+            Scaling factor for the quantization alphabet.
+
+        Returns
+        -------
+        bit : float
+            The quantized value.
+        """
+
+        if norm(X_tilde, 2) < 10 ** (-16):
+            return 0
+
+        if abs(dot(X_tilde, u)) < 10 ** (-10):
+            return self._bit_round(w, rad)
+
+        return self._bit_round(dot(X_tilde, u + w * X) / (norm(X_tilde, 2) ** 2), rad)
+
+    def _quantize_neuron(
+            self,
+            layer_idx: int,
+            neuron_idx: int,
+            wX: array,
+            qX: array,
+            rad=1,
+    ) -> QuantizedNeuron:
+        """Quantizes a single neuron in a Dense layer.
+
+        Parameters
+        -----------
+        layer_idx : int
+            Index of the Dense layer.
+        neuron_idx : int,
+            Index of the neuron in the Dense layer.
+        wX : array
+            Layer input for the analog convolutional neural network.
+        qX : array
+            Layer input for the quantized convolutional neural network.
+        rad : float
+            Scaling factor for the quantization alphabet.
+
+        Returns
+        -------
+        QuantizedNeuron: NamedTuple
+            A tuple with the layer and neuron index, as well as the quantized neuron.
+        """
+
+        N_ell = wX.shape[1]
+        u = zeros(self.batch_size)
+        w = self.trained_net_layers[layer_idx].get_weights()[0][:, neuron_idx]
+        q = zeros(N_ell)
+        for t in range(N_ell):
+            q[t] = self._quantize_weight(w[t], u, wX[:, t], qX[:, t], rad)
+            u += w[t] * wX[:, t] - q[t] * qX[:, t]
+
+        return QuantizedNeuron(layer_idx=layer_idx, neuron_idx=neuron_idx, q=q)
+
+    def _get_layer_data(self, layer_idx: int, hf=None):
+        """Gets the input data for the layer at a given index.
+
+        Parameters
+        -----------
+        layer_idx : int
+            Index of the layer.
+        hf: hdf5 File object in write mode.
+            If provided, will write output to hdf5 file instead of returning directly.
+
+        Returns
+        -------
+        tuple: (array, array)
+            A tuple of arrays, with the first entry being the input for the analog network
+            and the latter being the input for the quantized network.
+        """
+
+        layer = self.trained_net_layers[layer_idx]
+        layer_data_shape = layer.input_shape[1:] if layer.input_shape[0] is None else layer.input_shape
+        wX = zeros((self.batch_size, *layer_data_shape))
+        qX = zeros((self.batch_size, *layer_data_shape))
+        if layer_idx == 0:
+            for sample_idx in range(self.batch_size):
+                try:
+                    wX[sample_idx, :] = next(self.get_data)
+                except StopIteration:
+                    # No more samples!
+                    break
+            qX = wX
+        else:
+            # Define functions which will give you the output of the previous hidden layer
+            # for both networks.
+            prev_trained_output = Kfunction(
+                [self.trained_net_layers[0].input],
+                [self.trained_net_layers[layer_idx - 1].output],
+            )
+            prev_quant_output = Kfunction(
+                [self.quantized_net_layers[0].input],
+                [self.quantized_net_layers[layer_idx - 1].output],
+            )
+            input_layer = self.trained_net_layers[0]
+            input_shape = input_layer.input_shape[1:] if input_layer.input_shape[0] is None else input_layer.input_shape
+            batch = zeros((self.batch_size, *input_shape))
+
+            # TODO: Add hf option here. Feed batches of data through rather than all at once. You may want
+            # to reconsider how much memory you preallocate for batch, wX, and qX.
+            feed_foward_batch_size = 500
+            ctr = 0
+            for sample_idx in range(self.batch_size):
+                try:
+                    batch[sample_idx, :] = next(self.get_data)
+                except StopIteration:
+                    # No more samples!
+                    break
+
+            wX = prev_trained_output([batch])[0]
+            qX = prev_quant_output([batch])[0]
+
+        return (wX, qX)
+
+    def _update_weights(self, layer_idx: int, Q: array):
+        """Updates the weights of the quantized neural network given a layer index and
+        quantized weights.
+
+        Parameters
+        -----------
+        layer_idx : int
+            Index of the Conv2D layer.
+        Q : array
+            The quantized weights.
+        """
+
+        # Update the quantized network. Use the same bias vector as in the analog network for now.
+        if self.trained_net_layers[layer_idx].use_bias:
+            bias = self.trained_net_layers[layer_idx].get_weights()[1]
+            self.quantized_net_layers[layer_idx].set_weights([Q, bias])
+        else:
+            self.quantized_net_layers[layer_idx].set_weights([Q])
+
+    def _quantize_layer(self, layer_idx: int):
+        """Quantizes a Dense layer of a multi-layer perceptron.
+
+        Parameters
+        -----------
+        layer_idx : int
+            Index of the Dense layer.
+        """
+
+        W = self.trained_net_layers[layer_idx].get_weights()[0]
+        N_ell, N_ell_plus_1 = W.shape
+        # Placeholder for the weight matrix in the quantized network.
+        Q = zeros(W.shape)
+        N_ell_plus_1 = W.shape[1]
+        wX, qX = self._get_layer_data(layer_idx)
+
+        # Set the radius of the alphabet.
+        rad = self.alphabet_scalar * median(abs(W.flatten()))
+
+        for neuron_idx in range(N_ell_plus_1):
+            self._log(f"\tQuantizing neuron {neuron_idx} of {N_ell_plus_1}...")
+            tic = time()
+            qNeuron = self._quantize_neuron(layer_idx, neuron_idx, wX, qX, rad)
+            Q[:, neuron_idx] = qNeuron.q
+
+            self._log(f"\tdone. {time() - tic :.2f} seconds.")
+
+            self._update_weights(layer_idx, Q)
+
+    def quantize_network(self):
+        """Quantizes all Dense layers that are not specified by the list of ignored layers."""
+
+        # This must be done sequentially.
+        for layer_idx, layer in enumerate(self.trained_net_layers):
+            if (
+                    layer.__class__.__name__ == "Dense"
+                    and layer_idx not in self.ignore_layers
+            ):
+                # Only quantize dense layers.
+                self._log(f"Quantizing layer {layer_idx}...")
+                self._quantize_layer(layer_idx)
+                self._log(f"done. {layer_idx}...")