/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.clustering.evaluation;

import java.util.ArrayList;
import java.util.List;
import java.util.function.BiFunction;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.ClusteringFactory;
import org.tribuo.evaluation.metrics.EvaluationMetric;
import org.tribuo.evaluation.metrics.MetricContext;
import org.tribuo.evaluation.metrics.MetricTarget;

public class ClusteringMetric
implements EvaluationMetric<ClusterID, Context> {
    private final MetricTarget<ClusterID> target;
    private final String name;
    private final BiFunction<MetricTarget<ClusterID>, Context, Double> impl;

    public ClusteringMetric(MetricTarget<ClusterID> target, String name, BiFunction<MetricTarget<ClusterID>, Context, Double> impl) {
        this.target = target;
        this.name = name;
        this.impl = impl;
    }

    public double compute(Context context) {
        return this.impl.apply(this.target, context);
    }

    public MetricTarget<ClusterID> getTarget() {
        return this.target;
    }

    public String getName() {
        return this.name;
    }

    public Context createContext(Model<ClusterID> model, List<Prediction<ClusterID>> predictions) {
        return ClusteringMetric.buildContext(model, predictions);
    }

    public String toString() {
        return "ClusteringMetric(target=" + this.target + ",name='" + this.name + '\'' + ')';
    }

    static Context buildContext(Model<ClusterID> model, List<Prediction<ClusterID>> predictions) {
        return new Context(model, predictions);
    }

    static final class Context
    extends MetricContext<ClusterID> {
        private final ArrayList<Integer> predictedIDs = new ArrayList();
        private final ArrayList<Integer> trueIDs = new ArrayList();

        Context(Model<ClusterID> model, List<Prediction<ClusterID>> predictions) {
            super(model, predictions);
            int i = 0;
            for (Prediction<ClusterID> pred : predictions) {
                if (((ClusterID)pred.getOutput()).equals(ClusteringFactory.UNASSIGNED_CLUSTER_ID)) {
                    throw new IllegalArgumentException("The sentinel unassigned cluster id was used as a ground truth output at prediction number " + i);
                }
                if (((ClusterID)pred.getExample().getOutput()).equals(ClusteringFactory.UNASSIGNED_CLUSTER_ID)) {
                    throw new IllegalArgumentException("The sentinel unassigned cluster id was predicted by the model at prediction number " + i);
                }
                this.predictedIDs.add(((ClusterID)pred.getOutput()).getID());
                this.trueIDs.add(((ClusterID)pred.getExample().getOutput()).getID());
                ++i;
            }
        }

        public ArrayList<Integer> getPredictedIDs() {
            return this.predictedIDs;
        }

        public ArrayList<Integer> getTrueIDs() {
            return this.trueIDs;
        }
    }
}

