// ----------------------------------------------------------------------------
//
//  Copyright (C) 2015-2021 Fons Adriaensen <fons@linuxaudio.org>
//    
//  This program is free software; you can redistribute it and/or modify
//  it under the terms of the GNU General Public License as published by
//  the Free Software Foundation; either version 3 of the License, or
//  (at your option) any later version.
//
//  This program is distributed in the hope that it will be useful,
//  but WITHOUT ANY WARRANTY; without even the implied warranty of
//  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//  GNU General Public License for more details.
//
//  You should have received a copy of the GNU General Public License
//  along with this program.  If not, see <http://www.gnu.org/licenses/>.
//
// ----------------------------------------------------------------------------


#include <unistd.h>
#include <string.h>
#include "binconv.h"


// -------------------------------------------------------------------------------


Fdata::Fdata (int npar, int nbin):
    _npar (npar),
    _nbin (nbin)
{
    _data = new fftwf_complex* [nbin];
    for (int i = 0; i < npar; i++)
    {
        _data [i] = (fftwf_complex *)(fftwf_malloc (nbin * sizeof (fftwf_complex)));
    }
    clear ();
}


Fdata::~Fdata (void)
{
    for (int i = 0; i < _npar; i++)
    {
        fftwf_free (_data [i]);
    }
    delete[] _data;
}


void Fdata::clear ()
{
    for (int i = 0; i < _npar; i++)
    {
        memset (_data [i], 0, _nbin * sizeof (fftwf_complex));
        _nact = 0;
    }
}


// -------------------------------------------------------------------------------


Binconv::Binconv (int degree, int size, int frag):
    _degr (degree),
    _size (size),
    _frag (frag),
    _tfilt (0),
    _tdomS (0),
    _tdomD (0),
    _faccS (0),
    _faccD (0),
    _saveS (0),
    _saveD (0)
{
    if (_degr > MAXDEGR) _degr = MAXDEGR;
    if (_size > MAXSIZE) _size = MAXSIZE;
    _ninp = (_degr + 1) * (_degr + 1);
    _lfft = 2 * _frag;
    _nbin = _frag + 1;
    _npar = (_size + _frag - 1) / _frag;
    _ipar = 0;

    // These are allocated using fftw to ensure correct alignment.
    _tfilt = (float *) (fftwf_malloc (_lfft * sizeof (float)));
    _tdomS = (float *) (fftwf_malloc (_lfft * sizeof (float)));
    _tdomD = (float *) (fftwf_malloc (_lfft * sizeof (float)));
    _faccS = (fftwf_complex *)(fftwf_malloc (_nbin * sizeof (fftwf_complex)));
    _faccD = (fftwf_complex *)(fftwf_malloc (_nbin * sizeof (fftwf_complex)));

    // Output overlap buffers.
    _saveS = new float [_frag];
    _saveD = new float [_frag];
    
    // FFTW plans.
    _plan_r2c = fftwf_plan_dft_r2c_1d (_lfft, _tdomS, _faccS, 0);
    _plan_c2r = fftwf_plan_dft_c2r_1d (_lfft, _faccS, _tdomS, 0);

    // Allocate F-domain data.
    for (int i = 0; i < _ninp; i++)
    {
        _fdataA [i] = new Fdata (_npar, _nbin);
        _fdataB [i] = new Fdata (_npar, _nbin);
    }

    // Clear workspace.
    reset ();
}


Binconv::~Binconv (void)
{
    fftwf_destroy_plan (_plan_r2c);
    fftwf_destroy_plan (_plan_c2r);
    fftwf_free (_tfilt);
    fftwf_free (_tdomS);
    fftwf_free (_tdomD);
    fftwf_free (_faccS);
    fftwf_free (_faccD);
    delete[] _saveS;
    delete[] _saveD;
    for (int i = 0; i < _ninp; i++)
    {
        delete _fdataA [i];
        delete _fdataB [i];
    }
}


