/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.agent;

import java.security.PrivilegedActionException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.StepListener;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.agent.MLMemorySpec;
import org.opensearch.ml.common.agent.MLToolSpec;
import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.memory.Memory;
import org.opensearch.ml.common.spi.memory.Message;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.common.utils.ToolUtils;
import org.opensearch.ml.engine.algorithms.agent.AgentUtils;
import org.opensearch.ml.engine.algorithms.agent.MLAgentRunner;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
import org.opensearch.ml.engine.memory.ConversationIndexMessage;
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.transport.TransportChannel;
import org.opensearch.transport.client.Client;

public class MLConversationalFlowAgentRunner
implements MLAgentRunner {
    @Generated
    private static final Logger log = LogManager.getLogger(MLConversationalFlowAgentRunner.class);
    public static final String CHAT_HISTORY = "chat_history";
    private Client client;
    private Settings settings;
    private ClusterService clusterService;
    private NamedXContentRegistry xContentRegistry;
    private Map<String, Tool.Factory> toolFactories;
    private Map<String, Memory.Factory> memoryFactoryMap;
    private SdkClient sdkClient;
    private Encryptor encryptor;

    public MLConversationalFlowAgentRunner(Client client, Settings settings, ClusterService clusterService, NamedXContentRegistry xContentRegistry, Map<String, Tool.Factory> toolFactories, Map<String, Memory.Factory> memoryFactoryMap, SdkClient sdkClient, Encryptor encryptor) {
        this.client = client;
        this.settings = settings;
        this.clusterService = clusterService;
        this.xContentRegistry = xContentRegistry;
        this.toolFactories = toolFactories;
        this.memoryFactoryMap = memoryFactoryMap;
        this.sdkClient = sdkClient;
        this.encryptor = encryptor;
    }

    @Override
    public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Object> listener, TransportChannel channel) {
        String appType = mlAgent.getAppType();
        String memoryId = params.get("memory_id");
        String parentInteractionId = params.get("parent_interaction_id");
        if (appType == null || mlAgent.getMemory() == null) {
            this.runAgent(mlAgent, params, listener, null, memoryId, parentInteractionId);
            return;
        }
        String memoryType = mlAgent.getMemory().getType();
        String title = params.get("question");
        int messageHistoryLimit = AgentUtils.getMessageHistoryLimit(params);
        ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory)this.memoryFactoryMap.get(memoryType);
        conversationIndexMemoryFactory.create(title, memoryId, appType, (ActionListener<ConversationIndexMemory>)ActionListener.wrap(memory -> memory.getMessages(ActionListener.wrap(r -> {
            ArrayList<ConversationIndexMessage> messageList = new ArrayList<ConversationIndexMessage>();
            for (Interaction next : r) {
                String string = next.getInput();
                String response = next.getResponse();
                if (Strings.isNullOrEmpty((String)response)) continue;
                messageList.add(ConversationIndexMessage.conversationIndexMessageBuilder().sessionId(memory.getConversationId()).question(string).response(response).build());
            }
            StringBuilder chatHistoryBuilder = new StringBuilder();
            if (!messageList.isEmpty()) {
                chatHistoryBuilder.append("Below is Chat History between Human and AI which sorted by time with asc order:\n");
                for (Message message : messageList) {
                    chatHistoryBuilder.append(message.toString()).append("\n");
                }
                params.put(CHAT_HISTORY, chatHistoryBuilder.toString());
            }
            this.runAgent(mlAgent, params, listener, (ConversationIndexMemory)memory, memory.getConversationId(), parentInteractionId);
        }, e -> {
            log.error("Failed to get chat history", (Throwable)e);
            listener.onFailure(e);
        }), messageHistoryLimit), arg_0 -> listener.onFailure(arg_0)));
    }

    private void runAgent(MLAgent mlAgent, Map<String, String> params, ActionListener<Object> listener, ConversationIndexMemory memory, String memoryId, String parentInteractionId) {
        StepListener firstStepListener = null;
        Tool firstTool = null;
        ArrayList<ModelTensor> flowAgentOutput = new ArrayList<ModelTensor>();
        Map firstToolExecuteParams = null;
        StepListener previousStepListener = null;
        ConcurrentHashMap additionalInfo = new ConcurrentHashMap();
        List<MLToolSpec> toolSpecs = AgentUtils.getMlToolSpecs(mlAgent, params);
        if (toolSpecs == null || toolSpecs.isEmpty()) {
            listener.onFailure((Exception)new IllegalArgumentException("no tool configured"));
            return;
        }
        AtomicInteger traceNumber = new AtomicInteger(0);
        if (memory != null) {
            flowAgentOutput.add(ModelTensor.builder().name("memory_id").result(memoryId).build());
            flowAgentOutput.add(ModelTensor.builder().name("parent_message_id").result(parentInteractionId).build());
        }
        MLMemorySpec memorySpec = mlAgent.getMemory();
        for (int i = 0; i <= toolSpecs.size(); ++i) {
            if (i == 0) {
                MLToolSpec toolSpec = toolSpecs.get(i);
                firstToolExecuteParams = ToolUtils.buildToolParameters(params, (MLToolSpec)toolSpec, (String)mlAgent.getTenantId());
                Tool tool = AgentUtils.createTool(this.toolFactories, firstToolExecuteParams, toolSpec);
                previousStepListener = firstStepListener = new StepListener();
                firstTool = tool;
                continue;
            }
            MLToolSpec previousToolSpec = toolSpecs.get(i - 1);
            StepListener nextStepListener = new StepListener();
            int finalI = i;
            previousStepListener.whenComplete(output -> this.processOutput(params, listener, memory, memoryId, parentInteractionId, toolSpecs, flowAgentOutput, additionalInfo, traceNumber, memorySpec, previousToolSpec, finalI, output, mlAgent.getTenantId(), (StepListener<Object>)nextStepListener), e -> {
                log.error("Failed to run flow agent", (Throwable)e);
                listener.onFailure(e);
            });
            previousStepListener = nextStepListener;
        }
        if (toolSpecs.size() == 1) {
            firstTool.run(firstToolExecuteParams, ActionListener.wrap(output -> {
                MLToolSpec toolSpec = (MLToolSpec)toolSpecs.get(0);
                this.processOutput(params, listener, memory, memoryId, parentInteractionId, toolSpecs, flowAgentOutput, additionalInfo, traceNumber, memorySpec, toolSpec, 1, output, mlAgent.getTenantId(), null);
            }, e -> listener.onFailure(e)));
        } else {
            firstTool.run(firstToolExecuteParams, firstStepListener);
        }
    }

    private void processOutput(Map<String, String> params, ActionListener<Object> listener, ConversationIndexMemory memory, String memoryId, String parentInteractionId, List<MLToolSpec> toolSpecs, List<ModelTensor> flowAgentOutput, Map<String, Object> additionalInfo, AtomicInteger traceNumber, MLMemorySpec memorySpec, MLToolSpec previousToolSpec, int finalI, Object output, String tenantId, StepListener<Object> nextStepListener) throws PrivilegedActionException {
        boolean traceDisabled;
        String toolName = ToolUtils.getToolName((MLToolSpec)previousToolSpec);
        String outputKey = toolName + ".output";
        Map toolParameters = ToolUtils.buildToolParameters(params, (MLToolSpec)previousToolSpec, (String)tenantId);
        String filteredOutput = ToolUtils.parseResponse((Object)ToolUtils.filterToolOutput((Map)toolParameters, (Object)output));
        params.put(outputKey, StringUtils.prepareJsonValue((String)filteredOutput));
        boolean bl = traceDisabled = params.containsKey("disable_trace") && Boolean.parseBoolean(params.get("disable_trace"));
        if (previousToolSpec.isIncludeOutputInAgentResponse() || finalI == toolSpecs.size()) {
            if (toolParameters.containsKey("output_filter")) {
                flowAgentOutput.add(ModelTensor.builder().name(outputKey).result(filteredOutput).build());
            } else if (output instanceof ModelTensorOutput) {
                flowAgentOutput.addAll(((ModelTensors)((ModelTensorOutput)output).getMlModelOutputs().get(0)).getMlModelTensors());
            } else if (toolParameters.getOrDefault("return_data_as_map", "false").equalsIgnoreCase("true")) {
                flowAgentOutput.add(ToolUtils.convertOutputToModelTensor((Object)output, (String)outputKey));
            } else {
                ModelTensor stepOutput = ModelTensor.builder().name(toolName).result(StringUtils.toJson((Object)output)).build();
                flowAgentOutput.add(stepOutput);
            }
            if (memory == null) {
                additionalInfo.put(outputKey, filteredOutput);
            }
        }
        if (finalI == toolSpecs.size()) {
            ActionListener updateListener = ActionListener.wrap(r -> {
                log.info("Updated additional info for interaction {} of flow agent.", (Object)r.getId());
                listener.onResponse((Object)flowAgentOutput);
            }, e -> {
                log.error("Failed to update root interaction", (Throwable)e);
                listener.onResponse((Object)flowAgentOutput);
            });
            if (memory == null) {
                if (memoryId == null || parentInteractionId == null || memorySpec == null || memorySpec.getType() == null) {
                    listener.onResponse(flowAgentOutput);
                } else {
                    this.updateMemoryWithListener(additionalInfo, memorySpec, memoryId, parentInteractionId, updateListener);
                }
            } else {
                this.saveMessage(params, memory, filteredOutput, memoryId, parentInteractionId, toolName, traceNumber, traceDisabled, ActionListener.wrap(r -> {
                    log.info("saved last trace for interaction " + parentInteractionId + " of flow agent");
                    Map<String, Object> updateContent = Map.of("response", filteredOutput, "additional_info", additionalInfo);
                    memory.update(parentInteractionId, updateContent, (ActionListener<UpdateResponse>)updateListener);
                }, e -> {
                    log.error("Failed to update root interaction ", (Throwable)e);
                    listener.onFailure(e);
                }));
            }
        } else if (memory == null) {
            this.runNextStep(params, toolSpecs, finalI, tenantId, nextStepListener);
        } else {
            this.saveMessage(params, memory, filteredOutput, memoryId, parentInteractionId, toolName, traceNumber, traceDisabled, ActionListener.wrap(r -> this.runNextStep(params, toolSpecs, finalI, tenantId, nextStepListener), e -> {
                log.error("Failed to update root interaction ", (Throwable)e);
                listener.onFailure(e);
            }));
        }
    }

    private void runNextStep(Map<String, String> params, List<MLToolSpec> toolSpecs, int finalI, String tenantId, StepListener<Object> nextStepListener) {
        MLToolSpec toolSpec = toolSpecs.get(finalI);
        Map toolExecutionParameters = ToolUtils.buildToolParameters(params, (MLToolSpec)toolSpec, (String)tenantId);
        Tool tool = AgentUtils.createTool(this.toolFactories, toolExecutionParameters, toolSpec);
        if (finalI < toolSpecs.size()) {
            tool.run(toolExecutionParameters, nextStepListener);
        }
    }

    private void saveMessage(Map<String, String> params, ConversationIndexMemory memory, String outputResponse, String memoryId, String parentInteractionId, String toolName, AtomicInteger traceNumber, boolean traceDisabled, ActionListener listener) {
        ConversationIndexMessage finalMessage = ConversationIndexMessage.conversationIndexMessageBuilder().type(memory.getType()).question(params.get("question")).response(outputResponse).finalAnswer(true).sessionId(memoryId).build();
        if (traceDisabled) {
            listener.onResponse((Object)true);
        } else {
            memory.save(finalMessage, parentInteractionId, traceNumber.addAndGet(1), toolName, listener);
        }
    }

    @VisibleForTesting
    void updateMemoryWithListener(Map<String, Object> additionalInfo, MLMemorySpec memorySpec, String memoryId, String interactionId, ActionListener listener) {
        if (memoryId == null || interactionId == null || memorySpec == null || memorySpec.getType() == null) {
            return;
        }
        ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory)this.memoryFactoryMap.get(memorySpec.getType());
        conversationIndexMemoryFactory.create(memoryId, (ActionListener<ConversationIndexMemory>)ActionListener.wrap(memory -> memory.update(interactionId, Map.of("additional_info", additionalInfo), (ActionListener<UpdateResponse>)listener), e -> log.error("Failed create memory from id: " + memoryId, (Throwable)e)));
    }

    @Generated
    public Client getClient() {
        return this.client;
    }

    @Generated
    public Settings getSettings() {
        return this.settings;
    }

    @Generated
    public ClusterService getClusterService() {
        return this.clusterService;
    }

    @Generated
    public NamedXContentRegistry getXContentRegistry() {
        return this.xContentRegistry;
    }

    @Generated
    public Map<String, Tool.Factory> getToolFactories() {
        return this.toolFactories;
    }

    @Generated
    public Map<String, Memory.Factory> getMemoryFactoryMap() {
        return this.memoryFactoryMap;
    }

    @Generated
    public SdkClient getSdkClient() {
        return this.sdkClient;
    }

    @Generated
    public Encryptor getEncryptor() {
        return this.encryptor;
    }

    @Generated
    public void setClient(Client client) {
        this.client = client;
    }

    @Generated
    public void setSettings(Settings settings) {
        this.settings = settings;
    }

    @Generated
    public void setClusterService(ClusterService clusterService) {
        this.clusterService = clusterService;
    }

    @Generated
    public void setXContentRegistry(NamedXContentRegistry xContentRegistry) {
        this.xContentRegistry = xContentRegistry;
    }

    @Generated
    public void setToolFactories(Map<String, Tool.Factory> toolFactories) {
        this.toolFactories = toolFactories;
    }

    @Generated
    public void setMemoryFactoryMap(Map<String, Memory.Factory> memoryFactoryMap) {
        this.memoryFactoryMap = memoryFactoryMap;
    }

    @Generated
    public void setSdkClient(SdkClient sdkClient) {
        this.sdkClient = sdkClient;
    }

    @Generated
    public void setEncryptor(Encryptor encryptor) {
        this.encryptor = encryptor;
    }

    @Generated
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof MLConversationalFlowAgentRunner)) {
            return false;
        }
        MLConversationalFlowAgentRunner other = (MLConversationalFlowAgentRunner)o;
        if (!other.canEqual(this)) {
            return false;
        }
        Client this$client = this.getClient();
        Client other$client = other.getClient();
        if (this$client == null ? other$client != null : !this$client.equals(other$client)) {
            return false;
        }
        Settings this$settings = this.getSettings();
        Settings other$settings = other.getSettings();
        if (this$settings == null ? other$settings != null : !this$settings.equals(other$settings)) {
            return false;
        }
        ClusterService this$clusterService = this.getClusterService();
        ClusterService other$clusterService = other.getClusterService();
        if (this$clusterService == null ? other$clusterService != null : !this$clusterService.equals(other$clusterService)) {
            return false;
        }
        NamedXContentRegistry this$xContentRegistry = this.getXContentRegistry();
        NamedXContentRegistry other$xContentRegistry = other.getXContentRegistry();
        if (this$xContentRegistry == null ? other$xContentRegistry != null : !this$xContentRegistry.equals(other$xContentRegistry)) {
            return false;
        }
        Map<String, Tool.Factory> this$toolFactories = this.getToolFactories();
        Map<String, Tool.Factory> other$toolFactories = other.getToolFactories();
        if (this$toolFactories == null ? other$toolFactories != null : !((Object)this$toolFactories).equals(other$toolFactories)) {
            return false;
        }
        Map<String, Memory.Factory> this$memoryFactoryMap = this.getMemoryFactoryMap();
        Map<String, Memory.Factory> other$memoryFactoryMap = other.getMemoryFactoryMap();
        if (this$memoryFactoryMap == null ? other$memoryFactoryMap != null : !((Object)this$memoryFactoryMap).equals(other$memoryFactoryMap)) {
            return false;
        }
        SdkClient this$sdkClient = this.getSdkClient();
        SdkClient other$sdkClient = other.getSdkClient();
        if (this$sdkClient == null ? other$sdkClient != null : !this$sdkClient.equals(other$sdkClient)) {
            return false;
        }
        Encryptor this$encryptor = this.getEncryptor();
        Encryptor other$encryptor = other.getEncryptor();
        return !(this$encryptor == null ? other$encryptor != null : !this$encryptor.equals(other$encryptor));
    }

    @Generated
    protected boolean canEqual(Object other) {
        return other instanceof MLConversationalFlowAgentRunner;
    }

    @Generated
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        Client $client = this.getClient();
        result = result * 59 + ($client == null ? 43 : $client.hashCode());
        Settings $settings = this.getSettings();
        result = result * 59 + ($settings == null ? 43 : $settings.hashCode());
        ClusterService $clusterService = this.getClusterService();
        result = result * 59 + ($clusterService == null ? 43 : $clusterService.hashCode());
        NamedXContentRegistry $xContentRegistry = this.getXContentRegistry();
        result = result * 59 + ($xContentRegistry == null ? 43 : $xContentRegistry.hashCode());
        Map<String, Tool.Factory> $toolFactories = this.getToolFactories();
        result = result * 59 + ($toolFactories == null ? 43 : ((Object)$toolFactories).hashCode());
        Map<String, Memory.Factory> $memoryFactoryMap = this.getMemoryFactoryMap();
        result = result * 59 + ($memoryFactoryMap == null ? 43 : ((Object)$memoryFactoryMap).hashCode());
        SdkClient $sdkClient = this.getSdkClient();
        result = result * 59 + ($sdkClient == null ? 43 : $sdkClient.hashCode());
        Encryptor $encryptor = this.getEncryptor();
        result = result * 59 + ($encryptor == null ? 43 : $encryptor.hashCode());
        return result;
    }

    @Generated
    public String toString() {
        return "MLConversationalFlowAgentRunner(client=" + String.valueOf(this.getClient()) + ", settings=" + String.valueOf(this.getSettings()) + ", clusterService=" + String.valueOf(this.getClusterService()) + ", xContentRegistry=" + String.valueOf(this.getXContentRegistry()) + ", toolFactories=" + String.valueOf(this.getToolFactories()) + ", memoryFactoryMap=" + String.valueOf(this.getMemoryFactoryMap()) + ", sdkClient=" + String.valueOf(this.getSdkClient()) + ", encryptor=" + String.valueOf(this.getEncryptor()) + ")";
    }

    @Generated
    public MLConversationalFlowAgentRunner() {
    }
}

