package org.apache.sysml.runtime.instructions.cp;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.parser.DataIdentifier;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLScriptException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPInstruction;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.class */
public class FunctionCallCPInstruction extends CPInstruction {
    private String _functionName;
    private String _namespace;
    private ArrayList<CPOperand> _boundInputParamOperands;
    private ArrayList<String> _boundInputParamNames;
    private ArrayList<String> _boundOutputParamNames;

    public String getFunctionName() {
        return this._functionName;
    }

    public String getNamespace() {
        return this._namespace;
    }

    public FunctionCallCPInstruction(String str, String str2, ArrayList<CPOperand> arrayList, ArrayList<String> arrayList2, ArrayList<String> arrayList3, String str3) {
        super(null, str2, str3);
        this._cptype = CPInstruction.CPINSTRUCTION_TYPE.External;
        this._functionName = str2;
        this._namespace = str;
        this._boundInputParamOperands = arrayList;
        this._boundInputParamNames = arrayList2;
        this._boundOutputParamNames = arrayList3;
    }

    public static FunctionCallCPInstruction parseInstruction(String str) throws DMLRuntimeException, DMLUnsupportedOperationException {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[1];
        String str3 = instructionPartsWithValueType[2];
        int intValue = Integer.valueOf(instructionPartsWithValueType[3]).intValue();
        int intValue2 = Integer.valueOf(instructionPartsWithValueType[4]).intValue();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (int i = 0; i < intValue; i++) {
            CPOperand cPOperand = new CPOperand(instructionPartsWithValueType[5 + i]);
            arrayList.add(cPOperand);
            arrayList2.add(cPOperand.getName());
        }
        for (int i2 = 0; i2 < intValue2; i2++) {
            arrayList3.add(instructionPartsWithValueType[5 + intValue + i2]);
        }
        return new FunctionCallCPInstruction(str2, str3, arrayList, arrayList2, arrayList3, str);
    }

    @Override // org.apache.sysml.runtime.instructions.cp.CPInstruction, org.apache.sysml.runtime.instructions.Instruction
    public Instruction preprocessInstruction(ExecutionContext executionContext) throws DMLRuntimeException, DMLUnsupportedOperationException {
        Instruction preprocessInstruction = super.preprocessInstruction(executionContext);
        if (DMLScript.ENABLE_DEBUG_MODE) {
            executionContext.handleDebugFunctionEntry((FunctionCallCPInstruction) preprocessInstruction);
        }
        return preprocessInstruction;
    }

    @Override // org.apache.sysml.runtime.instructions.cp.CPInstruction, org.apache.sysml.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) throws DMLRuntimeException, DMLUnsupportedOperationException {
        Data removeVariable;
        Data scalarInput;
        if (LOG.isTraceEnabled()) {
            LOG.trace("Executing instruction : " + toString());
        }
        FunctionProgramBlock functionProgramBlock = executionContext.getProgram().getFunctionProgramBlock(this._namespace, this._functionName);
        LocalVariableMap localVariableMap = new LocalVariableMap();
        for (int i = 0; i < functionProgramBlock.getInputParams().size(); i++) {
            String name = functionProgramBlock.getInputParams().get(i).getName();
            Expression.ValueType valueType = functionProgramBlock.getInputParams().get(i).getValueType();
            if (i > this._boundInputParamNames.size() || (!this._boundInputParamOperands.get(i).isLiteral() && executionContext.getVariable(this._boundInputParamNames.get(i)) == null)) {
                scalarInput = executionContext.getScalarInput(functionProgramBlock.getInputParams().get(i).getDefaultValue(), valueType, false);
            } else {
                CPOperand cPOperand = this._boundInputParamOperands.get(i);
                scalarInput = cPOperand.getDataType() == Expression.DataType.SCALAR ? executionContext.getScalarInput(cPOperand.getName(), cPOperand.getValueType(), cPOperand.isLiteral()) : executionContext.getVariable(cPOperand.getName());
            }
            localVariableMap.put(name, scalarInput);
        }
        HashMap<String, Boolean> pinVariables = executionContext.pinVariables(this._boundInputParamNames);
        ExecutionContext createContext = ExecutionContextFactory.createContext(false, executionContext.getProgram());
        createContext.setVariables(localVariableMap);
        try {
            functionProgramBlock.execute(createContext);
            LocalVariableMap variables = createContext.getVariables();
            LinkedList<String> linkedList = new LinkedList(variables.keySet());
            HashSet hashSet = new HashSet();
            Iterator<DataIdentifier> it = functionProgramBlock.getOutputParams().iterator();
            while (it.hasNext()) {
                hashSet.add(it.next().getName());
            }
            for (String str : linkedList) {
                if (!hashSet.contains(str) && (removeVariable = createContext.removeVariable(str)) != null && (removeVariable instanceof MatrixObject)) {
                    createContext.cleanupMatrixObject((MatrixObject) removeVariable);
                }
            }
            executionContext.unpinVariables(this._boundInputParamNames, pinVariables);
            for (int i2 = 0; i2 < functionProgramBlock.getOutputParams().size(); i2++) {
                String str2 = this._boundOutputParamNames.get(i2);
                Data data = variables.get(functionProgramBlock.getOutputParams().get(i2).getName());
                if (data == null) {
                    throw new DMLUnsupportedOperationException(str2 + " was not assigned a return value");
                }
                Data removeVariable2 = executionContext.removeVariable(str2);
                if (removeVariable2 != null && (removeVariable2 instanceof MatrixObject) && removeVariable2 != data) {
                    executionContext.cleanupMatrixObject((MatrixObject) removeVariable2);
                }
                if (data instanceof MatrixObject) {
                    ((MatrixObject) data).setVarName(str2);
                }
                executionContext.setVariable(str2, data);
            }
        } catch (DMLScriptException e) {
            throw e;
        } catch (Exception e2) {
            throw new DMLRuntimeException("error executing function " + (this._namespace + Program.KEY_DELIM + this._functionName), e2);
        }
    }

    @Override // org.apache.sysml.runtime.instructions.Instruction
    public void postprocessInstruction(ExecutionContext executionContext) throws DMLRuntimeException {
        if (DMLScript.ENABLE_DEBUG_MODE) {
            executionContext.handleDebugFunctionExit(this);
        }
        super.postprocessInstruction(executionContext);
    }

    @Override // org.apache.sysml.runtime.instructions.Instruction
    public void printMe() {
        LOG.debug("ExternalBuiltInFunction: " + toString());
    }

    @Override // org.apache.sysml.runtime.instructions.cp.CPInstruction, org.apache.sysml.runtime.instructions.Instruction
    public String getGraphString() {
        return "ExtBuiltinFunc: " + this._functionName;
    }

    public ArrayList<String> getBoundInputParamNames() {
        return this._boundInputParamNames;
    }

    public ArrayList<String> getBoundOutputParamNames() {
        return this._boundOutputParamNames;
    }

    public void setFunctionName(String str) {
        this.instString = updateInstStringFunctionName(this._functionName, str);
        this._functionName = str;
        this.instOpcode = str;
    }

    public String updateInstStringFunctionName(String str, String str2) {
        String[] split = this.instString.split("°");
        if (split[3].equals(str)) {
            split[3] = str2;
        }
        StringBuilder sb = new StringBuilder();
        for (String str3 : split) {
            sb.append(str3);
            sb.append("°");
        }
        return sb.substring(0, sb.length() - "°".length());
    }
}
