|
@@ -26,6 +26,35 @@ class AwgnChannel(layers.Layer):
|
|
|
return self.noise_layer.call(inputs, training=True)
|
|
return self.noise_layer.call(inputs, training=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+class BitsToSymbols(layers.Layer):
|
|
|
|
|
+ def __init__(self, cardinality):
|
|
|
|
|
+ super(BitsToSymbols, self).__init__()
|
|
|
|
|
+
|
|
|
|
|
+ self.cardinality = cardinality
|
|
|
|
|
+
|
|
|
|
|
+ n = int(tf.math.log(self.cardinality, 2))
|
|
|
|
|
+ self.pows = tf.convert_to_tensor(np.power(2, np.linspace(n-1, 0, n)).reshape(-1, 1), dtype=tf.float32)
|
|
|
|
|
+
|
|
|
|
|
+ def call(self, inputs, **kwargs):
|
|
|
|
|
+ idx = tf.cast(tf.tensordot(inputs, self.pows, axes=1), dtype=tf.int32)
|
|
|
|
|
+ out = tf.one_hot(idx, self.cardinality)
|
|
|
|
|
+ return layers.Reshape((9, 32))(out)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class SymbolsToBits(layers.Layer):
|
|
|
|
|
+ def __init__(self, cardinality):
|
|
|
|
|
+ super(SymbolsToBits, self).__init__()
|
|
|
|
|
+
|
|
|
|
|
+ n = int(tf.math.log(cardinality, 2))
|
|
|
|
|
+ lst = [list(i) for i in itertools.product([0, 1], repeat=n)]
|
|
|
|
|
+
|
|
|
|
|
+ # self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
|
|
|
|
|
+ self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
|
|
|
|
|
+
|
|
|
|
|
+ def call(self, inputs, **kwargs):
|
|
|
|
|
+ return tf.matmul(inputs, self.all_syms)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
class ScaleAndOffset(layers.Layer):
|
|
class ScaleAndOffset(layers.Layer):
|
|
|
"""
|
|
"""
|
|
|
Scales and offsets a tensor
|
|
Scales and offsets a tensor
|