void Binconv::reset (void)
{
    // Clear stored input data and overlap buffers.
    for (int i = 0; i < _ninp; i++) _fdataB [i]->clear ();
    memset (_saveS, 0, _frag * sizeof (float));
    memset (_saveD, 0, _frag * sizeof (float));
    _ipar = 0;
}


int Binconv::setimp (int inp, float gain, const float *data, int size, int step)
{
    int     i, j, n;
    Fdata   *FA;

    // Check valid input.
    if ((inp < 0) || (inp >= _ninp)) return 1;

    // Clear current filter data.
    FA = _fdataA [inp];
    FA->clear ();
    if (! data) return 0;

    // Transform filter to F-domain and store.
    gain /= _lfft;
    for (i = 0; i < _npar; i++)
    {
        if (! size) break;
        n = (size < _frag) ? size : _frag;
        // Copy to first half of _tfilt, clear the
        // remaining part and transform to F-domain.
        for (j = 0; j < n; j++)
        {
            _tfilt [j] = gain * data [j * step];
        }
        memset (_tfilt + n, 0, (_lfft - n) * sizeof (float));
        fftwf_execute_dft_r2c (_plan_r2c, _tfilt, FA->_data [i]);       
        // Prepare for next partition.         
        data += n * step;
        size -= n;
    }
    // Remember how many partitions are used.
    FA->_nact = i;
    return 0;
}


void Binconv::process (float *inp [], float *out [2])
{
    int     i, inext, dnext;
    float   s, d;
    bool    sigma;
    float   *L, *R;

    // Clear F-domain accumulators.
    memset (_faccS, 0, _nbin * sizeof (fftwf_complex));
    memset (_faccD, 0, _nbin * sizeof (fftwf_complex));

    // Loop over all inputs.
    sigma = true;
    inext = 1;
    dnext = 1;
    for (i = 0; i < _ninp; i++)
    {
        // Switch between sigma and delta accumulators.
        if (i == inext)
        {
            sigma = !sigma;
            if (sigma) dnext++;
            inext += dnext;
        }
        // Add F-domain product to selected accumulator.
        convadd (inp [i], i, sigma ? _faccS : _faccD);
    }

    // Transform to T-domain, overlap and save,
    // convert S,D to L,R.
    fftwf_execute_dft_c2r (_plan_c2r, _faccS, _tdomS); 
    fftwf_execute_dft_c2r (_plan_c2r, _faccD, _tdomD); 
    L = out [0];
    R = out [1];
    for (i = 0; i < _frag; i++)
    {
        s = _saveS [i] + _tdomS [i];
        d = _saveD [i] + _tdomD [i];
        L [i] = s + d;
        R [i] = s - d;
    }
    memcpy (_saveS, _tdomS + _frag, _frag * sizeof (float));
    memcpy (_saveD, _tdomD + _frag, _frag * sizeof (float));

    // Increment current partition index.
    if (++_ipar == _npar) _ipar = 0;
}


void Binconv::convadd (float *inp, int ind, fftwf_complex *F)
{
    int            i, j, k;
    fftwf_complex  *A, *B;
    Fdata          *FA, *FB;
    
    FA = _fdataA [ind];  // F-domain filter.
    FB = _fdataB [ind];  // F-domain input.

    // Copy input to first half of _tdata, clear
    // second half, and transform to F-domain.
    memcpy (_tdomS, inp, _frag * sizeof (float));
    memset (_tdomS + _frag, 0, _frag * sizeof (float));
    fftwf_execute_dft_r2c (_plan_r2c, _tdomS, FB->_data [_ipar]); 

    // Loop over active partitions.
    j = _ipar;
    for (k = 0; k < FA->_nact; k++)
    {
        // Add product to accumulator.
        A = FA->_data [k];
        B = FB->_data [j];
        for (i = 0; i < _nbin; i++)
        {
            F [i][0] += A [i][0] * B [i][0] - A [i][1] * B [i][1];
            F [i][1] += A [i][0] * B [i][1] + A [i][1] * B [i][0];
        }
        // Next partition.
        if (--j < 0) j += _npar;
    }
}
