diff options
Diffstat (limited to 'gr-wavelet/python/qa_classify.py')
-rwxr-xr-x | gr-wavelet/python/qa_classify.py | 181 |
1 files changed, 181 insertions, 0 deletions
diff --git a/gr-wavelet/python/qa_classify.py b/gr-wavelet/python/qa_classify.py new file mode 100755 index 000000000..5701bce8c --- /dev/null +++ b/gr-wavelet/python/qa_classify.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python +# +# Copyright 2008,2010 Free Software Foundation, Inc. +# +# This file is part of GNU Radio +# +# GNU Radio is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 3, or (at your option) +# any later version. +# +# GNU Radio is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with GNU Radio; see the file COPYING. If not, write to +# the Free Software Foundation, Inc., 51 Franklin Street, +# Boston, MA 02110-1301, USA. +# + +import numpy +from gnuradio import gr, gr_unittest +import copy +#import pygsl.wavelet as wavelet # FIXME: pygsl not checked for in config +import math +import wavelet_swig + +def sqr(x): + return x*x + +def np2(k): + m = 0 + n = k - 1 + while n > 0: + m += 1 + return m + + +class test_classify(gr_unittest.TestCase): + + def setUp(self): + self.tb = gr.top_block() + + def tearDown(self): + self.tb = None + +# def test_000_(self): +# src_data = numpy.zeros(10) +# trg_data = numpy.zeros(10) +# src = gr.vector_source_f(src_data) +# dst = gr.vector_sink_f() +# self.tb.connect(src, dst) +# self.tb.run() +# rsl_data = dst.data() +# sum = 0 +# for (u,v) in zip(trg_data, rsl_data): +# w = u - v +# sum += w * w +# sum /= float(len(trg_data)) +# assert sum < 1e-6 + + def test_001_(self): + src_data = numpy.array([-1.0, 1.0, -1.0, 1.0]) + trg_data = src_data * 0.5 + src = gr.vector_source_f(src_data) + dst = gr.vector_sink_f() + rail = gr.rail_ff(-0.5, 0.5) + self.tb.connect(src, rail) + self.tb.connect(rail, dst) + self.tb.run() + rsl_data = dst.data() + sum = 0 + for (u, v) in zip(trg_data, rsl_data): + w = u - v + sum += w * w + sum /= float(len(trg_data)) + assert sum < 1e-6 + + def test_002_(self): + src_data = numpy.array([-1.0, + -1.0/2.0, + -1.0/3.0, + -1.0/4.0, + -1.0/5.0]) + trg_data = copy.deepcopy(src_data) + + src = gr.vector_source_f(src_data, False, len(src_data)) + st = gr.stretch_ff(-1.0/5.0, len(src_data)) + dst = gr.vector_sink_f(len(src_data)) + self.tb.connect(src, st) + self.tb.connect(st, dst) + self.tb.run() + rsl_data = dst.data() + sum = 0 + for (u, v) in zip(trg_data, rsl_data): + w = u - v + sum += w * w + sum /= float(len(trg_data)) + assert sum < 1e-6 + + def test_003_(self): + src_grid = (0.0, 1.0, 2.0, 3.0, 4.0) + trg_grid = copy.deepcopy(src_grid) + src_data = (0.0, 1.0, 0.0, 1.0, 0.0) + + src = gr.vector_source_f(src_data, False, len(src_grid)) + sq = wavelet_swig.squash_ff(src_grid, trg_grid) + dst = gr.vector_sink_f(len(trg_grid)) + self.tb.connect(src, sq) + self.tb.connect(sq, dst) + self.tb.run() + rsl_data = dst.data() + sum = 0 + for (u, v) in zip(src_data, rsl_data): + w = u - v + sum += w * w + sum /= float(len(src_data)) + assert sum < 1e-6 + +# def test_004_(self): # FIXME: requires pygsl +# +# n = 256 +# o = 4 +# ws = wavelet.workspace(n) +# w = wavelet.daubechies(o) +# +# a = numpy.arange(n) +# b = numpy.sin(a*numpy.pi/16.0) +# c = w.transform_forward(b, ws) +# d = w.transform_inverse(c, ws) +# +# src = gr.vector_source_f(b, False, n) +# wv = wavelet_swig.wavelet_ff(n, o, True) +# +# dst = gr.vector_sink_f(n) +# self.tb.connect(src, wv) +# self.tb.connect(wv, dst) +# self.tb.run() +# e = dst.data() +# +# sum = 0 +# for (u, v) in zip(c, e): +# w = u - v +# sum += w * w +# sum /= float(len(c)) +# assert sum < 1e-6 + + def test_005_(self): + + src_data = (1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0) + + dwav = numpy.array(src_data) + wvps = numpy.zeros(3) + # wavelet power spectrum + scl = 1.0/sqr(dwav[0]) + k = 1 + for e in range(len(wvps)): + wvps[e] = scl*sqr(dwav[k:k+(01<<e)]).sum() + k += 01<<e + + src = gr.vector_source_f(src_data, False, len(src_data)) + kon = wavelet_swig.wvps_ff(len(src_data)) + dst = gr.vector_sink_f(int(math.ceil(math.log(len(src_data), 2)))) + + self.tb.connect(src, kon) + self.tb.connect(kon, dst) + + self.tb.run() + snk_data = dst.data() + + sum = 0 + for (u,v) in zip(snk_data, wvps): + w = u - v + sum += w * w + sum /= float(len(snk_data)) + assert sum < 1e-6 + +if __name__ == '__main__': + gr_unittest.run(test_classify, "test_classify.xml") |