quantized_net.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. from numpy import (
  2. array,
  3. zeros,
  4. dot,
  5. median,
  6. log2,
  7. linspace,
  8. argmin,
  9. abs,
  10. )
  11. from scipy.linalg import norm
  12. from tensorflow.keras.backend import function as Kfunction
  13. from tensorflow.keras.models import Model, clone_model
  14. from collections import namedtuple
  15. from typing import List, Generator
  16. from time import time
  17. from models.autoencoder import Autoencoder
  18. QuantizedNeuron = namedtuple("QuantizedNeuron", ["layer_idx", "neuron_idx", "q"])
  19. QuantizedFilter = namedtuple(
  20. "QuantizedFilter", ["layer_idx", "filter_idx", "channel_idx", "q_filtr"]
  21. )
  22. SegmentedData = namedtuple("SegmentedData", ["wX_seg", "qX_seg"])
  23. class QuantizedNeuralNetwork:
  24. def __init__(
  25. self,
  26. network: Model,
  27. batch_size: int,
  28. get_data: Generator[array, None, None],
  29. logger=None,
  30. ignore_layers=[],
  31. bits=log2(3),
  32. alphabet_scalar=1,
  33. ):
  34. self.get_data = get_data
  35. # The pre-trained network.
  36. self.trained_net = network
  37. # This copies the network structure but not the weights.
  38. if isinstance(network, Autoencoder):
  39. # The pre-trained network.
  40. self.trained_net_layers = network.all_layers
  41. self.quantized_net = Autoencoder(network.N, network.channel, bipolar=network.bipolar)
  42. self.quantized_net.set_weights(network.get_weights())
  43. self.quantized_net_layers = self.quantized_net.all_layers
  44. # self.quantized_net.layers = self.quantized_net.layers
  45. else:
  46. # The pre-trained network.
  47. self.trained_net_layers = network.layers
  48. self.quantized_net = clone_model(network)
  49. # Set all the weights to be the same a priori.
  50. self.quantized_net.set_weights(network.get_weights())
  51. self.quantized_net_layers = self.quantized_net.layers
  52. self.batch_size = batch_size
  53. self.alphabet_scalar = alphabet_scalar
  54. # Create a dictionary encoding which layers are Dense, and what their dimensions are.
  55. self.layer_dims = {
  56. layer_idx: layer.get_weights()[0].shape
  57. for layer_idx, layer in enumerate(network.layers)
  58. if layer.__class__.__name__ == "Dense"
  59. }
  60. # This determines the alphabet. There will be 2**bits atoms in our alphabet.
  61. self.bits = bits
  62. # Construct the (unscaled) alphabet. Layers will scale this alphabet based on the
  63. # distribution of that layer's weights.
  64. self.alphabet = linspace(-1, 1, num=int(round(2 ** (bits))))
  65. self.logger = logger
  66. self.ignore_layers = ignore_layers
  67. def _log(self, msg: str):
  68. if self.logger:
  69. self.logger.info(msg)
  70. else:
  71. print(msg)
  72. def _bit_round(self, t: float, rad: float) -> float:
  73. """Rounds a quantity to the nearest atom in the (scaled) quantization alphabet.
  74. Parameters
  75. -----------
  76. t : float
  77. The value to quantize.
  78. rad : float
  79. Scaling factor for the quantization alphabet.
  80. Returns
  81. -------
  82. bit : float
  83. The quantized value.
  84. """
  85. # Scale the alphabet appropriately.
  86. layer_alphabet = rad * self.alphabet
  87. return layer_alphabet[argmin(abs(layer_alphabet - t))]
  88. def _quantize_weight(
  89. self, w: float, u: array, X: array, X_tilde: array, rad: float
  90. ) -> float:
  91. """Quantizes a single weight of a neuron.
  92. Parameters
  93. -----------
  94. w : float
  95. The weight.
  96. u : array ,
  97. Residual vector.
  98. X : array
  99. Vector from the analog network's random walk.
  100. X_tilde : array
  101. Vector from the quantized network's random walk.
  102. rad : float
  103. Scaling factor for the quantization alphabet.
  104. Returns
  105. -------
  106. bit : float
  107. The quantized value.
  108. """
  109. if norm(X_tilde, 2) < 10 ** (-16):
  110. return 0
  111. if abs(dot(X_tilde, u)) < 10 ** (-10):
  112. return self._bit_round(w, rad)
  113. return self._bit_round(dot(X_tilde, u + w * X) / (norm(X_tilde, 2) ** 2), rad)
  114. def _quantize_neuron(
  115. self,
  116. layer_idx: int,
  117. neuron_idx: int,
  118. wX: array,
  119. qX: array,
  120. rad=1,
  121. ) -> QuantizedNeuron:
  122. """Quantizes a single neuron in a Dense layer.
  123. Parameters
  124. -----------
  125. layer_idx : int
  126. Index of the Dense layer.
  127. neuron_idx : int,
  128. Index of the neuron in the Dense layer.
  129. wX : array
  130. Layer input for the analog convolutional neural network.
  131. qX : array
  132. Layer input for the quantized convolutional neural network.
  133. rad : float
  134. Scaling factor for the quantization alphabet.
  135. Returns
  136. -------
  137. QuantizedNeuron: NamedTuple
  138. A tuple with the layer and neuron index, as well as the quantized neuron.
  139. """
  140. N_ell = wX.shape[1]
  141. u = zeros(self.batch_size)
  142. w = self.trained_net_layers[layer_idx].get_weights()[0][:, neuron_idx]
  143. q = zeros(N_ell)
  144. for t in range(N_ell):
  145. q[t] = self._quantize_weight(w[t], u, wX[:, t], qX[:, t], rad)
  146. u += w[t] * wX[:, t] - q[t] * qX[:, t]
  147. return QuantizedNeuron(layer_idx=layer_idx, neuron_idx=neuron_idx, q=q)
  148. def _get_layer_data(self, layer_idx: int, hf=None):
  149. """Gets the input data for the layer at a given index.
  150. Parameters
  151. -----------
  152. layer_idx : int
  153. Index of the layer.
  154. hf: hdf5 File object in write mode.
  155. If provided, will write output to hdf5 file instead of returning directly.
  156. Returns
  157. -------
  158. tuple: (array, array)
  159. A tuple of arrays, with the first entry being the input for the analog network
  160. and the latter being the input for the quantized network.
  161. """
  162. layer = self.trained_net_layers[layer_idx]
  163. layer_data_shape = layer.input_shape[1:] if layer.input_shape[0] is None else layer.input_shape
  164. wX = zeros((self.batch_size, *layer_data_shape))
  165. qX = zeros((self.batch_size, *layer_data_shape))
  166. if layer_idx == 0:
  167. for sample_idx in range(self.batch_size):
  168. try:
  169. wX[sample_idx, :] = next(self.get_data)
  170. except StopIteration:
  171. # No more samples!
  172. break
  173. qX = wX
  174. else:
  175. # Define functions which will give you the output of the previous hidden layer
  176. # for both networks.
  177. prev_trained_output = Kfunction(
  178. [self.trained_net_layers[0].input],
  179. [self.trained_net_layers[layer_idx - 1].output],
  180. )
  181. prev_quant_output = Kfunction(
  182. [self.quantized_net_layers[0].input],
  183. [self.quantized_net_layers[layer_idx - 1].output],
  184. )
  185. input_layer = self.trained_net_layers[0]
  186. input_shape = input_layer.input_shape[1:] if input_layer.input_shape[0] is None else input_layer.input_shape
  187. batch = zeros((self.batch_size, *input_shape))
  188. # TODO: Add hf option here. Feed batches of data through rather than all at once. You may want
  189. # to reconsider how much memory you preallocate for batch, wX, and qX.
  190. feed_foward_batch_size = 500
  191. ctr = 0
  192. for sample_idx in range(self.batch_size):
  193. try:
  194. batch[sample_idx, :] = next(self.get_data)
  195. except StopIteration:
  196. # No more samples!
  197. break
  198. wX = prev_trained_output([batch])[0]
  199. qX = prev_quant_output([batch])[0]
  200. return (wX, qX)
  201. def _update_weights(self, layer_idx: int, Q: array):
  202. """Updates the weights of the quantized neural network given a layer index and
  203. quantized weights.
  204. Parameters
  205. -----------
  206. layer_idx : int
  207. Index of the Conv2D layer.
  208. Q : array
  209. The quantized weights.
  210. """
  211. # Update the quantized network. Use the same bias vector as in the analog network for now.
  212. if self.trained_net_layers[layer_idx].use_bias:
  213. bias = self.trained_net_layers[layer_idx].get_weights()[1]
  214. self.quantized_net_layers[layer_idx].set_weights([Q, bias])
  215. else:
  216. self.quantized_net_layers[layer_idx].set_weights([Q])
  217. def _quantize_layer(self, layer_idx: int):
  218. """Quantizes a Dense layer of a multi-layer perceptron.
  219. Parameters
  220. -----------
  221. layer_idx : int
  222. Index of the Dense layer.
  223. """
  224. W = self.trained_net_layers[layer_idx].get_weights()[0]
  225. N_ell, N_ell_plus_1 = W.shape
  226. # Placeholder for the weight matrix in the quantized network.
  227. Q = zeros(W.shape)
  228. N_ell_plus_1 = W.shape[1]
  229. wX, qX = self._get_layer_data(layer_idx)
  230. # Set the radius of the alphabet.
  231. rad = self.alphabet_scalar * median(abs(W.flatten()))
  232. for neuron_idx in range(N_ell_plus_1):
  233. self._log(f"\tQuantizing neuron {neuron_idx} of {N_ell_plus_1}...")
  234. tic = time()
  235. qNeuron = self._quantize_neuron(layer_idx, neuron_idx, wX, qX, rad)
  236. Q[:, neuron_idx] = qNeuron.q
  237. self._log(f"\tdone. {time() - tic :.2f} seconds.")
  238. self._update_weights(layer_idx, Q)
  239. def quantize_network(self):
  240. """Quantizes all Dense layers that are not specified by the list of ignored layers."""
  241. # This must be done sequentially.
  242. for layer_idx, layer in enumerate(self.trained_net_layers):
  243. if (
  244. layer.__class__.__name__ == "Dense"
  245. and layer_idx not in self.ignore_layers
  246. ):
  247. # Only quantize dense layers.
  248. self._log(f"Quantizing layer {layer_idx}...")
  249. self._quantize_layer(layer_idx)
  250. self._log(f"done. {layer_idx}...")