/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.search.query.util;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.grouping.CollapseTopFieldDocs;
import org.apache.lucene.util.BytesRef;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.neuralsearch.search.collector.HybridSearchCollector;
import org.opensearch.neuralsearch.search.query.HybridCollectorResultsUtilParams;
import org.opensearch.neuralsearch.search.query.exception.HybridSearchRescoreQueryException;
import org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.rescore.RescoreContext;

public class HybridSearchCollectorResultUtil {
    @Generated
    private static final Logger log = LogManager.getLogger(HybridSearchCollectorResultUtil.class);
    private final HybridCollectorResultsUtilParams hybridSearchCollectorResultsDTO;
    private final HybridSearchCollector hybridSearchCollector;

    public void reduceCollectorResults(QuerySearchResult result, TopDocsAndMaxScore topDocsAndMaxScore) {
        if (result.hasConsumedTopDocs()) {
            result.topDocs(topDocsAndMaxScore, this.hybridSearchCollectorResultsDTO.getDocValueFormats());
            return;
        }
        if (topDocsAndMaxScore.topDocs.totalHits.value() == 0L) {
            return;
        }
        TopDocsAndMaxScore originalTotalDocsAndHits = result.topDocs();
        TopDocsAndMaxScore mergeTopDocsAndMaxScores = this.hybridSearchCollectorResultsDTO.getTopDocsMerger().merge(originalTotalDocsAndHits, topDocsAndMaxScore);
        result.topDocs(mergeTopDocsAndMaxScores, this.hybridSearchCollectorResultsDTO.getDocValueFormats());
    }

    public TopDocsAndMaxScore getTopDocsAndMaxScore() throws IOException {
        List<TopDocs> topDocs = this.hybridSearchCollector.topDocs();
        if (this.hybridSearchCollectorResultsDTO.isSortEnabled() || this.hybridSearchCollectorResultsDTO.isCollapseEnabled()) {
            return this.getSortedTopDocsAndMaxScore(topDocs);
        }
        return this.getTopDocsAndMaxScore(topDocs);
    }

    private TopDocsAndMaxScore getSortedTopDocsAndMaxScore(List<TopFieldDocs> topDocs) {
        SortField[] sortFields = this.hybridSearchCollectorResultsDTO.getSortFields();
        TopDocs sortedTopDocs = this.getNewTopFieldDocs(this.getTotalHits(topDocs), topDocs, sortFields);
        return new TopDocsAndMaxScore(sortedTopDocs, this.hybridSearchCollector.getMaxScore());
    }

