/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.memorycontainer.memory;

import java.time.Instant;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.DocWriteRequest;
import org.opensearch.action.bulk.BulkItemResponse;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.delete.DeleteRequest;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.action.memorycontainer.memory.MemoryInfo;
import org.opensearch.ml.common.memorycontainer.MLMemory;
import org.opensearch.ml.common.memorycontainer.MemoryDecision;
import org.opensearch.ml.common.memorycontainer.MemoryStorageConfig;
import org.opensearch.ml.common.memorycontainer.MemoryType;
import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesInput;
import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesResponse;
import org.opensearch.ml.common.transport.memorycontainer.memory.MemoryEvent;
import org.opensearch.ml.common.transport.memorycontainer.memory.MemoryResult;
import org.opensearch.ml.helper.MemoryEmbeddingHelper;
import org.opensearch.transport.client.Client;

public class MemoryOperationsService {
    @Generated
    private static final Logger log = LogManager.getLogger(MemoryOperationsService.class);
    private final Client client;
    private final MemoryEmbeddingHelper memoryEmbeddingHelper;

    public MemoryOperationsService(Client client, MemoryEmbeddingHelper memoryEmbeddingHelper) {
        this.client = client;
        this.memoryEmbeddingHelper = memoryEmbeddingHelper;
    }

    public void executeMemoryOperations(List<MemoryDecision> decisions, String indexName, String sessionId, User user, MLAddMemoriesInput input, MemoryStorageConfig storageConfig, ActionListener<List<MemoryResult>> listener) {
        ArrayList<MemoryResult> results = new ArrayList<MemoryResult>();
        ArrayList<IndexRequest> addRequests = new ArrayList<IndexRequest>();
        ArrayList<UpdateRequest> updateRequests = new ArrayList<UpdateRequest>();
        ArrayList<DeleteRequest> deleteRequests = new ArrayList<DeleteRequest>();
        Instant now = Instant.now();
        for (MemoryDecision decision : decisions) {
            switch (decision.getEvent()) {
                case ADD: {
                    MLMemory mLMemory = MLMemory.builder().sessionId(sessionId).memory(decision.getText()).memoryType(MemoryType.FACT).userId(user != null ? user.getName() : null).agentId(input.getAgentId()).tags(input.getTags()).createdTime(now).lastUpdatedTime(now).build();
                    IndexRequest addRequest = new IndexRequest(indexName).source(mLMemory.toIndexMap());
                    addRequests.add(addRequest);
                    results.add(MemoryResult.builder().memoryId(null).memory(decision.getText()).event(MemoryEvent.ADD).oldMemory(null).build());
                    break;
                }
                case UPDATE: {
                    HashMap<String, Object> updateDoc = new HashMap<String, Object>();
                    updateDoc.put("memory", decision.getText());
                    updateDoc.put("last_updated_time", now.toEpochMilli());
                    UpdateRequest updateRequest = new UpdateRequest(indexName, decision.getId()).doc(updateDoc);
                    updateRequests.add(updateRequest);
                    results.add(MemoryResult.builder().memoryId(decision.getId()).memory(decision.getText()).event(MemoryEvent.UPDATE).oldMemory(decision.getOldMemory()).build());
                    break;
                }
                case DELETE: {
                    DeleteRequest deleteRequest = new DeleteRequest(indexName, decision.getId());
                    deleteRequests.add(deleteRequest);
                    results.add(MemoryResult.builder().memoryId(decision.getId()).memory(decision.getText()).event(MemoryEvent.DELETE).oldMemory(null).build());
                    break;
                }
            }
        }
        BulkRequest bulkRequest = new BulkRequest().setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
        for (IndexRequest indexRequest : addRequests) {
            bulkRequest.add(indexRequest);
        }
        for (UpdateRequest updateRequest : updateRequests) {
            bulkRequest.add(updateRequest);
        }
        for (DeleteRequest deleteRequest : deleteRequests) {
            bulkRequest.add(deleteRequest);
        }
        if (bulkRequest.requests().isEmpty()) {
            log.debug("No memory operations to execute");
            listener.onResponse(results);
            return;
        }
        this.client.bulk(bulkRequest, ActionListener.wrap(bulkResponse -> {
            if (bulkResponse.hasFailures()) {
                log.error("Bulk memory operations had failures: {}", (Object)bulkResponse.buildFailureMessage());
            }
            log.debug("Executed {} memory operations successfully", (Object)bulkResponse.getItems().length);
            BulkItemResponse[] items = bulkResponse.getItems();
            int itemIndex = 0;
            for (int i = 0; i < results.size(); ++i) {
                MemoryResult result = (MemoryResult)results.get(i);
                if (result.getEvent() != MemoryEvent.ADD || itemIndex >= items.length) continue;
                while (itemIndex < items.length && items[itemIndex].getOpType() != DocWriteRequest.OpType.INDEX) {
                    ++itemIndex;
                }
                if (itemIndex < items.length && !items[itemIndex].isFailed()) {
                    String actualId = items[itemIndex].getId();
                    results.set(i, MemoryResult.builder().memoryId(actualId).memory(result.getMemory()).event(MemoryEvent.ADD).oldMemory(null).build());
                }
                ++itemIndex;
            }
            if (storageConfig != null && storageConfig.isSemanticStorageEnabled()) {
                this.updateEmbeddingsForOperations(results, indexName, storageConfig, listener);
            } else {
                listener.onResponse((Object)results);
            }
        }, e -> {
            log.error("Failed to execute memory operations", (Throwable)e);
            listener.onFailure(e);
        }));
    }

