소스 검색

Quick fix signal model for autoencoder

Min 5 년 전
부모
커밋
fe3da6cd95
2개의 변경된 파일9개의 추가작업 그리고 7개의 파일을 삭제
  1. 6 4
      defs.py
  2. 3 3
      models/autoencoder.py

+ 6 - 4
defs.py

@@ -1,6 +1,6 @@
 import math
 import numpy as np
-
+import tensorflow as tf
 
 class Signal:
 
@@ -39,11 +39,13 @@ class Channel(COMComponent):
     """
 
     def forward(self, values: Signal) -> Signal:
+        raise NotImplemented("Need to define forward function")
+
+    def forward_tensor(self, tensor: tf.Tensor) -> tf.Tensor:
         """
-        :param values: value generator, each iteration returns tuple of (amplitude, phase, frequency)
-        :return: affected tuple of (amplitude, phase, frequency)
+        Forward operation optimised for tensorflow tensors
         """
-        raise NotImplemented("Need to define forward function")
+        raise NotImplemented("Need to define forward_tensor function")
 
 
 class ModComponent(COMComponent):

+ 3 - 3
models/autoencoder.py

@@ -24,7 +24,7 @@ class AutoencoderMod(defs.Modulator):
         super().__init__(2 ** autoencoder.N)
         self.autoencoder = autoencoder
 
-    def forward(self, binary: np.ndarray) -> defs.Signal:
+    def forward(self, binary: np.ndarray):
         reshaped = binary.reshape((-1, self.N))
         reshaped_ho = misc.bit_matrix2one_hot(reshaped)
         encoded = self.autoencoder.encoder(reshaped_ho)
@@ -33,7 +33,7 @@ class AutoencoderMod(defs.Modulator):
 
         f = np.zeros(x2.shape[0])
         x3 = misc.rect2polar(np.c_[x2[:, 0], x2[:, 1], f])
-        return defs.Signal(x3)
+        return basic.RFSignal(x3)
 
 
 class AutoencoderDemod(defs.Demodulator):
@@ -140,7 +140,7 @@ class Autoencoder(Model):
         x = x.reshape((-1, 2))
         f = np.zeros(x.shape[0])
         xf = np.c_[x[:, 0], x[:, 1], f]
-        y = demod.forward(defs.Signal(misc.rect2polar(xf)))
+        y = demod.forward(basic.RFSignal(misc.rect2polar(xf)))
         y_ho = misc.bit_matrix2one_hot(y.reshape((-1, 4)))
 
         X_train, X_test, y_train, y_test = train_test_split(x, y_ho)