/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.lops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.OperatorOrderingUtils;
import org.apache.sysds.lops.UnaryCP;
import org.apache.sysds.lops.rewrite.LopRewriteRule;
import org.apache.sysds.parser.StatementBlock;

public class RewriteAddBroadcastLop
extends LopRewriteRule {
    @Override
    public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock sb) {
        if (!ConfigurationManager.isBroadcastEnabled()) {
            return List.of(sb);
        }
        ArrayList<Lop> lops = OperatorOrderingUtils.getLopList(sb);
        if (lops == null) {
            return List.of(sb);
        }
        ArrayList<Lop> nodesWithBroadcast = new ArrayList<Lop>();
        for (Lop l : lops) {
            nodesWithBroadcast.add(l);
            if (!RewriteAddBroadcastLop.isBroadcastNeeded(l)) continue;
            ArrayList<Lop> oldOuts = new ArrayList<Lop>(l.getOutputs());
            UnaryCP bc = new UnaryCP(l, Types.OpOp1.BROADCAST, l.getDataType(), l.getValueType(), Types.ExecType.CP);
            bc.setAsynchronous(true);
            for (Lop outCP : oldOuts) {
                bc.addOutput(outCP);
                outCP.replaceInput(l, bc);
                l.removeOutput(outCP);
            }
            nodesWithBroadcast.add(bc);
        }
        return Arrays.asList(sb);
    }

    @Override
    public List<StatementBlock> rewriteLOPinStatementBlocks(List<StatementBlock> sbs) {
        return sbs;
    }

    private static boolean isBroadcastNeeded(Lop lop) {
        boolean isBc = lop.getOutputs().stream().anyMatch(out -> out.getBroadcastInput() == lop);
        boolean isCP = lop.getExecType() == Types.ExecType.CP;
        return isCP && isBc && lop.getDataType() == Types.DataType.MATRIX;
    }
}