    private TopDocs getNewTopFieldDocs(TotalHits totalHits, List<TopFieldDocs> topFieldDocs, SortField[] sortFields) {
        if (Objects.isNull(topFieldDocs)) {
            return new TopFieldDocs(totalHits, (ScoreDoc[])new FieldDoc[0], sortFields);
        }
        int delimiterDocId = topFieldDocs.stream().filter(Objects::nonNull).filter(topDoc -> Objects.nonNull(topDoc.scoreDocs)).map(topFieldDoc -> topFieldDoc.scoreDocs).filter(scoreDoc -> ((ScoreDoc[])scoreDoc).length > 0).map(scoreDoc -> scoreDoc[0].doc).findFirst().orElse(-1);
        if (delimiterDocId == -1) {
            SortField[] sortFieldArray;
            FieldDoc[] fieldDocArray = new FieldDoc[]{};
            if (sortFields == null) {
                SortField[] sortFieldArray2 = new SortField[1];
                sortFieldArray = sortFieldArray2;
                sortFieldArray2[0] = new SortField(null, SortField.Type.SCORE);
            } else {
                sortFieldArray = sortFields;
            }
            return new TopFieldDocs(totalHits, (ScoreDoc[])fieldDocArray, sortFieldArray);
        }
        if (this.hybridSearchCollectorResultsDTO.isCollapseEnabled()) {
            ArrayList<Object> collapseValues = new ArrayList<Object>();
            String collapseField = "";
            ArrayList fieldDocs = new ArrayList();
            ArrayList<Object> result = new ArrayList<Object>();
            Object[] fields = HybridSearchResultFormatUtil.createSortFieldsForDelimiterResults(topFieldDocs.getFirst().fields);
            result.add(HybridSearchResultFormatUtil.createFieldDocStartStopElementForHybridSearchResults(delimiterDocId, fields));
            collapseValues.add(new BytesRef(HybridSearchResultFormatUtil.createCollapseValueStartStopElementForHybridSearchResults()));
            for (TopDocs topDocs : topFieldDocs) {
                CollapseTopFieldDocs collapseTopFieldDoc = (CollapseTopFieldDocs)topDocs;
                collapseField = collapseTopFieldDoc.field;
                if (Objects.isNull(topDocs) || Objects.isNull(topDocs.scoreDocs)) {
                    result.add(HybridSearchResultFormatUtil.createFieldDocDelimiterElementForHybridSearchResults(delimiterDocId, fields));
                    continue;
                }
                ArrayList<FieldDoc> fieldDocsPerQuery = new ArrayList<FieldDoc>();
                for (ScoreDoc scoreDoc2 : collapseTopFieldDoc.scoreDocs) {
                    fieldDocsPerQuery.add((FieldDoc)scoreDoc2);
                }
                result.add(HybridSearchResultFormatUtil.createFieldDocDelimiterElementForHybridSearchResults(delimiterDocId, fields));
                result.addAll(fieldDocsPerQuery);
                collapseValues.add(new BytesRef(HybridSearchResultFormatUtil.createCollapseValueDelimiterElementForHybridSearchResults()));
                collapseValues.addAll(Arrays.asList(collapseTopFieldDoc.collapseValues));
            }
            result.add(HybridSearchResultFormatUtil.createFieldDocStartStopElementForHybridSearchResults(delimiterDocId, fields));
            collapseValues.add(new BytesRef(HybridSearchResultFormatUtil.createCollapseValueStartStopElementForHybridSearchResults()));
            fieldDocs.addAll(result);
            return new CollapseTopFieldDocs(collapseField, totalHits, (ScoreDoc[])fieldDocs.toArray(new FieldDoc[0]), topFieldDocs.getFirst().fields, collapseValues.toArray(new Object[0]));
        }
        Object[] sortFieldsForDelimiterResults = HybridSearchResultFormatUtil.createSortFieldsForDelimiterResults(sortFields);
        ArrayList<Object> result = new ArrayList<Object>();
        result.add(HybridSearchResultFormatUtil.createFieldDocStartStopElementForHybridSearchResults(delimiterDocId, sortFieldsForDelimiterResults));
        for (TopFieldDocs topFieldDoc2 : topFieldDocs) {
            if (Objects.isNull(topFieldDoc2) || Objects.isNull(topFieldDoc2.scoreDocs)) {
                result.add(HybridSearchResultFormatUtil.createFieldDocDelimiterElementForHybridSearchResults(delimiterDocId, sortFieldsForDelimiterResults));
                continue;
            }
            ArrayList<FieldDoc> fieldDocsPerQuery = new ArrayList<FieldDoc>();
            for (ScoreDoc scoreDoc3 : topFieldDoc2.scoreDocs) {
                fieldDocsPerQuery.add((FieldDoc)scoreDoc3);
            }
            result.add(HybridSearchResultFormatUtil.createFieldDocDelimiterElementForHybridSearchResults(delimiterDocId, sortFieldsForDelimiterResults));
            result.addAll(fieldDocsPerQuery);
        }
        result.add(HybridSearchResultFormatUtil.createFieldDocStartStopElementForHybridSearchResults(delimiterDocId, sortFieldsForDelimiterResults));
        FieldDoc[] fieldDocs = result.toArray(new FieldDoc[0]);
        return new TopFieldDocs(totalHits, (ScoreDoc[])fieldDocs, sortFields);
    }

    private TotalHits getTotalHits(List<?> topDocs) {
        TotalHits.Relation relation;
        TotalHits.Relation relation2 = relation = this.hybridSearchCollectorResultsDTO.getTrackTotalHitsUpTo() == -1 ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO : TotalHits.Relation.EQUAL_TO;
        if (topDocs == null || topDocs.isEmpty()) {
            return new TotalHits(0L, relation);
        }
        return new TotalHits((long)this.hybridSearchCollector.getTotalHits(), relation);
    }

    private TopDocsAndMaxScore getTopDocsAndMaxScore(List<TopDocs> topDocs) {
        if (this.shouldRescore()) {
            topDocs = this.rescore(topDocs);
        }
        float maxScore = this.calculateMaxScore(topDocs, this.hybridSearchCollector.getMaxScore());
        TopDocs finalTopDocs = this.getNewTopDocs(this.getTotalHits(topDocs), topDocs);
        return new TopDocsAndMaxScore(finalTopDocs, maxScore);
    }

