/*
 * Decompiled with CFR 0.152.
 */
package org.javaseis.tests;

import java.util.Arrays;
import mpi.MPI;
import mpi.MPIException;
import org.javaseis.fft.SeisFft3d;
import org.javaseis.parallel.DistributedArray;
import org.javaseis.parallel.DistributedArrayPositionIterator;
import org.javaseis.parallel.IParallelContext;
import org.javaseis.parallel.MPIContext;

public class SeisFft3dTestMPI {
    public IParallelContext _pc = new MPIContext();

    public SeisFft3dTestMPI(String[] args) throws MPIException {
        this._pc.init(args);
        this._pc.masterPrint("SeisFft3dTestMPI - 3D FFT Parallel Tests");
        this._pc.serialPrint(MPI.Get_processor_name() + " Task " + this._pc.rank() + " Size " + this._pc.size());
        this.testSimple();
        this.testShapes();
        this._pc.masterPrint("*** org.javaseis.tests.SeisFft3dTestMPI SUCCESS ***");
        this._pc.finish();
    }

    public static void main(String[] args) throws MPIException {
        SeisFft3dTestMPI test = new SeisFft3dTestMPI(args);
    }

    public void testSimple() {
        this._pc.masterPrint("\nSimple Test ... ");
        int nt = 11 * this._pc.size();
        int nx = 3 * this._pc.size();
        int ny = 2 * this._pc.size();
        int[] len = new int[]{nt, nx, ny};
        float[] pad = new float[]{0.0f, 0.0f, 0.0f};
        int[] shape = SeisFft3d.getTransformShape(len, pad, this._pc);
        DistributedArray a = new DistributedArray(this._pc, 2, shape);
        a.setShape(1, len);
        float[] trc = new float[nt];
        int[] position = new int[3];
        DistributedArrayPositionIterator mapi = new DistributedArrayPositionIterator(a, position);
        while (mapi.hasNext()) {
            mapi.next();
            for (int i = 0; i < nt; ++i) {
                trc[i] = (float)i + 10.0f * (float)position[1] + 100.0f * (float)position[2];
            }
            a.putTrace(trc, position);
        }
        SeisFft3d f3d = new SeisFft3d(a);
        this._pc.masterPrint("Forward transform ... ");
        f3d.forward();
        this._pc.masterPrint("Inverse transform ... ");
        f3d.inverse();
        mapi.reset();
        while (mapi.hasNext()) {
            mapi.next();
            a.getTrace(trc, position);
            for (int i = 0; i < nt; ++i) {
                position[0] = i;
                float val = (float)i + 10.0f * (float)position[1] + 100.0f * (float)position[2];
                assert ((double)Math.abs(val - trc[i]) < 0.1) : "Fft3d test failed: position " + Arrays.toString(position) + " Expected " + val + " got " + trc[i];
            }
        }
        this._pc.serialPrint("Rank " + this._pc.rank() + " Simple Test Completed");
    }

    void testShapes() {
        int ntest = 10;
        int ndim = 3;
        int[] maxLengths = new int[]{25, 17, 13};
        int[] lengths = new int[ndim];
        int[] shape = new int[ndim];
        lengths[ndim - 1] = maxLengths[ndim - 1];
        this._pc.masterPrint("\nFFT Shape Tests ...");
        for (int itest = 0; itest < ntest; ++itest) {
            if (this._pc.isMaster()) {
                for (int i = 0; i < ndim; ++i) {
                    lengths[i] = Math.max((int)(Math.random() * (double)maxLengths[i]), this._pc.size());
                }
                shape = SeisFft3d.getTransformShape(lengths, new float[]{0.0f, 0.0f, 0.0f}, this._pc);
                shape[0] = 2 * shape[0];
                this._pc.bcastInt(99, shape, 0, 3, 0);
                this._pc.bcastInt(99, lengths, 0, 3, 0);
            } else {
                this._pc.bcastInt(99, shape, 0, 3, 0);
                this._pc.bcastInt(99, lengths, 0, 3, 0);
            }
            this._pc.masterPrint("Input Shape " + Arrays.toString(lengths) + " -> Padded Shape " + Arrays.toString(shape));
            DistributedArray sa = SeisFft3dTestMPI.initArrayFloat(this._pc, lengths, shape);
            SeisFft3d fft3d = new SeisFft3d(sa);
            fft3d.forward();
            fft3d.inverse();
            this.checkContents("FFT Shape Test iteration " + itest, sa, lengths);
            sa = null;
            fft3d = null;
            System.gc();
        }
        this._pc.masterPrint("FFT Shape Tests Completed");
    }

    public static DistributedArray initArrayFloat(IParallelContext pc, int[] lengths, int[] shape) {
        DistributedArray sa = new DistributedArray(pc, 2, shape);
        sa.setShape(1, lengths);
        int[] position = new int[]{0, 0, 0};
        for (int kl = 0; kl < sa.getLocalLength(2); ++kl) {
            int k = sa.localToGlobal(2, kl);
            if (k >= lengths[2]) continue;
            position[2] = k;
            for (int j = 0; j < lengths[1]; ++j) {
                position[1] = j;
                for (int i = 0; i < lengths[0]; ++i) {
                    position[0] = i;
                    sa.putSample(SeisFft3dTestMPI.testFloat(i, j, k), position);
                }
            }
        }
        return sa;
    }

    public static int testFn(int i, int j) {
        return i + 100 * j;
    }

    public static int testFn(int i, int j, int k) {
        return i + 100 * j + 10000 * k;
    }

    public static int testFn(int i, int j, int k, int m) {
        return i + 100 * j + 10000 * k + 1000000 * m;
    }

    public static float testFloat(int i, int j, int k) {
        return 1.0f * (float)i + 100.0f * (float)j + 10000.0f * (float)k;
    }

    public void checkContents(String title, DistributedArray sa, int[] lengths) {
        int[] position = new int[]{0, 0, 0};
        for (int kl = 0; kl < sa.getLocalLength(2); ++kl) {
            int k = sa.localToGlobal(2, kl);
            if (k >= lengths[2]) continue;
            position[2] = k;
            for (int j = 0; j < lengths[1]; ++j) {
                position[1] = j;
                for (int i = 0; i < lengths[0]; ++i) {
                    position[0] = i;
                    float f2 = sa.getFloat(position);
                    float f1 = SeisFft3dTestMPI.testFloat(i, j, k);
                    this.assertEquals(title, f1, f2, 0.1f);
                }
            }
        }
    }

    private void assertEquals(String title, float arg1, float arg2, float tol) {
        assert (Math.abs(arg2 - arg1) < tol) : title + " Expected " + arg1 + " got " + arg2;
    }
}

