/*
 * Decompiled with CFR 0.152.
 */
package beast.util;

import beast.core.Description;
import beast.core.Input;
import beast.core.StateNode;
import beast.core.StateNodeInitialiser;
import beast.core.parameter.RealParameter;
import beast.evolution.alignment.Alignment;
import beast.evolution.alignment.TaxonSet;
import beast.evolution.alignment.distance.Distance;
import beast.evolution.alignment.distance.JukesCantorDistance;
import beast.evolution.tree.Node;
import beast.evolution.tree.Tree;
import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Locale;
import java.util.PriorityQueue;

@Description(value="Create initial beast.tree by hierarchical clustering, either through one of the classic link methods or by neighbor joining. The following link methods are supported: <br/>o single link, <br/>o complete link, <br/>o UPGMA=average link, <br/>o mean link, <br/>o centroid, <br/>o Ward and <br/>o adjusted complete link <br/>o neighborjoining <br/>o neighborjoining2 - corrects tree for tip data, unlike plain neighborjoining")
public class ClusterTree
extends Tree
implements StateNodeInitialiser {
    double EPSILON = 1.0E-10;
    public final Input<Type> clusterTypeInput = new Input<Type>("clusterType", "type of clustering algorithm used for generating initial beast.tree. Should be one of " + Arrays.toString((Object[])Type.values()) + " (default " + (Object)((Object)Type.average) + ")", Type.average, Type.values());
    public final Input<Alignment> dataInput = new Input("taxa", "alignment data used for calculating distances for clustering");
    public final Input<Distance> distanceInput = new Input("distance", "method for calculating distance between two sequences (default Jukes Cantor)");
    public final Input<RealParameter> clockRateInput = new Input<RealParameter>("clock.rate", "the clock rate parameter, used to divide all divergence times by, to convert from substitutions to times. (default 1.0)", new RealParameter(new Double[]{1.0}));
    protected boolean distanceIsBranchLength = false;
    Distance distance;
    List<String> taxaNames;
    Type linkType = Type.single;

    @Override
    public void initAndValidate() {
        RealParameter realParameter = this.clockRateInput.get();
        if (this.dataInput.get() != null) {
            this.taxaNames = this.dataInput.get().getTaxaNames();
        } else {
            if (this.m_taxonset.get() == null) {
                throw new RuntimeException("At least one of taxa and taxonset input needs to be specified");
            }
            this.taxaNames = ((TaxonSet)this.m_taxonset.get()).asStringList();
        }
        if (Boolean.valueOf(System.getProperty("beast.resume")).booleanValue() && (((Boolean)this.isEstimatedInput.get()).booleanValue() || this.m_initial.get() != null && ((Boolean)((Tree)this.m_initial.get()).isEstimatedInput.get()).booleanValue())) {
            Node node = this.newNode();
            node.setNr(0);
            node.setID(this.taxaNames.get(0));
            node.setHeight(0.0);
            for (int i = 1; i < this.taxaNames.size(); ++i) {
                Node node2 = this.newNode();
                node2.setNr(i);
                node2.setID(this.taxaNames.get(i));
                node2.setHeight(0.0);
                Node node3 = this.newNode();
                node3.setNr(this.taxaNames.size() + i - 1);
                node3.setHeight(i);
                node.setParent(node3);
                node3.setLeft(node);
                node2.setParent(node3);
                node3.setRight(node2);
                node = node3;
            }
            this.root = node;
            this.leafNodeCount = this.taxaNames.size();
            this.nodeCount = this.leafNodeCount * 2 - 1;
            this.internalNodeCount = this.leafNodeCount - 1;
            super.initAndValidate();
            return;
        }
        this.distance = this.distanceInput.get();
        if (this.distance == null) {
            this.distance = new JukesCantorDistance();
        }
        if (this.distance instanceof Distance.Base) {
            if (this.dataInput.get() == null) {
                // empty if block
            }
            ((Distance.Base)this.distance).setPatterns(this.dataInput.get());
        }
        this.linkType = this.clusterTypeInput.get();
        if (this.linkType == Type.upgma) {
            this.linkType = Type.average;
        }
        if (this.linkType == Type.neighborjoining || this.linkType == Type.neighborjoining2) {
            this.distanceIsBranchLength = true;
        }
        Node node = this.buildClusterer();
        this.setRoot(node);
        node.labelInternalNodes((this.getNodeCount() + 1) / 2);
        super.initAndValidate();
        if (this.linkType == Type.neighborjoining2) {
            Node[] nodeArray = this.getNodesAsArray();
            for (int i = 0; i < this.getLeafNodeCount(); ++i) {
                nodeArray[i].setHeight(0.0);
            }
            super.initAndValidate();
        }
        if (this.m_initial.get() != null) {
            this.processTraits(((Tree)this.m_initial.get()).m_traitList.get());
        } else {
            this.processTraits((List)this.m_traitList.get());
        }
        if (this.timeTraitSet != null) {
            this.adjustTreeNodeHeights(node);
        } else {
            for (int i = 0; i < this.getLeafNodeCount(); ++i) {
                this.getNode(i).setHeight(0.0);
            }
        }
        for (Node node4 : this.getInternalNodes()) {
            double d = node4.getHeight();
            node4.setHeight(d / realParameter.getValue());
        }
        this.initStateNodes();
    }

    double distance(int n, int n2) {
        return this.distance.pairwiseDistance(n, n2);
    }

    double distance(double[] dArray, double[] dArray2) {
        double d = 0.0;
        for (int i = 0; i < this.dataInput.get().getPatternCount(); ++i) {
            d += (double)this.dataInput.get().getPatternWeight(i) * Math.abs(dArray[i] - dArray2[i]);
        }
        return d / (double)this.dataInput.get().getSiteCount();
    }

    public Node buildClusterer() {
        int n;
        int n2 = this.taxaNames.size();
        if (n2 == 1) {
            Node node = this.newNode();
            node.setHeight(1.0);
            node.setNr(0);
            return node;
        }
        ArrayList[] arrayListArray = new ArrayList[n2];
        for (n = 0; n < n2; ++n) {
            arrayListArray[n] = new ArrayList();
            arrayListArray[n].add(n);
        }
        n = n2;
        NodeX[] nodeXArray = new NodeX[n2];
        if (this.linkType == Type.neighborjoining || this.linkType == Type.neighborjoining2) {
            this.neighborJoining(n, arrayListArray, nodeXArray);
        } else {
            this.doLinkClustering(n, arrayListArray, nodeXArray);
        }
        for (int i = 0; i < n2; ++i) {
            if (arrayListArray[i].size() <= 0) continue;
            return nodeXArray[i].toNode();
        }
        return null;
    }

    void neighborJoining(int n, List<Integer>[] listArray, NodeX[] nodeXArray) {
        double d;
        int n2;
        int n3 = this.taxaNames.size();
        double[][] dArray = new double[n][n];
        for (int i = 0; i < n; ++i) {
            dArray[i][i] = 0.0;
            for (int j = i + 1; j < n; ++j) {
                dArray[i][j] = this.getDistance0(listArray[i], listArray[j]);
                dArray[j][i] = dArray[i][j];
            }
        }
        double[] dArray2 = new double[n3];
        double[] dArray3 = new double[n3];
        int[] nArray = new int[n3];
        for (n2 = 0; n2 < n3; ++n2) {
            double d2 = 0.0;
            for (int i = 0; i < n3; ++i) {
                d2 += dArray[n2][i];
            }
            dArray2[n2] = d2;
            dArray3[n2] = d2 / (double)(n - 2);
            nArray[n2] = n2 + 1;
        }
        while (n > 2) {
            n2 = -1;
            int n4 = -1;
            d = Double.MAX_VALUE;
            int n5 = 0;
            while (n5 < n3) {
                double d3 = dArray3[n5];
                double[] dArray4 = dArray[n5];
                int n6 = nArray[n5];
                while (n6 < n3) {
                    double d4 = dArray3[n6];
                    double d5 = dArray4[n6] - d3 - d4;
                    if (d5 < d) {
                        n2 = n5;
                        n4 = n6;
                        d = d5;
                    }
                    n6 = nArray[n6];
                }
                n5 = nArray[n5];
            }
            double d6 = dArray[n2][n4];
            double d7 = dArray3[n2];
            double d8 = dArray3[n4];
            double d9 = 0.5 * d6 + 0.5 * (d7 - d8);
            double d10 = 0.5 * d6 + 0.5 * (d8 - d7);
            if (--n > 2) {
                int n7;
                double d11 = 0.0;
                double d12 = dArray[n2][n4];
                double[] dArray5 = dArray[n2];
                double[] dArray6 = dArray[n4];
                for (n7 = 0; n7 < n3; ++n7) {
                    if (n7 == n2 || n7 == n4 || listArray[n7].size() == 0) {
                        dArray5[n7] = 0.0;
                        continue;
                    }
                    double d13 = dArray5[n7];
                    double d14 = dArray6[n7];
                    double d15 = (d13 + d14 - d12) / 2.0;
                    d11 += d15;
                    int n8 = n7;
                    dArray2[n8] = dArray2[n8] + (d15 - d13 - d14);
                    dArray3[n7] = dArray2[n7] / (double)(n - 2);
                    dArray5[n7] = d15;
                    dArray[n7][n2] = d15;
                }
                dArray2[n2] = d11;
                dArray3[n2] = d11 / (double)(n - 2);
                dArray2[n4] = 0.0;
                this.merge(n2, n4, d9, d10, listArray, nodeXArray);
                n7 = n4;
                while (listArray[n7].size() == 0) {
                    --n7;
                }
                nArray[n7] = nArray[n4];
                continue;
            }
            this.merge(n2, n4, d9, d10, listArray, nodeXArray);
            break;
        }
        block9: for (n2 = 0; n2 < n3; ++n2) {
            if (listArray[n2].size() <= 0) continue;
            for (int i = n2 + 1; i < n3; ++i) {
                if (listArray[i].size() <= 0) continue;
                d = dArray[n2][i];
                if (listArray[n2].size() == 1) {
                    this.merge(n2, i, d, 0.0, listArray, nodeXArray);
                    continue block9;
                }
                if (listArray[i].size() == 1) {
                    this.merge(n2, i, 0.0, d, listArray, nodeXArray);
                    continue block9;
                }
                this.merge(n2, i, d / 2.0, d / 2.0, listArray, nodeXArray);
                continue block9;
            }
        }
    }

    void doLinkClustering(int n, List<Integer>[] listArray, NodeX[] nodeXArray) {
        int n2;
        int n3;
        int n4 = this.taxaNames.size();
        PriorityQueue<Tuple> priorityQueue = new PriorityQueue<Tuple>(n * n / 2, new TupleComparator());
        double[][] dArray = new double[n][n];
        for (n3 = 0; n3 < n; ++n3) {
            dArray[n3][n3] = 0.0;
            for (n2 = n3 + 1; n2 < n; ++n2) {
                dArray[n3][n2] = this.getDistance0(listArray[n3], listArray[n2]);
                dArray[n2][n3] = dArray[n3][n2];
                priorityQueue.add(new Tuple(dArray[n3][n2], n3, n2, 1, 1));
            }
        }
        while (n > 1) {
            Tuple tuple;
            n3 = -1;
            n2 = -1;
            while ((tuple = priorityQueue.poll()) != null && (listArray[tuple.m_iCluster1].size() != tuple.m_nClusterSize1 || listArray[tuple.m_iCluster2].size() != tuple.m_nClusterSize2)) {
            }
            n3 = tuple.m_iCluster1;
            n2 = tuple.m_iCluster2;
            this.merge(n3, n2, tuple.m_fDist / 2.0, tuple.m_fDist / 2.0, listArray, nodeXArray);
            for (int i = 0; i < n4; ++i) {
                if (i == n3 || listArray[i].size() == 0) continue;
                int n5 = Math.min(n3, i);
                int n6 = Math.max(n3, i);
                double d = this.getDistance(dArray, listArray[n5], listArray[n6]);
                priorityQueue.add(new Tuple(d, n5, n6, listArray[n5].size(), listArray[n6].size()));
            }
            --n;
        }
    }

    void merge(int n, int n2, double d, double d2, List<Integer>[] listArray, NodeX[] nodeXArray) {
        if (n > n2) {
            int n3 = n;
            n = n2;
            n2 = n3;
            double d3 = d;
            d = d2;
            d2 = d3;
        }
        listArray[n].addAll(listArray[n2]);
        listArray[n2].removeAll(listArray[n2]);
        NodeX nodeX = new NodeX();
        if (nodeXArray[n] == null) {
            nodeX.m_iLeftInstance = n;
        } else {
            nodeX.m_left = nodeXArray[n];
            nodeXArray[n].m_parent = nodeX;
        }
        if (nodeXArray[n2] == null) {
            nodeX.m_iRightInstance = n2;
        } else {
            nodeX.m_right = nodeXArray[n2];
            nodeXArray[n2].m_parent = nodeX;
        }
        if (this.distanceIsBranchLength) {
            nodeX.setLength(d, d2);
        } else {
            nodeX.setHeight(d, d2);
        }
        nodeXArray[n] = nodeX;
    }

    double getDistance0(List<Integer> list, List<Integer> list2) {
        double d = Double.MAX_VALUE;
        switch (this.linkType) {
            case single: 
            case neighborjoining: 
            case neighborjoining2: 
            case centroid: 
            case complete: 
            case adjcomplete: 
            case average: 
            case mean: {
                d = this.distance(list.get(0), list2.get(0));
                break;
            }
            case ward: {
                double d2 = this.calcESS(list);
                double d3 = this.calcESS(list2);
                ArrayList<Integer> arrayList = new ArrayList<Integer>();
                arrayList.addAll(list);
                arrayList.addAll(list2);
                double d4 = this.calcESS(arrayList);
                d = d4 * (double)arrayList.size() - d2 * (double)list.size() - d3 * (double)list2.size();
                break;
            }
        }
        return d;
    }

    double getDistance(double[][] dArray, List<Integer> list, List<Integer> list2) {
        double d = Double.MAX_VALUE;
        switch (this.linkType) {
            case single: {
                d = Double.MAX_VALUE;
                for (int i = 0; i < list.size(); ++i) {
                    int n = list.get(i);
                    for (int j = 0; j < list2.size(); ++j) {
                        int n2 = list2.get(j);
                        double d2 = dArray[n][n2];
                        if (!(d > d2)) continue;
                        d = d2;
                    }
                }
                break;
            }
            case complete: 
            case adjcomplete: {
                double d3;
                int n;
                int n3;
                int n4;
                d = 0.0;
                for (int i = 0; i < list.size(); ++i) {
                    int n5 = list.get(i);
                    for (n4 = 0; n4 < list2.size(); ++n4) {
                        n3 = list2.get(n4);
                        double d4 = dArray[n5][n3];
                        if (!(d < d4)) continue;
                        d = d4;
                    }
                }
                if (this.linkType == Type.complete) break;
                double d5 = 0.0;
                for (n4 = 0; n4 < list.size(); ++n4) {
                    n3 = list.get(n4);
                    for (int i = n4 + 1; i < list.size(); ++i) {
                        n = list.get(i);
                        d3 = dArray[n3][n];
                        if (!(d5 < d3)) continue;
                        d5 = d3;
                    }
                }
                for (n4 = 0; n4 < list2.size(); ++n4) {
                    n3 = list2.get(n4);
                    for (int i = n4 + 1; i < list2.size(); ++i) {
                        n = list2.get(i);
                        d3 = dArray[n3][n];
                        if (!(d5 < d3)) continue;
                        d5 = d3;
                    }
                }
                d -= d5;
                break;
            }
            case average: {
                d = 0.0;
                for (int i = 0; i < list.size(); ++i) {
                    int n = list.get(i);
                    for (int j = 0; j < list2.size(); ++j) {
                        int n6 = list2.get(j);
                        d += dArray[n][n6];
                    }
                }
                d /= (double)(list.size() * list2.size());
                break;
            }
            case mean: {
                int n;
                ArrayList<Integer> arrayList = new ArrayList<Integer>();
                arrayList.addAll(list);
                arrayList.addAll(list2);
                d = 0.0;
                for (n = 0; n < arrayList.size(); ++n) {
                    int n7 = (Integer)arrayList.get(n);
                    for (int i = n + 1; i < arrayList.size(); ++i) {
                        int n8 = (Integer)arrayList.get(i);
                        d += dArray[n7][n8];
                    }
                }
                n = arrayList.size();
                d /= (double)n * ((double)n - 1.0) / 2.0;
                break;
            }
            case centroid: {
                int n;
                int n9;
                int n10 = this.dataInput.get().getPatternCount();
                double[] dArray2 = new double[n10];
                for (int i = 0; i < list.size(); ++i) {
                    n9 = list.get(i);
                    for (n = 0; n < n10; ++n) {
                        int n11 = n;
                        dArray2[n11] = dArray2[n11] + (double)this.dataInput.get().getPattern(n9, n);
                    }
                }
                double[] dArray3 = new double[n10];
                for (n9 = 0; n9 < list2.size(); ++n9) {
                    n = list2.get(n9);
                    for (int i = 0; i < n10; ++i) {
                        int n12 = i;
                        dArray3[n12] = dArray3[n12] + (double)this.dataInput.get().getPattern(n, i);
                    }
                }
                n9 = 0;
                while (n9 < n10) {
                    int n13 = n9;
                    dArray2[n13] = dArray2[n13] / (double)list.size();
                    int n14 = n9++;
                    dArray3[n14] = dArray3[n14] / (double)list2.size();
                }
                d = this.distance(dArray2, dArray3);
                break;
            }
            case ward: {
                double d6 = this.calcESS(list);
                double d7 = this.calcESS(list2);
                ArrayList<Integer> arrayList = new ArrayList<Integer>();
                arrayList.addAll(list);
                arrayList.addAll(list2);
                double d8 = this.calcESS(arrayList);
                d = d8 * (double)arrayList.size() - d6 * (double)list.size() - d7 * (double)list2.size();
                break;
            }
        }
        return d;
    }

    double calcESS(List<Integer> list) {
        int n;
        int n2;
        int n3 = this.dataInput.get().getPatternCount();
        double[] dArray = new double[n3];
        for (n2 = 0; n2 < list.size(); ++n2) {
            int n4 = list.get(n2);
            for (n = 0; n < n3; ++n) {
                int n5 = n;
                dArray[n5] = dArray[n5] + (double)this.dataInput.get().getPattern(n4, n);
            }
        }
        n2 = 0;
        while (n2 < n3) {
            int n6 = n2++;
            dArray[n6] = dArray[n6] / (double)list.size();
        }
        double d = 0.0;
        for (n = 0; n < list.size(); ++n) {
            double[] dArray2 = new double[n3];
            int n7 = list.get(n);
            for (int i = 0; i < n3; ++i) {
                int n8 = i;
                dArray2[n8] = dArray2[n8] + (double)this.dataInput.get().getPattern(n7, i);
            }
            d += this.distance(dArray, dArray2);
        }
        return d / (double)list.size();
    }

    @Override
    public void initStateNodes() {
        if (this.m_initial.get() != null) {
            ((Tree)this.m_initial.get()).assignFromWithoutID(this);
        }
    }

    @Override
    public void getInitialisedStateNodes(List<StateNode> list) {
        if (this.m_initial.get() != null) {
            list.add((StateNode)this.m_initial.get());
        }
    }

    class TupleComparator
    implements Comparator<Tuple> {
        TupleComparator() {
        }

        @Override
        public int compare(Tuple tuple, Tuple tuple2) {
            if (tuple.m_fDist < tuple2.m_fDist) {
                return -1;
            }
            if (tuple.m_fDist == tuple2.m_fDist) {
                return 0;
            }
            return 1;
        }
    }

    class Tuple {
        double m_fDist;
        int m_iCluster1;
        int m_iCluster2;
        int m_nClusterSize1;
        int m_nClusterSize2;

        public Tuple(double d, int n, int n2, int n3, int n4) {
            this.m_fDist = d;
            this.m_iCluster1 = n;
            this.m_iCluster2 = n2;
            this.m_nClusterSize1 = n3;
            this.m_nClusterSize2 = n4;
        }
    }

    class NodeX {
        NodeX m_left;
        NodeX m_right;
        NodeX m_parent;
        int m_iLeftInstance;
        int m_iRightInstance;
        double m_fLeftLength = 0.0;
        double m_fRightLength = 0.0;
        double m_fHeight = 0.0;

        NodeX() {
        }

        void setHeight(double d, double d2) {
            if (d < ClusterTree.this.EPSILON) {
                d = ClusterTree.this.EPSILON;
            }
            if (d2 < ClusterTree.this.EPSILON) {
                d2 = ClusterTree.this.EPSILON;
            }
            this.m_fHeight = d;
            this.m_fLeftLength = this.m_left == null ? d : d - this.m_left.m_fHeight;
            this.m_fRightLength = this.m_right == null ? d2 : d2 - this.m_right.m_fHeight;
        }

        void setLength(double d, double d2) {
            if (d < ClusterTree.this.EPSILON) {
                d = ClusterTree.this.EPSILON;
            }
            if (d2 < ClusterTree.this.EPSILON) {
                d2 = ClusterTree.this.EPSILON;
            }
            this.m_fLeftLength = d;
            this.m_fRightLength = d2;
            this.m_fHeight = d;
            if (this.m_left != null) {
                this.m_fHeight += this.m_left.m_fHeight;
            }
        }

        public String toString() {
            DecimalFormat decimalFormat = new DecimalFormat("#.#####", new DecimalFormatSymbols(Locale.US));
            if (this.m_left == null) {
                if (this.m_right == null) {
                    return "(" + ClusterTree.this.taxaNames.get(this.m_iLeftInstance) + ":" + decimalFormat.format(this.m_fLeftLength) + "," + ClusterTree.this.taxaNames.get(this.m_iRightInstance) + ":" + decimalFormat.format(this.m_fRightLength) + ")";
                }
                return "(" + ClusterTree.this.taxaNames.get(this.m_iLeftInstance) + ":" + decimalFormat.format(this.m_fLeftLength) + "," + this.m_right.toString() + ":" + decimalFormat.format(this.m_fRightLength) + ")";
            }
            if (this.m_right == null) {
                return "(" + this.m_left.toString() + ":" + decimalFormat.format(this.m_fLeftLength) + "," + ClusterTree.this.taxaNames.get(this.m_iRightInstance) + ":" + decimalFormat.format(this.m_fRightLength) + ")";
            }
            return "(" + this.m_left.toString() + ":" + decimalFormat.format(this.m_fLeftLength) + "," + this.m_right.toString() + ":" + decimalFormat.format(this.m_fRightLength) + ")";
        }

        Node toNode() {
            Node node = ClusterTree.this.newNode();
            node.setHeight(this.m_fHeight);
            if (this.m_left == null) {
                node.setLeft(ClusterTree.this.newNode());
                node.getLeft().setNr(this.m_iLeftInstance);
                node.getLeft().setID(ClusterTree.this.taxaNames.get(this.m_iLeftInstance));
                node.getLeft().setHeight(this.m_fHeight - this.m_fLeftLength);
                if (this.m_right == null) {
                    node.setRight(ClusterTree.this.newNode());
                    node.getRight().setNr(this.m_iRightInstance);
                    node.getRight().setID(ClusterTree.this.taxaNames.get(this.m_iRightInstance));
                    node.getRight().setHeight(this.m_fHeight - this.m_fRightLength);
                } else {
                    node.setRight(this.m_right.toNode());
                }
            } else {
                node.setLeft(this.m_left.toNode());
                if (this.m_right == null) {
                    node.setRight(ClusterTree.this.newNode());
                    node.getRight().setNr(this.m_iRightInstance);
                    node.getRight().setID(ClusterTree.this.taxaNames.get(this.m_iRightInstance));
                    node.getRight().setHeight(this.m_fHeight - this.m_fRightLength);
                } else {
                    node.setRight(this.m_right.toNode());
                }
            }
            if (node.getHeight() < node.getLeft().getHeight() + ClusterTree.this.EPSILON) {
                node.setHeight(node.getLeft().getHeight() + ClusterTree.this.EPSILON);
            }
            if (node.getHeight() < node.getRight().getHeight() + ClusterTree.this.EPSILON) {
                node.setHeight(node.getRight().getHeight() + ClusterTree.this.EPSILON);
            }
            node.getRight().setParent(node);
            node.getLeft().setParent(node);
            return node;
        }
    }

    public static enum Type {
        single,
        average,
        complete,
        upgma,
        mean,
        centroid,
        ward,
        adjcomplete,
        neighborjoining,
        neighborjoining2;

    }
}

