/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.timeseries.ml;

import com.amazon.randomcutforest.parkservices.AnomalyDescriptor;
import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest;
import java.time.Clock;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Deque;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.TreeSet;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.Message;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.action.support.ThreadedActionListener;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.timeseries.AnalysisType;
import org.opensearch.timeseries.MaintenanceState;
import org.opensearch.timeseries.caching.CacheProvider;
import org.opensearch.timeseries.caching.TimeSeriesCache;
import org.opensearch.timeseries.common.exception.TimeSeriesException;
import org.opensearch.timeseries.feature.SearchFeatureDao;
import org.opensearch.timeseries.indices.IndexManagement;
import org.opensearch.timeseries.ml.CheckpointDao;
import org.opensearch.timeseries.ml.IntermediateResult;
import org.opensearch.timeseries.ml.ModelColdStart;
import org.opensearch.timeseries.ml.ModelManager;
import org.opensearch.timeseries.ml.ModelState;
import org.opensearch.timeseries.ml.Sample;
import org.opensearch.timeseries.model.Config;
import org.opensearch.timeseries.model.Entity;
import org.opensearch.timeseries.model.IndexableResult;
import org.opensearch.timeseries.model.IntervalTimeConfiguration;
import org.opensearch.timeseries.model.TaskType;
import org.opensearch.timeseries.model.TimeSeriesTask;
import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker;
import org.opensearch.timeseries.ratelimit.ColdStartWorker;
import org.opensearch.timeseries.ratelimit.FeatureRequest;
import org.opensearch.timeseries.ratelimit.RateLimitedRequestWorker;
import org.opensearch.timeseries.ratelimit.RequestPriority;
import org.opensearch.timeseries.ratelimit.SaveResultStrategy;
import org.opensearch.timeseries.stats.Stats;
import org.opensearch.timeseries.task.TaskCacheManager;
import org.opensearch.timeseries.task.TaskManager;
import org.opensearch.timeseries.util.ExpiringValue;
import org.opensearch.timeseries.util.ModelUtil;

