/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.spark;

import java.util.List;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.functionobjects.Builtin;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
import org.apache.sysml.runtime.instructions.spark.AppendGSPInstruction;
import org.apache.sysml.runtime.instructions.spark.SPInstruction;
import org.apache.sysml.runtime.instructions.spark.functions.MapInputSignature;
import org.apache.sysml.runtime.instructions.spark.functions.MapJoinSignature;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.operators.SimpleOperator;
import org.apache.sysml.runtime.util.UtilFunctions;
import scala.Tuple2;

public class BuiltinNarySPInstruction
extends SPInstruction {
    private CPOperand[] inputs;
    private CPOperand output;

    protected BuiltinNarySPInstruction(CPOperand[] in, CPOperand out, String opcode, String istr) {
        super(SPInstruction.SPType.BuiltinNary, opcode, istr);
        this.inputs = in;
        this.output = out;
    }

    public static BuiltinNarySPInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        CPOperand output = new CPOperand(parts[parts.length - 1]);
        CPOperand[] inputs = null;
        inputs = new CPOperand[parts.length - 2];
        for (int i = 1; i < parts.length - 1; ++i) {
            inputs[i - 1] = new CPOperand(parts[i]);
        }
        return new BuiltinNarySPInstruction(inputs, output, opcode, str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        JavaPairRDD out = null;
        MatrixCharacteristics mcOut = null;
        if (this.getOpcode().equals("cbind") || this.getOpcode().equals("rbind")) {
            boolean cbind = this.getOpcode().equals("cbind");
            mcOut = BuiltinNarySPInstruction.computeAppendOutputMatrixCharacteristics(sec, this.inputs, cbind);
            MatrixCharacteristics off = new MatrixCharacteristics(0L, 0L, mcOut.getRowsPerBlock(), mcOut.getColsPerBlock(), 0L);
            for (CPOperand input : this.inputs) {
                MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(input.getName());
                JavaPairRDD in = sec.getBinaryBlockRDDHandleForVariable(input.getName()).flatMapToPair((PairFlatMapFunction)new AppendGSPInstruction.ShiftMatrix(off, mcIn, cbind)).mapToPair((PairFunction)new PadBlocksFunction(mcOut));
                out = out != null ? out.union(in) : in;
                BuiltinNarySPInstruction.updateAppendMatrixCharacteristics(mcIn, off, cbind);
            }
            int numPartOut = SparkUtils.getNumPreferredPartitions(mcOut);
            out = RDDAggregateUtils.mergeByKey(out, numPartOut, false);
        } else if (this.getOpcode().equals("nmin") || this.getOpcode().equals("nmax")) {
            mcOut = BuiltinNarySPInstruction.computeMinMaxOutputMatrixCharacteristics(sec, this.inputs);
            List<ScalarObject> scalars = sec.getScalarInputs(this.inputs);
            JavaPairRDD in = null;
            for (CPOperand input : this.inputs) {
                if (!input.getDataType().isMatrix()) continue;
                JavaPairRDD<MatrixIndexes, MatrixBlock> tmp = sec.getBinaryBlockRDDHandleForVariable(input.getName());
                in = in == null ? tmp.mapValues((Function)new MapInputSignature()) : in.join(tmp).mapValues((Function)new MapJoinSignature());
            }
            out = in.mapValues((Function)new MinMaxFunction(this.getOpcode(), scalars));
        }
        sec.getMatrixCharacteristics(this.output.getName()).set(mcOut);
        sec.setRDDHandleForVariable(this.output.getName(), out);
        for (CPOperand input : this.inputs) {
            if (input.isScalar()) continue;
            sec.addLineageRDD(this.output.getName(), input.getName());
        }
    }

    private static MatrixCharacteristics computeAppendOutputMatrixCharacteristics(SparkExecutionContext sec, CPOperand[] inputs, boolean cbind) {
        MatrixCharacteristics mcIn1 = sec.getMatrixCharacteristics(inputs[0].getName());
        MatrixCharacteristics mcOut = new MatrixCharacteristics(0L, 0L, mcIn1.getRowsPerBlock(), mcIn1.getColsPerBlock(), 0L);
        for (CPOperand input : inputs) {
            MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(input.getName());
            BuiltinNarySPInstruction.updateAppendMatrixCharacteristics(mcIn, mcOut, cbind);
        }
        return mcOut;
    }

    private static void updateAppendMatrixCharacteristics(MatrixCharacteristics in, MatrixCharacteristics out, boolean cbind) {
        out.setDimension(cbind ? Math.max(out.getRows(), in.getRows()) : out.getRows() + in.getRows(), cbind ? out.getCols() + in.getCols() : Math.max(out.getCols(), in.getCols()));
        out.setNonZeros(out.getNonZeros() != -1L && in.dimsKnown(true) ? out.getNonZeros() + in.getNonZeros() : -1L);
    }

    private static MatrixCharacteristics computeMinMaxOutputMatrixCharacteristics(SparkExecutionContext sec, CPOperand[] inputs) {
        MatrixCharacteristics mcOut = new MatrixCharacteristics();
        for (CPOperand input : inputs) {
            if (!input.getDataType().isMatrix()) continue;
            MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(input.getName());
            mcOut.setRows(Math.max(mcOut.getRows(), mcIn.getRows()));
            mcOut.setCols(Math.max(mcOut.getCols(), mcIn.getCols()));
            mcOut.setRowsPerBlock(mcIn.getRowsPerBlock());
            mcOut.setColsPerBlock(mcIn.getColsPerBlock());
        }
        return mcOut;
    }

    private static class MinMaxFunction
    implements Function<MatrixBlock[], MatrixBlock> {
        private static final long serialVersionUID = -4227447915387484397L;
        private final SimpleOperator _op;
        private final ScalarObject[] _scalars;

        public MinMaxFunction(String opcode, List<ScalarObject> scalars) {
            this._scalars = scalars.toArray(new ScalarObject[0]);
            this._op = new SimpleOperator(Builtin.getBuiltinFnObject(opcode.substring(1)));
        }

        public MatrixBlock call(MatrixBlock[] v1) throws Exception {
            return MatrixBlock.naryOperations(this._op, v1, this._scalars, new MatrixBlock());
        }
    }

    public static class PadBlocksFunction
    implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 1291358959908299855L;
        private final MatrixCharacteristics _mcOut;

        public PadBlocksFunction(MatrixCharacteristics mcOut) {
            this._mcOut = mcOut;
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
            MatrixIndexes ix = (MatrixIndexes)arg0._1();
            MatrixBlock mb = (MatrixBlock)arg0._2();
            int brlen = UtilFunctions.computeBlockSize(this._mcOut.getRows(), ix.getRowIndex(), this._mcOut.getRowsPerBlock());
            int bclen = UtilFunctions.computeBlockSize(this._mcOut.getCols(), ix.getColumnIndex(), this._mcOut.getColsPerBlock());
            if (brlen == mb.getNumRows() && bclen == mb.getNumColumns()) {
                return arg0;
            }
            if (brlen > mb.getNumRows()) {
                mb = mb.append(new MatrixBlock(brlen - mb.getNumRows(), bclen, true), new MatrixBlock(), false);
            } else if (bclen > mb.getNumColumns()) {
                mb = mb.append(new MatrixBlock(brlen, bclen - mb.getNumColumns(), true), new MatrixBlock(), true);
            }
            return new Tuple2((Object)ix, (Object)mb);
        }
    }
}

