/*
 * Decompiled with CFR 0.152.
 */
package beast.evolution.tree.coalescent;

import beast.core.CalculationNode;
import beast.core.Description;
import beast.core.Input;
import beast.evolution.tree.Node;
import beast.evolution.tree.Tree;
import beast.evolution.tree.coalescent.IntervalList;
import beast.evolution.tree.coalescent.IntervalType;
import beast.util.HeapSort;
import java.util.ArrayList;
import java.util.List;

@Description(value="Extracts the intervals from a tree. Points in the intervals are defined by the heights of nodes in the tree.")
public class TreeIntervals
extends CalculationNode
implements IntervalList {
    public final Input<Tree> treeInput = new Input("tree", "tree for which to calculate the intervals", Input.Validate.REQUIRED);
    protected double[] intervals;
    protected double[] storedIntervals;
    double[] times;
    int[] indices;
    protected int[] lineageCounts;
    protected int[] storedLineageCounts;
    protected List<Node>[] lineagesAdded;
    protected List<Node>[] lineagesRemoved;
    protected int intervalCount = 0;
    protected int storedIntervalCount = 0;
    protected boolean intervalsKnown = false;
    protected double multifurcationLimit = -1.0;

    public TreeIntervals() {
    }

    public TreeIntervals(Tree tree) {
        this.init(tree);
    }

    @Override
    public void initAndValidate() {
        this.calculateIntervals();
        this.intervalsKnown = false;
    }

    @Override
    protected boolean requiresRecalculation() {
        this.intervalsKnown = false;
        return true;
    }

    @Override
    protected void restore() {
        double[] dArray = this.storedIntervals;
        this.storedIntervals = this.intervals;
        this.intervals = dArray;
        int[] nArray = this.storedLineageCounts;
        this.storedLineageCounts = this.lineageCounts;
        this.lineageCounts = nArray;
        int n = this.storedIntervalCount;
        this.storedIntervalCount = this.intervalCount;
        this.intervalCount = n;
        super.restore();
    }

    @Override
    protected void store() {
        System.arraycopy(this.lineageCounts, 0, this.storedLineageCounts, 0, this.lineageCounts.length);
        System.arraycopy(this.intervals, 0, this.storedIntervals, 0, this.intervals.length);
        this.storedIntervalCount = this.intervalCount;
        super.store();
    }

    public void setIntervalsUnknown() {
        this.intervalsKnown = false;
    }

    public void setMultifurcationLimit(double d) {
        if (this.multifurcationLimit != d) {
            this.multifurcationLimit = d;
            this.intervalsKnown = false;
        }
    }

    @Override
    public int getSampleCount() {
        return this.treeInput.get().getInternalNodeCount();
    }

    @Override
    public int getIntervalCount() {
        if (!this.intervalsKnown) {
            this.calculateIntervals();
        }
        return this.intervalCount;
    }

    @Override
    public double getInterval(int n) {
        if (!this.intervalsKnown) {
            this.calculateIntervals();
        }
        if (n < 0 || n >= this.intervalCount) {
            throw new IllegalArgumentException();
        }
        return this.intervals[n];
    }

    public double[] getIntervals(double[] dArray) {
        if (!this.intervalsKnown) {
            this.calculateIntervals();
        }
        if (dArray == null) {
            dArray = new double[this.intervals.length];
        }
        System.arraycopy(this.intervals, 0, dArray, 0, this.intervals.length);
        return dArray;
    }

    public double[] getCoalescentTimes(double[] dArray) {
        if (!this.intervalsKnown) {
            this.calculateIntervals();
        }
        if (dArray == null) {
            dArray = new double[this.getSampleCount()];
        }
        double d = 0.0;
        int n = 0;
        for (int i = 0; i < this.intervals.length; ++i) {
            d += this.intervals[i];
            for (int j = 0; j < this.getCoalescentEvents(i); ++j) {
                dArray[n] = d;
                ++n;
            }
        }
        return dArray;
    }

    @Override
    public int getLineageCount(int n) {
        if (!this.intervalsKnown) {
            this.calculateIntervals();
        }
        if (n >= this.intervalCount) {
            throw new IllegalArgumentException();
        }
        return this.lineageCounts[n];
    }

    @Override
    public int getCoalescentEvents(int n) {
        if (!this.intervalsKnown) {
            this.calculateIntervals();
        }
        if (n >= this.intervalCount) {
            throw new IllegalArgumentException();
        }
        if (n < this.intervalCount - 1) {
            return this.lineageCounts[n] - this.lineageCounts[n + 1];
        }
        return this.lineageCounts[n] - 1;
    }

    @Override
    public IntervalType getIntervalType(int n) {
        if (!this.intervalsKnown) {
            this.calculateIntervals();
        }
        if (n >= this.intervalCount) {
            throw new IllegalArgumentException();
        }
        int n2 = this.getCoalescentEvents(n);
        if (n2 > 0) {
            return IntervalType.COALESCENT;
        }
        if (n2 < 0) {
            return IntervalType.SAMPLE;
        }
        return IntervalType.NOTHING;
    }

    @Override
    public double getTotalDuration() {
        if (!this.intervalsKnown) {
            this.calculateIntervals();
        }
        double d = 0.0;
        for (int i = 0; i < this.intervalCount; ++i) {
            d += this.intervals[i];
        }
        return d;
    }

    @Override
    public boolean isBinaryCoalescent() {
        if (!this.intervalsKnown) {
            this.calculateIntervals();
        }
        for (int i = 0; i < this.intervalCount; ++i) {
            if (this.getCoalescentEvents(i) <= 0 || this.getCoalescentEvents(i) == 1) continue;
            return false;
        }
        return true;
    }

    @Override
    public boolean isCoalescentOnly() {
        if (!this.intervalsKnown) {
            this.calculateIntervals();
        }
        for (int i = 0; i < this.intervalCount; ++i) {
            if (this.getCoalescentEvents(i) >= 1) continue;
            return false;
        }
        return true;
    }

    protected void calculateIntervals() {
        Tree tree = this.treeInput.get();
        int n = tree.getNodeCount();
        this.times = new double[n];
        int[] nArray = new int[n];
        TreeIntervals.collectTimes(tree, this.times, nArray);
        this.indices = new int[n];
        HeapSort.sort(this.times, this.indices);
        if (this.intervals == null || this.intervals.length != n) {
            this.intervals = new double[n];
            this.lineageCounts = new int[n];
            this.lineagesAdded = new List[n];
            this.lineagesRemoved = new List[n];
            this.storedIntervals = new double[n];
            this.storedLineageCounts = new int[n];
        } else {
            for (List<Node> list : this.lineagesAdded) {
                if (list == null) continue;
                list.clear();
            }
            for (List<Node> list : this.lineagesRemoved) {
                if (list == null) continue;
                list.clear();
            }
        }
        double d = this.times[this.indices[0]];
        int n2 = 0;
        int n3 = 0;
        this.intervalCount = 0;
        while (n3 < n) {
            double d2;
            int n4 = 0;
            int n5 = 0;
            double d3 = this.times[this.indices[n3]];
            do {
                int n6 = this.indices[n3];
                int n7 = nArray[n6];
                ++n3;
                if (n7 == 0) {
                    this.addLineage(this.intervalCount, tree.getNode(n6));
                    ++n5;
                    continue;
                }
                n4 += n7 - 1;
                Node node = tree.getNode(n6);
                for (int i = 0; i < n7; ++i) {
                    Node node2 = i == 0 ? node.getLeft() : node.getRight();
                    this.removeLineage(this.intervalCount, node2);
                }
                this.addLineage(this.intervalCount, node);
                if (this.multifurcationLimit == 0.0) break;
            } while (n3 < n && Math.abs((d2 = this.times[this.indices[n3]]) - d3) <= this.multifurcationLimit);
            if (n5 > 0) {
                if (this.intervalCount > 0 || d3 - d > this.multifurcationLimit) {
                    this.intervals[this.intervalCount] = d3 - d;
                    this.lineageCounts[this.intervalCount] = n2;
                    ++this.intervalCount;
                }
                d = d3;
            }
            n2 += n5;
            if (n4 > 0) {
                this.intervals[this.intervalCount] = d3 - d;
                this.lineageCounts[this.intervalCount] = n2;
                ++this.intervalCount;
                d = d3;
            }
            n2 -= n4;
        }
        this.intervalsKnown = true;
    }

    public double getIntervalTime(int n) {
        if (!this.intervalsKnown) {
            this.calculateIntervals();
        }
        return this.times[this.indices[n]];
    }

    protected void addLineage(int n, Node node) {
        if (this.lineagesAdded[n] == null) {
            this.lineagesAdded[n] = new ArrayList<Node>();
        }
        this.lineagesAdded[n].add(node);
    }

    protected void removeLineage(int n, Node node) {
        if (this.lineagesRemoved[n] == null) {
            this.lineagesRemoved[n] = new ArrayList<Node>();
        }
        this.lineagesRemoved[n].add(node);
    }

    public double getDelta() {
        return IntervalList.Utils.getDelta(this);
    }

    protected static void collectTimes(Tree tree, double[] dArray, int[] nArray) {
        Node[] nodeArray = tree.getNodesAsArray();
        for (int i = 0; i < nodeArray.length; ++i) {
            Node node = nodeArray[i];
            dArray[i] = node.getHeight();
            nArray[i] = node.isLeaf() ? 0 : 2;
        }
    }
}

