package org.apache.sysml.hops.codegen.cplan;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.sysml.hops.codegen.template.TemplateUtils;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysml.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysml/hops/codegen/cplan/CNodeNary.class */
public class CNodeNary extends CNode {
    private final NaryType _type;

    /* loaded from: input_file:org/apache/sysml/hops/codegen/cplan/CNodeNary$NaryType.class */
    public enum NaryType {
        VECT_CBIND;

        public static boolean contains(String str) {
            for (NaryType naryType : values()) {
                if (naryType.name().equals(str)) {
                    return true;
                }
            }
            return false;
        }

        public String getTemplate(boolean z, long j, ArrayList<CNode> arrayList) {
            switch (this) {
                case VECT_CBIND:
                    StringBuilder sb = new StringBuilder();
                    sb.append("    double[] %TMP% = LibSpoofPrimitives.allocVector(" + j + ", true); //nary cbind\n");
                    int i = 0;
                    for (int i2 = 0; i2 < arrayList.size(); i2++) {
                        CNode cNode = arrayList.get(i2);
                        boolean z2 = z && (cNode instanceof CNodeData) && cNode.getVarname().startsWith(GPUInstruction.MISC_TIMER_ALLOCATE);
                        String varname = cNode.getVarname();
                        String str = ((cNode instanceof CNodeData) && cNode.getDataType().isMatrix()) ? !varname.startsWith("b") ? varname + "i" : TemplateUtils.isMatrix(cNode) ? varname + ".pos(rix)" : "0" : "0";
                        sb.append(z2 ? "    LibSpoofPrimitives.vectWrite(" + varname + "vals, %TMP%, " + varname + "ix, " + str + ", " + i + ", " + cNode._cols + ");\n" : "    LibSpoofPrimitives.vectWrite(" + (varname.startsWith("b") ? varname + ".values(rix)" : varname) + ", %TMP%, " + str + ", " + i + ", " + cNode._cols + ");\n");
                        i = (int) (i + cNode._cols);
                    }
                    return sb.toString();
                default:
                    throw new RuntimeException("Invalid nary type: " + toString());
            }
        }

        public boolean isVectorPrimitive() {
            return this == VECT_CBIND;
        }
    }

    public CNodeNary(CNode[] cNodeArr, NaryType naryType) {
        for (CNode cNode : cNodeArr) {
            this._inputs.add(cNode);
        }
        this._type = naryType;
        setOutputDims();
    }

    public NaryType getType() {
        return this._type;
    }

    @Override // org.apache.sysml.hops.codegen.cplan.CNode
    public String codegen(boolean z) {
        if (isGenerated()) {
            return "";
        }
        StringBuilder sb = new StringBuilder();
        Iterator<CNode> it = this._inputs.iterator();
        while (it.hasNext()) {
            sb.append(it.next().codegen(z));
        }
        sb.append(this._type.getTemplate(z, this._cols, this._inputs).replace("%TMP%", createVarname()));
        this._generated = true;
        return sb.toString();
    }

    public String toString() {
        switch (this._type) {
            case VECT_CBIND:
                return "n(cbind)";
            default:
                return "m(" + this._type.name().toLowerCase() + ")";
        }
    }

    @Override // org.apache.sysml.hops.codegen.cplan.CNode
    public void setOutputDims() {
        switch (this._type) {
            case VECT_CBIND:
                this._rows = this._inputs.get(0)._rows;
                this._cols = 0L;
                Iterator<CNode> it = this._inputs.iterator();
                while (it.hasNext()) {
                    this._cols += it.next()._cols;
                }
                this._dataType = Expression.DataType.MATRIX;
                return;
            default:
                return;
        }
    }

    @Override // org.apache.sysml.hops.codegen.cplan.CNode
    public int hashCode() {
        if (this._hash == 0) {
            this._hash = UtilFunctions.intHashCode(super.hashCode(), this._type.hashCode());
        }
        return this._hash;
    }

    @Override // org.apache.sysml.hops.codegen.cplan.CNode
    public boolean equals(Object obj) {
        if (!(obj instanceof CNodeNary)) {
            return false;
        }
        CNodeNary cNodeNary = (CNodeNary) obj;
        return super.equals(cNodeNary) && this._type == cNodeNary._type;
    }
}
