package org.apache.sysml.yarn.ropt;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.ForProgramBlock;
import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysml.runtime.controlprogram.IfProgramBlock;
import org.apache.sysml.runtime.controlprogram.ProgramBlock;
import org.apache.sysml.runtime.controlprogram.WhileProgramBlock;

/* loaded from: input_file:org/apache/sysml/yarn/ropt/GridEnumerationMemory.class */
public class GridEnumerationMemory extends GridEnumeration {
    public static final int DEFAULT_NSTEPS = 20;
    public static final int DEFAULT_MEM_ADD = 1048576;
    private int _nsteps;

    public GridEnumerationMemory(ArrayList<ProgramBlock> arrayList, long j, long j2) throws DMLRuntimeException {
        super(arrayList, j, j2);
        this._nsteps = -1;
        this._nsteps = 20;
    }

    public void setNumSteps(int i) {
        this._nsteps = i;
    }

    @Override // org.apache.sysml.yarn.ropt.GridEnumeration
    public ArrayList<Long> enumerateGridPoints() throws DMLRuntimeException, HopsException {
        ArrayList<Long> arrayList = new ArrayList<>();
        long j = (this._max - this._min) / (this._nsteps - 1);
        ArrayList<Long> arrayList2 = new ArrayList<>();
        getMemoryEstimates(this._prog, arrayList2);
        HashSet hashSet = new HashSet();
        Iterator<Long> it = arrayList2.iterator();
        while (it.hasNext()) {
            Long next = it.next();
            if (next.longValue() < this._min) {
                hashSet.add(Long.valueOf(this._min));
            } else if (next.longValue() > this._max) {
                hashSet.add(Long.valueOf(this._max));
            } else {
                long max = Math.max((next.longValue() - this._min) / j, 0L);
                hashSet.add(Long.valueOf(filterMax(this._min + (max * j))));
                hashSet.add(Long.valueOf(filterMax(this._min + ((max + 1) * j))));
            }
        }
        Iterator it2 = hashSet.iterator();
        while (it2.hasNext()) {
            arrayList.add((Long) it2.next());
        }
        Collections.sort(arrayList);
        return arrayList;
    }

    private long filterMax(long j) {
        return j > this._max ? this._max : j;
    }

    private void getMemoryEstimates(ArrayList<ProgramBlock> arrayList, ArrayList<Long> arrayList2) throws HopsException {
        Iterator<ProgramBlock> it = arrayList.iterator();
        while (it.hasNext()) {
            getMemoryEstimates(it.next(), arrayList2);
        }
    }

    private void getMemoryEstimates(ProgramBlock programBlock, ArrayList<Long> arrayList) throws HopsException {
        if (programBlock instanceof FunctionProgramBlock) {
            getMemoryEstimates(((FunctionProgramBlock) programBlock).getChildBlocks(), arrayList);
            return;
        }
        if (programBlock instanceof WhileProgramBlock) {
            getMemoryEstimates(((WhileProgramBlock) programBlock).getChildBlocks(), arrayList);
            return;
        }
        if (programBlock instanceof IfProgramBlock) {
            IfProgramBlock ifProgramBlock = (IfProgramBlock) programBlock;
            getMemoryEstimates(ifProgramBlock.getChildBlocksIfBody(), arrayList);
            getMemoryEstimates(ifProgramBlock.getChildBlocksElseBody(), arrayList);
        } else {
            if (programBlock instanceof ForProgramBlock) {
                getMemoryEstimates(((ForProgramBlock) programBlock).getChildBlocks(), arrayList);
                return;
            }
            StatementBlock statementBlock = programBlock.getStatementBlock();
            if (statementBlock == null || statementBlock.getHops() == null) {
                return;
            }
            Hop.resetVisitStatus(statementBlock.getHops());
            Iterator<Hop> it = statementBlock.getHops().iterator();
            while (it.hasNext()) {
                getMemoryEstimates(it.next(), arrayList);
            }
        }
    }

    private void getMemoryEstimates(Hop hop, ArrayList<Long> arrayList) {
        if (hop.isVisited()) {
            return;
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            getMemoryEstimates(it.next(), arrayList);
        }
        arrayList.add(Long.valueOf((long) ((hop.getMemEstimate() + 1048576.0d) / OptimizerUtils.MEM_UTIL_FACTOR)));
        hop.setVisited();
    }
}
