/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.searchpipelines.questionanswering.generative.llm;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.searchpipelines.questionanswering.generative.client.MachineLearningInternalClient;
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionInput;
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput;
import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm;
import org.opensearch.searchpipelines.questionanswering.generative.prompt.PromptUtil;
import org.opensearch.transport.client.Client;

public class DefaultLlmImpl
implements Llm {
    @Generated
    private static final Logger log = LogManager.getLogger(DefaultLlmImpl.class);
    private static final String CONNECTOR_INPUT_PARAMETER_MODEL = "model";
    private static final String CONNECTOR_INPUT_PARAMETER_MESSAGES = "messages";
    private static final String CONNECTOR_OUTPUT_CHOICES = "choices";
    private static final String CONNECTOR_OUTPUT_MESSAGE = "message";
    private static final String CONNECTOR_OUTPUT_MESSAGE_ROLE = "role";
    private static final String CONNECTOR_OUTPUT_MESSAGE_CONTENT = "content";
    private static final String CONNECTOR_OUTPUT_ERROR = "error";
    private final String openSearchModelId;
    private MachineLearningInternalClient mlClient;

    public DefaultLlmImpl(String openSearchModelId, Client client) {
        Preconditions.checkNotNull((Object)openSearchModelId);
        this.openSearchModelId = openSearchModelId;
        this.mlClient = new MachineLearningInternalClient(client);
    }

    @VisibleForTesting
    protected void setMlClient(MachineLearningInternalClient mlClient) {
        this.mlClient = mlClient;
    }

    @Override
    public void doChatCompletion(final ChatCompletionInput chatCompletionInput, final ActionListener<ChatCompletionOutput> listener) {
        RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(this.getInputParameters(chatCompletionInput)).build();
        MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset((MLInputDataset)dataset).build();
        this.mlClient.predict(this.openSearchModelId, mlInput, new ActionListener<MLOutput>(){

            public void onResponse(MLOutput mlOutput) {
                Map dataAsMap = ((ModelTensor)((ModelTensors)((ModelTensorOutput)mlOutput).getMlModelOutputs().get(0)).getMlModelTensors().get(0)).getDataAsMap();
                listener.onResponse((Object)DefaultLlmImpl.this.buildChatCompletionOutput(chatCompletionInput.getModelProvider(), dataAsMap, chatCompletionInput.getLlmResponseField()));
            }

            public void onFailure(Exception e) {
                listener.onFailure(e);
            }
        });
    }

    protected Map<String, String> getInputParameters(ChatCompletionInput chatCompletionInput) {
        HashMap<String, String> inputParameters = new HashMap<String, String>();
        if (chatCompletionInput.getModelProvider() == Llm.ModelProvider.OPENAI) {
            inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel());
            String messages = PromptUtil.getChatCompletionPrompt(chatCompletionInput.getModelProvider(), chatCompletionInput.getSystemPrompt(), chatCompletionInput.getUserInstructions(), chatCompletionInput.getQuestion(), chatCompletionInput.getChatHistory(), chatCompletionInput.getContexts(), chatCompletionInput.getLlmMessages());
            inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages);
        } else if (chatCompletionInput.getModelProvider() == Llm.ModelProvider.BEDROCK || chatCompletionInput.getModelProvider() == Llm.ModelProvider.COHERE || chatCompletionInput.getLlmResponseField() != null) {
            inputParameters.put("inputs", PromptUtil.buildSingleStringPrompt(chatCompletionInput.getSystemPrompt(), chatCompletionInput.getUserInstructions(), chatCompletionInput.getQuestion(), chatCompletionInput.getChatHistory(), chatCompletionInput.getContexts()));
        } else if (chatCompletionInput.getModelProvider() == Llm.ModelProvider.BEDROCK_CONVERSE) {
            String messages = PromptUtil.getChatCompletionPrompt(chatCompletionInput.getModelProvider(), null, chatCompletionInput.getUserInstructions(), chatCompletionInput.getQuestion(), chatCompletionInput.getChatHistory(), chatCompletionInput.getContexts(), chatCompletionInput.getLlmMessages());
            inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages);
        } else {
            throw new IllegalArgumentException("Unknown/unsupported model provider: " + String.valueOf((Object)chatCompletionInput.getModelProvider()) + ".  You must provide a valid model provider or llm_response_field.");
        }
        return inputParameters;
    }

    protected ChatCompletionOutput buildChatCompletionOutput(Llm.ModelProvider provider, Map<String, ?> dataAsMap, String llmResponseField) {
        List<Object> answers = new ArrayList<Object>();
        List<String> errors = new ArrayList<String>();
        String answerField = null;
        String errorField = CONNECTOR_OUTPUT_ERROR;
        String defaultErrorMessageField = CONNECTOR_OUTPUT_MESSAGE;
        if (llmResponseField != null) {
            answerField = llmResponseField;
            this.fillAnswersOrErrors(dataAsMap, answers, errors, answerField, errorField, defaultErrorMessageField);
        } else if (provider == Llm.ModelProvider.OPENAI) {
            List choices = (List)dataAsMap.get(CONNECTOR_OUTPUT_CHOICES);
            if (choices == null) {
                Map error = (Map)dataAsMap.get(CONNECTOR_OUTPUT_ERROR);
                errors = List.of((String)error.get(CONNECTOR_OUTPUT_MESSAGE));
            } else {
                Map firstChoiceMap = (Map)choices.get(0);
                Map message = (Map)firstChoiceMap.get(CONNECTOR_OUTPUT_MESSAGE);
                answers = List.of(message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT));
            }
        } else if (provider == Llm.ModelProvider.BEDROCK) {
            answerField = "completion";
            this.fillAnswersOrErrors(dataAsMap, answers, errors, answerField, errorField, defaultErrorMessageField);
        } else if (provider == Llm.ModelProvider.COHERE) {
            answerField = "text";
            this.fillAnswersOrErrors(dataAsMap, answers, errors, answerField, errorField, defaultErrorMessageField);
        } else if (provider == Llm.ModelProvider.BEDROCK_CONVERSE) {
            Map output = (Map)dataAsMap.get("output");
            Map message = (Map)output.get(CONNECTOR_OUTPUT_MESSAGE);
            if (message != null) {
                List content = (List)message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT);
                String answer = (String)((Map)content.get(0)).get("text");
                answers.add(answer);
            } else {
                Map error = (Map)output.get(CONNECTOR_OUTPUT_ERROR);
                if (error == null) {
                    throw new RuntimeException("Unexpected output: " + String.valueOf(output));
                }
                errors.add((String)error.get(CONNECTOR_OUTPUT_MESSAGE));
            }
        } else {
            throw new IllegalArgumentException("Unknown/unsupported model provider: " + String.valueOf((Object)provider) + ".  You must provide a valid model provider or llm_response_field.");
        }
        return new ChatCompletionOutput(answers, errors);
    }

    private void fillAnswersOrErrors(Map<String, ?> dataAsMap, List<Object> answers, List<String> errors, String answerField, String errorField, String defaultErrorMessageField) {
        String response = (String)dataAsMap.get(answerField);
        if (response != null) {
            answers.add(response);
        } else {
            Map error = (Map)dataAsMap.get(errorField);
            if (error != null && error.get(defaultErrorMessageField) != null) {
                errors.add((String)error.get(defaultErrorMessageField));
            } else {
                errors.add("Unknown error or response.");
                log.error("{}", dataAsMap);
            }
        }
    }
}