public abstract class RealTimeInferencer<RCFModelType extends ThresholdedRandomCutForest, ResultType extends IndexableResult, RCFResultType extends IntermediateResult<ResultType>, IndexType extends Enum<IndexType>, IndexManagementType extends IndexManagement<IndexType>, CheckpointDaoType extends CheckpointDao<RCFModelType, IndexType, IndexManagementType>, CheckpointWriterType extends CheckpointWriteWorker<RCFModelType, IndexType, IndexManagementType, CheckpointDaoType>, ColdStarterType extends ModelColdStart<RCFModelType, IndexType, IndexManagementType, ResultType>, ModelManagerType extends ModelManager<RCFModelType, ResultType, RCFResultType, IndexType, IndexManagementType, CheckpointDaoType, ColdStarterType>, SaveResultStrategyType extends SaveResultStrategy<ResultType, RCFResultType>, CacheType extends TimeSeriesCache<RCFModelType>, TaskCacheManagerType extends TaskCacheManager, TaskTypeEnum extends TaskType, TaskClass extends TimeSeriesTask, TaskManagerType extends TaskManager<TaskCacheManagerType, TaskTypeEnum, TaskClass, IndexType, IndexManagementType>, ColdStartWorkerType extends ColdStartWorker<RCFModelType, IndexType, IndexManagementType, CheckpointDaoType, CheckpointWriterType, ColdStarterType, CacheType, ResultType, RCFResultType, ModelManagerType, SaveResultStrategyType, TaskCacheManagerType, TaskTypeEnum, TaskClass, TaskManagerType>>
implements MaintenanceState {
    private static final Logger LOG = LogManager.getLogger(RealTimeInferencer.class);
    protected ModelManagerType modelManager;
    protected Stats stats;
    private String modelCorruptionStat;
    protected CheckpointDaoType checkpointDao;
    protected ColdStartWorkerType coldStartWorker;
    protected SaveResultStrategyType resultWriteWorker;
    private CacheProvider<RCFModelType, CacheType> cache;
    private Map<String, ExpiringValue<Lock>> modelLocks;
    private ThreadPool threadPool;
    private String threadPoolName;
    private Map<String, ExpiringValue<TreeSet<Sample>>> sampleQueues;
    private Comparator<Sample> sampleComparator;
    private Clock clock;
    private SearchFeatureDao searchFeatureDao;
    private AnalysisType analysisContext;

    public RealTimeInferencer(ModelManagerType modelManager, Stats stats, String modelCorruptionStat, CheckpointDaoType checkpointDao, ColdStartWorkerType coldStartWorker, SaveResultStrategyType resultWriteWorker, CacheProvider<RCFModelType, CacheType> cache, ThreadPool threadPool, String threadPoolName, Clock clock, SearchFeatureDao searchFeatureDao, AnalysisType analysisContext) {
        this.modelManager = modelManager;
        this.stats = stats;
        this.modelCorruptionStat = modelCorruptionStat;
        this.checkpointDao = checkpointDao;
        this.coldStartWorker = coldStartWorker;
        this.resultWriteWorker = resultWriteWorker;
        this.cache = cache;
        this.threadPool = threadPool;
        this.threadPoolName = threadPoolName;
        this.modelLocks = new ConcurrentHashMap<String, ExpiringValue<Lock>>();
        this.sampleQueues = new ConcurrentHashMap<String, ExpiringValue<TreeSet<Sample>>>();
        this.sampleComparator = Comparator.comparing(Sample::getDataEndTime);
        this.clock = clock;
        this.searchFeatureDao = searchFeatureDao;
        this.analysisContext = analysisContext;
    }

    public void process(Sample sample, ModelState<RCFModelType> modelState, Config config, String taskId, ActionListener<Boolean> listener) {
        String modelId = modelState.getModelId();
        ExpiringValue expiringSampleQueue = this.sampleQueues.computeIfAbsent(modelId, k -> new ExpiringValue<TreeSet<Sample>>(new TreeSet<Sample>(this.sampleComparator), config.getIntervalDuration().multipliedBy(60L).toMillis(), this.clock));
        TreeSet queue = (TreeSet)expiringSampleQueue.getValue();
        this.addSamples(queue, modelState.getSamples(), config);
        Instant lastSampleDataEndTime = queue.isEmpty() ? Instant.MIN : ((Sample)queue.last()).getDataEndTime();
        this.addSample(queue, sample, config);
        Optional<RCFModelType> modelOptional = modelState.getModel();
        if (modelOptional.isPresent()) {
            long lastInputTimestampSecs = Math.max(ModelUtil.getLastInputTimestampSeconds((ThresholdedRandomCutForest)modelOptional.get()), lastSampleDataEndTime.getEpochSecond());
            long currentTimeSecs = sample.getDataEndTime().getEpochSecond();
            long diffSecs = currentTimeSecs - lastInputTimestampSecs;
            LOG.debug("diffSecs:{} interval:{} maxFrequencyMultiple:{} lastInputTimestampSecs:{} currentTimeSecs:{}", (Object)diffSecs, (Object)config.getIntervalInSeconds(), (Object)10000, (Object)lastInputTimestampSecs, (Object)currentTimeSecs);
            long minGapSecs = 2L * config.getIntervalInSeconds();
            if (diffSecs >= minGapSecs && diffSecs / config.getIntervalInSeconds() <= 10000L) {
                LOG.info("fetching features between {} and {}", (Object)lastInputTimestampSecs, (Object)currentTimeSecs);
                this.getFeatures(config, modelState.getEntity(), lastInputTimestampSecs * 1000L, sample.getDataStartTime().getEpochSecond() * 1000L, (ActionListener<List<Sample>>)ActionListener.wrap(samples -> {
                    LOG.info("samples size: {}", (Object)samples.size());
                    for (Sample s : samples) {
                        this.addSample(queue, s, config);
                    }
                    this.processWithTimeout(modelState, config, taskId, sample, listener);
                }, arg_0 -> listener.onFailure(arg_0)));
            } else if (diffSecs < minGapSecs) {
                this.processWithTimeout(modelState, config, taskId, sample, listener);
            } else {
                LOG.warn("Time gap {} is too large for config [{}], model [{}]. Triggering cold start.", (Object)diffSecs, (Object)config.getId(), (Object)modelId);
                this.reColdStart(config, modelId, null, sample, taskId);
                listener.onResponse((Object)false);
            }
        } else {
            LOG.warn("Model not present for config [{}], model [{}]. Skipping.", (Object)config.getId(), (Object)modelId);
            listener.onResponse((Object)false);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void processWithTimeout(ModelState<RCFModelType> modelState, Config config, String taskId, Sample sample, ActionListener<Boolean> listener) {
        String modelId = modelState.getModelId();
        ReentrantLock lock = (ReentrantLock)this.modelLocks.computeIfAbsent(modelId, k -> new ExpiringValue<ReentrantLock>(new ReentrantLock(), config.getIntervalDuration().multipliedBy(60L).toMillis(), this.clock)).getValue();
        boolean success = false;
        LOG.debug("try lock");
        if (lock.tryLock()) {
            LOG.debug("lock acquired");
            try {
                TreeSet<Sample> queue = this.sampleQueues.get(modelId).getValue();
                LOG.debug("queue size:{}", (Object)queue.size());
                if (!queue.isEmpty()) {
                    ArrayList<Sample> samples = new ArrayList<Sample>(queue);
                    queue.clear();
                    double[][] points = new double[samples.size()][];
                    long[] timestamps = new long[samples.size()];
                    ArrayList<Instant> dataStarts = new ArrayList<Instant>();
                    ArrayList<Instant> dataEnds = new ArrayList<Instant>();
                    for (int i = 0; i < samples.size(); ++i) {
                        points[i] = ((Sample)samples.get(i)).getValueList();
                        Instant dataStart = ((Sample)samples.get(i)).getDataStartTime();
                        dataStarts.add(dataStart);
                        Instant dataEnd = ((Sample)samples.get(i)).getDataEndTime();
                        dataEnds.add(dataEnd);
                        timestamps[i] = dataEnd.getEpochSecond();
                    }
                    ThresholdedRandomCutForest model = (ThresholdedRandomCutForest)modelState.getModel().get();
                    LOG.debug("Processing sequential points - timestamps: {}, entity: {}", (Object)Arrays.toString(timestamps), (Object)modelState.getEntity().map(Object::toString).orElse("null"));
                    List results = model.processSequentially((double[][])points, timestamps, x -> true);
                    ArrayList intermediateResults = new ArrayList();
                    for (int i = 0; i < results.size(); ++i) {
                        Sample sampleI = (Sample)samples.get(i);
                        AnomalyDescriptor result = (AnomalyDescriptor)results.get(i);
                        Object rcfResult = ((ModelManager)this.modelManager).toResult(model.getForest(), (AnomalyDescriptor)result, sampleI.getValueList(), result.getMissingValues() != null, config);
                        intermediateResults.add(rcfResult);
                    }
                    this.resultWriteWorker.saveAllResults(intermediateResults, config, dataStarts, dataEnds, modelId, Arrays.asList(points), modelState.getEntity(), taskId);
                    success = true;
                }
                listener.onResponse((Object)success);
            }
            catch (Exception e) {
                LOG.error("Error processing samples", (Throwable)e);
                if (e.getMessage() != null && e.getMessage().contains("incorrect ordering of time")) {
                    LOG.warn(String.format(Locale.ROOT, "incorrect ordering of time for config %s model %s at data end time %d", config.getId(), modelState.getModelId(), sample.getDataEndTime().toEpochMilli()));
                } else {
                    this.reColdStart(config, modelId, e, sample, taskId);
                }
                listener.onFailure(e);
            }
            finally {
                LOG.debug("unlock");
                if (lock.isHeldByCurrentThread()) {
                    lock.unlock();
                }
            }
        } else {
            long windowDelayMillis = config.getWindowDelay() == null ? 0L : ((IntervalTimeConfiguration)config.getWindowDelay()).toDuration().toMillis();
            long curExecutionEnd = sample.getDataEndTime().toEpochMilli() + windowDelayMillis;
            long nextExecutionEnd = curExecutionEnd + config.getIntervalInMilliseconds();
            if (this.clock.millis() >= nextExecutionEnd) {
                LOG.warn("Timeout reached, not retrying.");
                listener.onResponse((Object)false);
            } else {
                try {
                    LOG.debug("Scheduling a retry in one second for model [{}], config [{}], taskId [{}], sample data end time [{}], next execution end time [{}], current time [{}], window delay [{}], interval in milliseconds [{}]", (Object)modelId, (Object)config.getId(), (Object)taskId, (Object)sample.getDataEndTime().toEpochMilli(), (Object)nextExecutionEnd, (Object)this.clock.millis(), (Object)windowDelayMillis, (Object)config.getIntervalInMilliseconds());
                    this.threadPool.schedule(() -> this.processWithTimeout(modelState, config, taskId, sample, listener), new TimeValue(1L, TimeUnit.SECONDS), this.threadPoolName);
                }
                catch (Exception e) {
                    LOG.error("Failed to schedule retry", (Throwable)e);
                    listener.onFailure(e);
                }
            }
        }
    }

    public void reColdStart(Config config, String modelId, Exception e, Sample sample, String taskId) {
        if (e != null) {
            LOG.error((Message)new ParameterizedMessage("Likely model corruption for [{}]", (Object)modelId), (Throwable)e);
        } else {
            LOG.warn((Message)new ParameterizedMessage("Likely model corruption for [{}]", (Object)modelId));
        }
        this.stats.getStat(this.modelCorruptionStat).increment();
        this.cache.get().removeModel(config.getId(), modelId);
        if (null != modelId) {
            ((CheckpointDao)this.checkpointDao).deleteModelCheckpoint(modelId, (ActionListener<Void>)ActionListener.wrap(r -> LOG.debug((Message)new ParameterizedMessage("Succeeded in deleting checkpoint [{}].", (Object)modelId)), ex -> LOG.error((Message)new ParameterizedMessage("Failed to delete checkpoint [{}].", (Object)modelId), (Throwable)ex)));
        }
        ((RateLimitedRequestWorker)this.coldStartWorker).put((FeatureRequest)new FeatureRequest(this.clock.millis() + config.getInferredFrequencyInMilliseconds(), config.getId(), RequestPriority.MEDIUM, modelId, sample.getValueList(), sample.getDataStartTime().toEpochMilli(), taskId));
    }

    private void getFeatures(Config config, Optional<Entity> entity, long startTimeMs, long endTimeMs, ActionListener<List<Sample>> listener) {
        if (startTimeMs == 0L || startTimeMs >= endTimeMs || endTimeMs - startTimeMs < config.getIntervalInMilliseconds()) {
            listener.onResponse(new ArrayList());
            return;
        }
        int numberOfSamples = (int)Math.floor((double)(endTimeMs - startTimeMs) / (double)config.getIntervalInMilliseconds());
        if (numberOfSamples > 10000) {
            listener.onResponse(new ArrayList());
            return;
        }
        List<Map.Entry<Long, Long>> sampleRanges = this.searchFeatureDao.getTrainSampleRanges((IntervalTimeConfiguration)config.getInterval(), startTimeMs, endTimeMs, numberOfSamples);
        if (sampleRanges.isEmpty()) {
            listener.onResponse(new ArrayList());
            return;
        }
        ActionListener getFeatureListener = ActionListener.wrap(featureSamples -> {
            int totalNumSamples = featureSamples.size();
            if (totalNumSamples != sampleRanges.size()) {
                String err = String.format(Locale.ROOT, "length mismatch: totalNumSamples %d != time range length %d", totalNumSamples, sampleRanges.size());
                listener.onFailure((Exception)new IllegalArgumentException(err));
                return;
            }
            ArrayList<Sample> samples = new ArrayList<Sample>();
            for (int index = 0; index < featureSamples.size(); ++index) {
                Optional featuresOptional = (Optional)featureSamples.get(index);
                Map.Entry curRange = (Map.Entry)sampleRanges.get(index);
                if (!featuresOptional.isPresent()) continue;
                samples.add(new Sample((double[])featuresOptional.get(), Instant.ofEpochMilli((Long)curRange.getKey()), Instant.ofEpochMilli((Long)curRange.getValue())));
            }
            listener.onResponse(samples);
        }, arg_0 -> listener.onFailure(arg_0));
        try {
            this.searchFeatureDao.getColdStartSamplesForPeriods(config, sampleRanges, entity, true, this.analysisContext, (ActionListener<List<Optional<double[]>>>)new ThreadedActionListener(LOG, this.threadPool, this.threadPoolName, getFeatureListener, false));
        }
        catch (Exception e) {
            listener.onFailure(e);
        }
    }

    private void addSample(TreeSet<Sample> queue, Sample sample, Config config) {
        boolean nextGapOk;
        long intervalSeconds = config.getIntervalInSeconds();
        long sampleTime = sample.getDataStartTime().getEpochSecond();
        Sample previousSample = queue.floor(sample);
        Sample nextSample = queue.ceiling(sample);
        LOG.debug("lastSample: {} sample: {}", previousSample == null ? "null" : Long.valueOf(previousSample.getDataStartTime().getEpochSecond()), (Object)sample.getDataStartTime().getEpochSecond());
        boolean previousGapOk = previousSample == null || sampleTime - previousSample.getDataStartTime().getEpochSecond() >= intervalSeconds;
        boolean bl = nextGapOk = nextSample == null || nextSample.getDataStartTime().getEpochSecond() - sampleTime >= intervalSeconds;
        if (previousGapOk && nextGapOk) {
            queue.add(sample);
        }
    }

    private void addSamples(TreeSet<Sample> queue, Deque<Sample> samples, Config config) {
        if (samples != null) {
            for (Sample sample : samples) {
                this.addSample(queue, sample, config);
            }
        }
    }

    @Override
    public void maintenance() {
        try {
            this.modelLocks.entrySet().removeIf(entry -> ((ExpiringValue)entry.getValue()).isExpired());
            this.sampleQueues.entrySet().removeIf(entry -> ((ExpiringValue)entry.getValue()).isExpired());
        }
        catch (Exception e) {
            throw new TimeSeriesException("Fail to maintain RealTimeInferencer", e);
        }
    }

    public Map<String, ExpiringValue<Lock>> getModelLocks() {
        return this.modelLocks;
    }

    public Map<String, ExpiringValue<TreeSet<Sample>>> getSampleQueues() {
        return this.sampleQueues;
    }
}

