binary_net.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518
  1. """
  2. Adopted from https://github.com/uranusx86/BinaryNet-on-tensorflow
  3. """
  4. # coding=UTF-8
  5. import tensorflow as tf
  6. from tensorflow import keras
  7. from tensorflow.python.framework import tensor_shape, ops
  8. from tensorflow.python.ops import standard_ops, nn, variable_scope, math_ops, control_flow_ops
  9. from tensorflow.python.eager import context
  10. from tensorflow.python.training import optimizer, training_ops
  11. import numpy as np
  12. # Warning: if you have a @property getter/setter function in a class, must inherit from object class
  13. all_layers = []
  14. def hard_sigmoid(x):
  15. return tf.clip_by_value((x + 1.) / 2., 0, 1)
  16. def round_through(x):
  17. """
  18. Element-wise rounding to the closest integer with full gradient propagation.
  19. A trick from [Sergey Ioffe](http://stackoverflow.com/a/36480182)
  20. a op that behave as f(x) in forward mode,
  21. but as g(x) in the backward mode.
  22. """
  23. rounded = tf.round(x)
  24. return x + tf.stop_gradient(rounded - x)
  25. # The neurons' activations binarization function
  26. # It behaves like the sign function during forward propagation
  27. # And like:
  28. # hard_tanh(x) = 2*hard_sigmoid(x)-1
  29. # during back propagation
  30. def binary_tanh_unit(x):
  31. return 2. * round_through(hard_sigmoid(x)) - 1.
  32. def binary_sigmoid_unit(x):
  33. return round_through(hard_sigmoid(x))
  34. # The weights' binarization function,
  35. # taken directly from the BinaryConnect github repository
  36. # (which was made available by his authors)
  37. def binarization(W, H, binary=True, deterministic=False, stochastic=False, srng=None):
  38. dim = W.get_shape().as_list()
  39. # (deterministic == True) <-> test-time <-> inference-time
  40. if not binary or (deterministic and stochastic):
  41. # print("not binary")
  42. Wb = W
  43. else:
  44. # [-1,1] -> [0,1]
  45. # Wb = hard_sigmoid(W/H)
  46. # Wb = T.clip(W/H,-1,1)
  47. # Stochastic BinaryConnect
  48. '''
  49. if stochastic:
  50. # print("stoch")
  51. Wb = tf.cast(srng.binomial(n=1, p=Wb, size=tf.shape(Wb)), tf.float32)
  52. '''
  53. # Deterministic BinaryConnect (round to nearest)
  54. # else:
  55. # print("det")
  56. # Wb = tf.round(Wb)
  57. # 0 or 1 -> -1 or 1
  58. # Wb = tf.where(tf.equal(Wb, 1.0), tf.ones_like(W), -tf.ones_like(W)) # cant differential
  59. Wb = H * binary_tanh_unit(W / H)
  60. return Wb
  61. class DenseBinaryLayer(keras.layers.Dense):
  62. def __init__(self, output_dim,
  63. activation=None,
  64. use_bias=True,
  65. binary=True, stochastic=True, H=1., W_LR_scale="Glorot",
  66. kernel_initializer=tf.glorot_normal_initializer(),
  67. bias_initializer=tf.zeros_initializer(),
  68. kernel_regularizer=None,
  69. bias_regularizer=None,
  70. activity_regularizer=None,
  71. kernel_constraint=None,
  72. bias_constraint=None,
  73. trainable=True,
  74. name=None,
  75. **kwargs):
  76. super(DenseBinaryLayer, self).__init__(
  77. units=output_dim,
  78. activation=activation,
  79. use_bias=use_bias,
  80. kernel_initializer=kernel_initializer,
  81. bias_initializer=bias_initializer,
  82. kernel_regularizer=kernel_regularizer,
  83. bias_regularizer=bias_regularizer,
  84. activity_regularizer=activity_regularizer,
  85. kernel_constraint=kernel_constraint,
  86. bias_constraint=bias_constraint,
  87. trainable=trainable,
  88. name=name,
  89. **kwargs
  90. )
  91. self.binary = binary
  92. self.stochastic = stochastic
  93. self.H = H
  94. self.W_LR_scale = W_LR_scale
  95. all_layers.append(self)
  96. def build(self, input_shape):
  97. num_inputs = tensor_shape.TensorShape(input_shape).as_list()[-1]
  98. num_units = self.units
  99. print(num_units)
  100. if self.H == "Glorot":
  101. self.H = np.float32(np.sqrt(1.5 / (num_inputs + num_units))) # weight init method
  102. self.W_LR_scale = np.float32(1. / np.sqrt(1.5 / (num_inputs + num_units))) # each layer learning rate
  103. print("H = ", self.H)
  104. print("LR scale = ", self.W_LR_scale)
  105. self.kernel_initializer = tf.random_uniform_initializer(-self.H, self.H)
  106. self.kernel_constraint = lambda w: tf.clip_by_value(w, -self.H, self.H)
  107. '''
  108. self.b_kernel = self.add_variable('binary_weight',
  109. shape=[input_shape[-1], self.units],
  110. initializer=self.kernel_initializer,
  111. regularizer=None,
  112. constraint=None,
  113. dtype=self.dtype,
  114. trainable=False) # add_variable must execute before call build()
  115. '''
  116. self.b_kernel = self.add_variable('binary_weight',
  117. shape=[input_shape[-1], self.units],
  118. initializer=tf.random_uniform_initializer(-self.H, self.H),
  119. regularizer=None,
  120. constraint=None,
  121. dtype=self.dtype,
  122. trainable=False)
  123. super(DenseBinaryLayer, self).build(input_shape)
  124. # tf.add_to_collection('real', self.trainable_variables)
  125. # tf.add_to_collection(self.name + '_binary', self.kernel) # layer-wise group
  126. # tf.add_to_collection('binary', self.kernel) # global group
  127. def call(self, inputs):
  128. inputs = ops.convert_to_tensor(inputs, dtype=self.dtype)
  129. shape = inputs.get_shape().as_list()
  130. # binarization weight
  131. self.b_kernel = binarization(self.kernel, self.H)
  132. # r_kernel = self.kernel
  133. # self.kernel = self.b_kernel
  134. print("shape: ", len(shape))
  135. if len(shape) > 2:
  136. # Broadcasting is required for the inputs.
  137. outputs = standard_ops.tensordot(inputs, self.b_kernel, [[len(shape) - 1], [0]])
  138. # Reshape the output back to the original ndim of the input.
  139. if context.in_graph_mode():
  140. output_shape = shape[:-1] + [self.units]
  141. outputs.set_shape(output_shape)
  142. else:
  143. outputs = standard_ops.matmul(inputs, self.b_kernel)
  144. # restore weight
  145. # self.kernel = r_kernel
  146. if self.use_bias:
  147. outputs = nn.bias_add(outputs, self.bias)
  148. if self.activation is not None:
  149. return self.activation(outputs)
  150. return outputs
  151. # Functional interface for the Dense_BinaryLayer class.
  152. def dense_binary(
  153. inputs, units,
  154. activation=None,
  155. use_bias=True,
  156. binary=True, stochastic=True, H=1., W_LR_scale="Glorot",
  157. kernel_initializer=tf.glorot_normal_initializer(),
  158. bias_initializer=tf.zeros_initializer(),
  159. kernel_regularizer=None,
  160. bias_regularizer=None,
  161. activity_regularizer=None,
  162. kernel_constraint=None,
  163. bias_constraint=None,
  164. trainable=True,
  165. name=None,
  166. reuse=None):
  167. layer = DenseBinaryLayer(units,
  168. activation=activation,
  169. use_bias=use_bias,
  170. binary=binary, stochastic=stochastic, H=H, W_LR_scale=W_LR_scale,
  171. kernel_initializer=kernel_initializer,
  172. bias_initializer=bias_initializer,
  173. kernel_regularizer=kernel_regularizer,
  174. bias_regularizer=bias_regularizer,
  175. activity_regularizer=activity_regularizer,
  176. kernel_constraint=kernel_constraint,
  177. bias_constraint=bias_constraint,
  178. trainable=trainable,
  179. name=name,
  180. dtype=inputs.dtype.base_dtype,
  181. _scope=name,
  182. _reuse=reuse)
  183. return layer.apply(inputs)
  184. # Not yet binarized
  185. class BatchNormalization(keras.layers.BatchNormalization):
  186. def __init__(self,
  187. axis=-1,
  188. momentum=0.99,
  189. epsilon=1e-3,
  190. center=True,
  191. scale=True,
  192. beta_initializer=tf.zeros_initializer(),
  193. gamma_initializer=tf.ones_initializer(),
  194. moving_mean_initializer=tf.zeros_initializer(),
  195. moving_variance_initializer=tf.ones_initializer(),
  196. beta_regularizer=None,
  197. gamma_regularizer=None,
  198. beta_constraint=None,
  199. gamma_constraint=None,
  200. renorm=False,
  201. renorm_clipping=None,
  202. renorm_momentum=0.99,
  203. fused=None,
  204. trainable=True,
  205. name=None,
  206. **kwargs):
  207. super(BatchNormalization, self).__init__(
  208. axis=axis,
  209. momentum=momentum,
  210. epsilon=epsilon,
  211. center=center,
  212. scale=scale,
  213. beta_initializer=beta_initializer,
  214. gamma_initializer=gamma_initializer,
  215. moving_mean_initializer=moving_mean_initializer,
  216. moving_variance_initializer=moving_variance_initializer,
  217. beta_regularizer=beta_regularizer,
  218. gamma_regularizer=gamma_regularizer,
  219. beta_constraint=beta_constraint,
  220. gamma_constraint=gamma_constraint,
  221. renorm=renorm,
  222. renorm_clipping=renorm_clipping,
  223. renorm_momentum=renorm_momentum,
  224. fused=fused,
  225. trainable=trainable,
  226. name=name,
  227. **kwargs)
  228. # all_layers.append(self)
  229. def build(self, input_shape):
  230. super(BatchNormalization, self).build(input_shape)
  231. self.W_LR_scale = np.float32(1.)
  232. # Functional interface for the batch normalization layer.
  233. def batch_normalization(
  234. inputs,
  235. axis=-1,
  236. momentum=0.99,
  237. epsilon=1e-3,
  238. center=True,
  239. scale=True,
  240. beta_initializer=tf.zeros_initializer(),
  241. gamma_initializer=tf.ones_initializer(),
  242. moving_mean_initializer=tf.zeros_initializer(),
  243. moving_variance_initializer=tf.ones_initializer(),
  244. beta_regularizer=None,
  245. gamma_regularizer=None,
  246. beta_constraint=None,
  247. gamma_constraint=None,
  248. training=False,
  249. trainable=True,
  250. name=None,
  251. reuse=None,
  252. renorm=False,
  253. renorm_clipping=None,
  254. renorm_momentum=0.99,
  255. fused=None):
  256. layer = BatchNormalization(
  257. axis=axis,
  258. momentum=momentum,
  259. epsilon=epsilon,
  260. center=center,
  261. scale=scale,
  262. beta_initializer=beta_initializer,
  263. gamma_initializer=gamma_initializer,
  264. moving_mean_initializer=moving_mean_initializer,
  265. moving_variance_initializer=moving_variance_initializer,
  266. beta_regularizer=beta_regularizer,
  267. gamma_regularizer=gamma_regularizer,
  268. beta_constraint=beta_constraint,
  269. gamma_constraint=gamma_constraint,
  270. renorm=renorm,
  271. renorm_clipping=renorm_clipping,
  272. renorm_momentum=renorm_momentum,
  273. fused=fused,
  274. trainable=trainable,
  275. name=name,
  276. dtype=inputs.dtype.base_dtype,
  277. _reuse=reuse,
  278. _scope=name
  279. )
  280. return layer.apply(inputs, training=training)
  281. class AdamOptimizer(optimizer.Optimizer):
  282. """Optimizer that implements the Adam algorithm.
  283. See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
  284. ([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
  285. """
  286. def __init__(self, weight_scale, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
  287. use_locking=False, name="Adam"):
  288. super(AdamOptimizer, self).__init__(use_locking, name)
  289. self._lr = learning_rate
  290. self._beta1 = beta1
  291. self._beta2 = beta2
  292. self._epsilon = epsilon
  293. # BNN weight scale factor
  294. self._weight_scale = weight_scale
  295. # Tensor versions of the constructor arguments, created in _prepare().
  296. self._lr_t = None
  297. self._beta1_t = None
  298. self._beta2_t = None
  299. self._epsilon_t = None
  300. # Variables to accumulate the powers of the beta parameters.
  301. # Created in _create_slots when we know the variables to optimize.
  302. self._beta1_power = None
  303. self._beta2_power = None
  304. # Created in SparseApply if needed.
  305. self._updated_lr = None
  306. def _get_beta_accumulators(self):
  307. return self._beta1_power, self._beta2_power
  308. def _non_slot_variables(self):
  309. return self._get_beta_accumulators()
  310. def _create_slots(self, var_list):
  311. first_var = min(var_list, key=lambda x: x.name)
  312. create_new = self._beta1_power is None
  313. if not create_new and context.in_graph_mode():
  314. create_new = (self._beta1_power.graph is not first_var.graph)
  315. if create_new:
  316. with ops.colocate_with(first_var):
  317. self._beta1_power = variable_scope.variable(self._beta1,
  318. name="beta1_power",
  319. trainable=False)
  320. self._beta2_power = variable_scope.variable(self._beta2,
  321. name="beta2_power",
  322. trainable=False)
  323. # Create slots for the first and second moments.
  324. for v in var_list:
  325. self._zeros_slot(v, "m", self._name)
  326. self._zeros_slot(v, "v", self._name)
  327. def _prepare(self):
  328. self._lr_t = ops.convert_to_tensor(self._lr, name="learning_rate")
  329. self._beta1_t = ops.convert_to_tensor(self._beta1, name="beta1")
  330. self._beta2_t = ops.convert_to_tensor(self._beta2, name="beta2")
  331. self._epsilon_t = ops.convert_to_tensor(self._epsilon, name="epsilon")
  332. def _apply_dense(self, grad, var):
  333. m = self.get_slot(var, "m")
  334. v = self.get_slot(var, "v")
  335. # for BNN kernel
  336. # origin version clipping weight method is new_w = old_w + scale*(new_w - old_w)
  337. # and adam update function is new_w = old_w - lr_t * m_t / (sqrt(v_t) + epsilon)
  338. # so subtitute adam function into weight clipping
  339. # new_w = old_w - (scale * lr_t * m_t) / (sqrt(v_t) + epsilon)
  340. scale = self._weight_scale[var.name] / 4
  341. return training_ops.apply_adam(
  342. var, m, v,
  343. math_ops.cast(self._beta1_power, var.dtype.base_dtype),
  344. math_ops.cast(self._beta2_power, var.dtype.base_dtype),
  345. math_ops.cast(self._lr_t * scale, var.dtype.base_dtype),
  346. math_ops.cast(self._beta1_t, var.dtype.base_dtype),
  347. math_ops.cast(self._beta2_t, var.dtype.base_dtype),
  348. math_ops.cast(self._epsilon_t, var.dtype.base_dtype),
  349. grad, use_locking=self._use_locking).op
  350. def _resource_apply_dense(self, grad, var):
  351. m = self.get_slot(var, "m")
  352. v = self.get_slot(var, "v")
  353. return training_ops.resource_apply_adam(
  354. var.handle, m.handle, v.handle,
  355. math_ops.cast(self._beta1_power, grad.dtype.base_dtype),
  356. math_ops.cast(self._beta2_power, grad.dtype.base_dtype),
  357. math_ops.cast(self._lr_t, grad.dtype.base_dtype),
  358. math_ops.cast(self._beta1_t, grad.dtype.base_dtype),
  359. math_ops.cast(self._beta2_t, grad.dtype.base_dtype),
  360. math_ops.cast(self._epsilon_t, grad.dtype.base_dtype),
  361. grad, use_locking=self._use_locking)
  362. def _apply_sparse_shared(self, grad, var, indices, scatter_add):
  363. beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype)
  364. beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype)
  365. lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
  366. beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
  367. beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
  368. epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
  369. lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
  370. # m_t = beta1 * m + (1 - beta1) * g_t
  371. m = self.get_slot(var, "m")
  372. m_scaled_g_values = grad * (1 - beta1_t)
  373. m_t = state_ops.assign(m, m * beta1_t,
  374. use_locking=self._use_locking)
  375. with ops.control_dependencies([m_t]):
  376. m_t = scatter_add(m, indices, m_scaled_g_values)
  377. # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
  378. v = self.get_slot(var, "v")
  379. v_scaled_g_values = (grad * grad) * (1 - beta2_t)
  380. v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
  381. with ops.control_dependencies([v_t]):
  382. v_t = scatter_add(v, indices, v_scaled_g_values)
  383. v_sqrt = math_ops.sqrt(v_t)
  384. var_update = state_ops.assign_sub(var,
  385. lr * m_t / (v_sqrt + epsilon_t),
  386. use_locking=self._use_locking)
  387. return control_flow_ops.group(*[var_update, m_t, v_t])
  388. def _apply_sparse(self, grad, var):
  389. return self._apply_sparse_shared(
  390. grad.values, var, grad.indices,
  391. lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
  392. x, i, v, use_locking=self._use_locking))
  393. def _resource_scatter_add(self, x, i, v):
  394. with ops.control_dependencies(
  395. [resource_variable_ops.resource_scatter_add(
  396. x.handle, i, v)]):
  397. return x.value()
  398. def _resource_apply_sparse(self, grad, var, indices):
  399. return self._apply_sparse_shared(
  400. grad, var, indices, self._resource_scatter_add)
  401. def _finish(self, update_ops, name_scope):
  402. # Update the power accumulators.
  403. with ops.control_dependencies(update_ops):
  404. with ops.colocate_with(self._beta1_power):
  405. update_beta1 = self._beta1_power.assign(
  406. self._beta1_power * self._beta1_t,
  407. use_locking=self._use_locking)
  408. update_beta2 = self._beta2_power.assign(
  409. self._beta2_power * self._beta2_t,
  410. use_locking=self._use_locking)
  411. return control_flow_ops.group(*update_ops + [update_beta1, update_beta2],
  412. name=name_scope)
  413. def get_all_layers():
  414. return all_layers
  415. def get_all_LR_scale():
  416. return {layer.kernel.name: layer.W_LR_scale for layer in get_all_layers()}
  417. # This function computes the gradient of the binary weights
  418. def compute_grads(loss, opt):
  419. layers = get_all_layers()
  420. grads_list = []
  421. update_weights = []
  422. for layer in layers:
  423. # refer to self.params[self.W]=set(['binary'])
  424. # The list can optionally be filtered by specifying tags as keyword arguments.
  425. # For example,
  426. # ``trainable=True`` will only return trainable parameters, and
  427. # ``regularizable=True`` will only return parameters that can be regularized
  428. # function return, e.g. [W, b] for dense layer
  429. params = tf.get_collection(layer.name + "_binary")
  430. if params:
  431. # print(params[0].name)
  432. # theano.grad(cost, wrt) -> d(cost)/d(wrt)
  433. # wrt – with respect to which we want gradients
  434. # http://blog.csdn.net/shouhuxianjian/article/details/46517143
  435. # http://blog.csdn.net/qq_33232071/article/details/52806630
  436. # grad = opt.compute_gradients(loss, layer.b_kernel) # origin version
  437. grad = opt.compute_gradients(loss, params[0]) # modify
  438. print("grad: ", grad)
  439. grads_list.append(grad[0][0])
  440. update_weights.extend(params)
  441. print(grads_list)
  442. print(update_weights)
  443. return zip(grads_list, update_weights)