    public void bulkIndexMemoriesWithResults(List<IndexRequest> indexRequests, List<MemoryInfo> memoryInfos, String sessionId, String indexName, ActionListener<MLAddMemoriesResponse> actionListener) {
        if (indexRequests.isEmpty()) {
            log.warn("No memories to index");
            actionListener.onFailure((Exception)new IllegalStateException("No memories to index"));
            return;
        }
        this.indexMemoriesSequentiallyWithResults(indexRequests, memoryInfos, 0, sessionId, indexName, new ArrayList<MemoryResult>(), actionListener);
    }

    public void createFactMemoriesFromList(List<String> facts, MLAddMemoriesInput input, String indexName, String sessionId, User user, Instant now, List<IndexRequest> indexRequests, List<MemoryInfo> memoryInfos) {
        for (String fact : facts) {
            MLMemory factMemory = MLMemory.builder().sessionId(sessionId).memory(fact).memoryType(MemoryType.FACT).userId(user != null ? user.getName() : null).agentId(input.getAgentId()).role("assistant").tags(input.getTags()).createdTime(now).lastUpdatedTime(now).build();
            IndexRequest request = new IndexRequest(indexName).source(factMemory.toIndexMap());
            indexRequests.add(request);
            memoryInfos.add(new MemoryInfo(null, factMemory.getMemory(), factMemory.getMemoryType(), true));
        }
    }

    private void indexMemoriesSequentiallyWithResults(List<IndexRequest> indexRequests, List<MemoryInfo> memoryInfos, int currentIndex, String sessionId, String indexName, List<MemoryResult> results, ActionListener<MLAddMemoriesResponse> actionListener) {
        if (currentIndex >= indexRequests.size()) {
            log.debug("Successfully indexed {} memories in index {}", (Object)indexRequests.size(), (Object)indexName);
            MLAddMemoriesResponse response = MLAddMemoriesResponse.builder().results(results).sessionId(sessionId).build();
            actionListener.onResponse((Object)response);
            return;
        }
        IndexRequest currentRequest = indexRequests.get(currentIndex);
        this.client.index(currentRequest, ActionListener.wrap(indexResponse -> {
            String memoryId = indexResponse.getId();
            MemoryInfo info = (MemoryInfo)memoryInfos.get(currentIndex);
            info.setMemoryId(memoryId);
            if (info.isIncludeInResponse()) {
                results.add(MemoryResult.builder().memoryId(memoryId).memory(info.getContent()).event(MemoryEvent.ADD).oldMemory(null).build());
            }
            this.indexMemoriesSequentiallyWithResults(indexRequests, memoryInfos, currentIndex + 1, sessionId, indexName, results, actionListener);
        }, arg_0 -> actionListener.onFailure(arg_0)));
    }

    private void updateEmbeddingsForOperations(List<MemoryResult> results, String indexName, MemoryStorageConfig storageConfig, ActionListener<List<MemoryResult>> listener) {
        ArrayList<String> textsToEmbed = new ArrayList<String>();
        ArrayList<String> memoryIdsToUpdate = new ArrayList<String>();
        for (MemoryResult result : results) {
            if (result.getEvent() != MemoryEvent.ADD && result.getEvent() != MemoryEvent.UPDATE || result.getMemoryId() == null) continue;
            textsToEmbed.add(result.getMemory());
            memoryIdsToUpdate.add(result.getMemoryId());
        }
        if (!textsToEmbed.isEmpty()) {
            this.memoryEmbeddingHelper.generateEmbeddingsForMultipleTexts(textsToEmbed, storageConfig, (ActionListener<List<Object>>)ActionListener.wrap(embeddings -> {
                ArrayList<UpdateRequest> embeddingUpdates = new ArrayList<UpdateRequest>();
                for (int i = 0; i < memoryIdsToUpdate.size() && i < embeddings.size(); ++i) {
                    HashMap embeddingUpdate = new HashMap();
                    embeddingUpdate.put("memory_embedding", embeddings.get(i));
                    UpdateRequest updateRequest = new UpdateRequest(indexName, (String)memoryIdsToUpdate.get(i)).doc(embeddingUpdate);
                    embeddingUpdates.add(updateRequest);
                }
                if (!embeddingUpdates.isEmpty()) {
                    BulkRequest embeddingBulk = new BulkRequest().setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
                    for (UpdateRequest request : embeddingUpdates) {
                        embeddingBulk.add(request);
                    }
                    this.client.bulk(embeddingBulk, ActionListener.wrap(embeddingResponse -> {
                        if (embeddingResponse.hasFailures()) {
                            log.error("Failed to update embeddings: {}", (Object)embeddingResponse.buildFailureMessage());
                        }
                        listener.onResponse((Object)results);
                    }, e -> {
                        log.error("Failed to update embeddings", (Throwable)e);
                        listener.onResponse((Object)results);
                    }));
                } else {
                    listener.onResponse((Object)results);
                }
            }, e -> {
                log.error("Failed to generate embeddings for memory operations", (Throwable)e);
                listener.onResponse((Object)results);
            }));
        } else {
            listener.onResponse(results);
        }
    }
}