    private TopDocs getNewTopDocs(TotalHits totalHits, List<TopDocs> topDocs) {
        boolean isCollapseEnabled = !topDocs.isEmpty() && topDocs.get(0) instanceof CollapseTopFieldDocs;
        ScoreDoc[] scoreDocs = new ScoreDoc[]{};
        ArrayList<Object> collapseValues = new ArrayList<Object>();
        String collapseField = "";
        ArrayList fieldDocs = new ArrayList();
        ArrayList<SortField> sortFields = new ArrayList<SortField>();
        if (Objects.nonNull(topDocs)) {
            int delimiterDocId = topDocs.stream().filter(Objects::nonNull).filter(topDoc -> Objects.nonNull(topDoc.scoreDocs)).map(topDoc -> topDoc.scoreDocs).filter(scoreDoc -> ((ScoreDoc[])scoreDoc).length > 0).map(scoreDoc -> scoreDoc[0].doc).findFirst().orElse(-1);
            if (delimiterDocId == -1) {
                return new TopDocs(totalHits, scoreDocs);
            }
            if (isCollapseEnabled) {
                ArrayList<Object> result = new ArrayList<Object>();
                Object[] fields = new Object[]{};
                result.add(HybridSearchResultFormatUtil.createFieldDocStartStopElementForHybridSearchResults(delimiterDocId, fields));
                collapseValues.add(0);
                for (TopDocs topDoc2 : topDocs) {
                    CollapseTopFieldDocs collapseTopFieldDoc = (CollapseTopFieldDocs)topDoc2;
                    collapseField = collapseTopFieldDoc.field;
                    sortFields.addAll(Arrays.asList(collapseTopFieldDoc.fields));
                    if (Objects.isNull(topDoc2) || Objects.isNull(topDoc2.scoreDocs)) {
                        result.add(HybridSearchResultFormatUtil.createFieldDocDelimiterElementForHybridSearchResults(delimiterDocId, fields));
                        continue;
                    }
                    ArrayList<FieldDoc> fieldDocsPerQuery = new ArrayList<FieldDoc>();
                    for (ScoreDoc scoreDoc2 : collapseTopFieldDoc.scoreDocs) {
                        fieldDocsPerQuery.add(new FieldDoc(scoreDoc2.doc, scoreDoc2.score, new Object[0]));
                    }
                    result.add(HybridSearchResultFormatUtil.createFieldDocDelimiterElementForHybridSearchResults(delimiterDocId, fields));
                    result.addAll(fieldDocsPerQuery);
                    collapseValues.add(0);
                    collapseValues.addAll(Arrays.asList(collapseTopFieldDoc.collapseValues));
                }
                result.add(HybridSearchResultFormatUtil.createFieldDocStartStopElementForHybridSearchResults(delimiterDocId, fields));
                collapseValues.add(0);
                fieldDocs.addAll(result);
            } else {
                ArrayList<ScoreDoc> result = new ArrayList<ScoreDoc>();
                result.add(HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults(delimiterDocId));
                for (TopDocs topDoc3 : topDocs) {
                    if (Objects.isNull(topDoc3) || Objects.isNull(topDoc3.scoreDocs)) {
                        result.add(HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults(delimiterDocId));
                        continue;
                    }
                    result.add(HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults(delimiterDocId));
                    result.addAll(Arrays.asList(topDoc3.scoreDocs));
                }
                result.add(HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults(delimiterDocId));
                scoreDocs = (ScoreDoc[])result.stream().map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new);
            }
        }
        if (isCollapseEnabled) {
            return new CollapseTopFieldDocs(collapseField, totalHits, (ScoreDoc[])fieldDocs.toArray(new FieldDoc[0]), sortFields.toArray(new SortField[0]), collapseValues.toArray(new Object[0]));
        }
        return new TopDocs(totalHits, scoreDocs);
    }

    private float calculateMaxScore(List<TopDocs> topDocsList, float initialMaxScore) {
        List<RescoreContext> rescoreContexts = this.hybridSearchCollectorResultsDTO.getRescoreContexts();
        if (Objects.nonNull(rescoreContexts) && !rescoreContexts.isEmpty()) {
            for (TopDocs topDocs : topDocsList) {
                if (!Objects.nonNull(topDocs.scoreDocs) || topDocs.scoreDocs.length <= 0) continue;
                initialMaxScore = Math.max(initialMaxScore, topDocs.scoreDocs[0].score);
            }
        }
        return initialMaxScore;
    }

    private boolean shouldRescore() {
        List<RescoreContext> rescoreContexts = this.hybridSearchCollectorResultsDTO.getRescoreContexts();
        return !CollectionUtils.isEmpty(rescoreContexts);
    }

    private List<TopDocs> rescore(List<TopDocs> topDocs) {
        List<TopDocs> rescoredTopDocs = topDocs;
        for (RescoreContext ctx : this.hybridSearchCollectorResultsDTO.getRescoreContexts()) {
            rescoredTopDocs = this.rescoredTopDocs(ctx, rescoredTopDocs);
        }
        return rescoredTopDocs;
    }

    private List<TopDocs> rescoredTopDocs(RescoreContext ctx, List<TopDocs> topDocs) {
        ArrayList<TopDocs> result = new ArrayList<TopDocs>(topDocs.size());
        for (TopDocs topDoc : topDocs) {
            try {
                result.add(ctx.rescorer().rescore(topDoc, (IndexSearcher)this.hybridSearchCollectorResultsDTO.getSearchContext().searcher(), ctx));
            }
            catch (IOException exception) {
                log.error("rescore failed for hybrid query in collector_manager.reduce call", (Throwable)exception);
                throw new HybridSearchRescoreQueryException(exception);
            }
        }
        return result;
    }

    @Generated
    public HybridSearchCollectorResultUtil(HybridCollectorResultsUtilParams hybridSearchCollectorResultsDTO, HybridSearchCollector hybridSearchCollector) {
        this.hybridSearchCollectorResultsDTO = hybridSearchCollectorResultsDTO;
        this.hybridSearchCollector = hybridSearchCollector;
    }
}

