Browse Source

Quantised net fix

Min 4 years ago
parent
commit
86474cdd4f
1 changed files with 6 additions and 4 deletions
  1. 6 4
      models/quantized_net.py

+ 6 - 4
models/quantized_net.py

@@ -208,6 +208,7 @@ class QuantizedNeuralNetwork:
         else:
             # Define functions which will give you the output of the previous hidden layer
             # for both networks.
+            # try:
             prev_trained_output = Kfunction(
                 [self.trained_net_layers[0].input],
                 [self.trained_net_layers[layer_idx - 1].output],
@@ -219,7 +220,8 @@ class QuantizedNeuralNetwork:
             input_layer = self.trained_net_layers[0]
             input_shape = input_layer.input_shape[1:] if input_layer.input_shape[0] is None else input_layer.input_shape
             batch = zeros((self.batch_size, *input_shape))
-
+            # except Exception:
+            #     pass
             # TODO: Add hf option here. Feed batches of data through rather than all at once. You may want
             # to reconsider how much memory you preallocate for batch, wX, and qX.
             feed_foward_batch_size = 500
@@ -275,12 +277,12 @@ class QuantizedNeuralNetwork:
         rad = self.alphabet_scalar * median(abs(W.flatten()))
 
         for neuron_idx in range(N_ell_plus_1):
-            self._log(f"\tQuantizing neuron {neuron_idx} of {N_ell_plus_1}...")
+            # self._log(f"\tQuantizing neuron {neuron_idx} of {N_ell_plus_1}...")
             tic = time()
             qNeuron = self._quantize_neuron(layer_idx, neuron_idx, wX, qX, rad)
             Q[:, neuron_idx] = qNeuron.q
 
-            self._log(f"\tdone. {time() - tic :.2f} seconds.")
+            # self._log(f"\tdone. {time() - tic :.2f} seconds.")
 
             self._update_weights(layer_idx, Q)
 
@@ -296,4 +298,4 @@ class QuantizedNeuralNetwork:
                 # Only quantize dense layers.
                 self._log(f"Quantizing layer {layer_idx}...")
                 self._quantize_layer(layer_idx)
-                self._log(f"done. {layer_idx}...")
+                self._log(f"done #{layer_idx} {layer}...")