/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.sparse_encoding;

import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.translate.TranslatorContext;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.opensearch.ml.common.input.parameter.textembedding.SparseEmbeddingFormat;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.algorithms.SentenceTransformerTranslator;

public class SparseEncodingTranslator
extends SentenceTransformerTranslator {
    @Override
    public NDList processInput(TranslatorContext ctx, Input input) {
        String embeddingFormat = input.getAsString("sparse_embedding_format");
        if (embeddingFormat != null) {
            ctx.setAttachment("sparse_embedding_format", (Object)embeddingFormat);
        }
        return super.processInput(ctx, input);
    }

    public Output processOutput(TranslatorContext ctx, NDList list) {
        Output output = new Output(200, "OK");
        Object embeddingFormatObject = ctx.getAttachment("sparse_embedding_format");
        SparseEmbeddingFormat embeddingFormat = embeddingFormatObject != null ? SparseEmbeddingFormat.valueOf((String)embeddingFormatObject.toString()) : SparseEmbeddingFormat.WORD;
        ArrayList<ModelTensor> outputs = new ArrayList<ModelTensor>();
        for (NDArray ndArray : list) {
            String name = ndArray.getName();
            Object result = this.convertOutput(ndArray, embeddingFormat);
            Map<String, List<Object>> wrappedMap = Map.of("response", Collections.singletonList(result));
            ModelTensor tensor = ModelTensor.builder().name(name).dataAsMap(wrappedMap).build();
            outputs.add(tensor);
        }
        ModelTensors modelTensorOutput = new ModelTensors(outputs);
        output.add(modelTensorOutput.toBytes());
        return output;
    }

    private Object convertOutput(NDArray array, SparseEmbeddingFormat embeddingFormat) {
        NDArray nonZeroIndices = array.nonzero().squeeze();
        long[] indices = nonZeroIndices.toLongArray();
        if (embeddingFormat == SparseEmbeddingFormat.TOKEN_ID) {
            HashMap<String, Float> tokenIdWeights = new HashMap<String, Float>();
            for (long index : indices) {
                tokenIdWeights.put(String.valueOf(index), Float.valueOf(array.getFloat(new long[]{index})));
            }
            return tokenIdWeights;
        }
        HashMap<String, Float> tokenWeights = new HashMap<String, Float>();
        for (long index : indices) {
            String token = this.tokenizer.decode(new long[]{index}, true);
            if (token.isEmpty()) continue;
            tokenWeights.put(token, Float.valueOf(array.getFloat(new long[]{index})));
        }
        return tokenWeights;
    }
}

