package org.apache.sysml.runtime.matrix.mapred;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Set;
import org.apache.sysml.runtime.matrix.mapred.CachedMapElement;

/* loaded from: input_file:org/apache/sysml/runtime/matrix/mapred/CachedMap.class */
public class CachedMap<T extends CachedMapElement> {
    protected HashMap<Byte, ArrayList<Integer>> map = new HashMap<>();
    protected ArrayList<T> cache = new ArrayList<>();
    protected int numValid = 0;
    protected ArrayList<T> returnListCache = new ArrayList<>(4);

    public T add(Byte b, T t) {
        if (this.numValid < this.cache.size()) {
            this.cache.get(this.numValid).set(t);
        } else {
            this.cache.add(t.duplicate());
        }
        ArrayList<Integer> arrayList = this.map.get(b);
        if (arrayList == null) {
            arrayList = new ArrayList<>(4);
            this.map.put(b, arrayList);
        }
        arrayList.add(Integer.valueOf(this.numValid));
        this.numValid++;
        return this.cache.get(this.numValid - 1);
    }

    public void reset() {
        this.numValid = 0;
        this.map.clear();
    }

    public void remove(byte b) {
        ArrayList<Integer> remove = this.map.remove(Byte.valueOf(b));
        if (remove == null) {
            return;
        }
        Iterator<Integer> it = remove.iterator();
        while (it.hasNext()) {
            Integer next = it.next();
            if (next.intValue() == this.numValid - 1) {
                this.numValid--;
                return;
            }
            T t = this.cache.get(this.numValid - 1);
            this.cache.set(this.numValid - 1, this.cache.get(next.intValue()));
            this.cache.set(next.intValue(), t);
            for (ArrayList<Integer> arrayList : this.map.values()) {
                int i = 0;
                while (true) {
                    if (i >= arrayList.size()) {
                        break;
                    }
                    if (arrayList.get(i).intValue() == this.numValid - 1) {
                        arrayList.set(i, next);
                        break;
                    }
                    i++;
                }
            }
            this.numValid--;
        }
    }

    public ArrayList<T> get(byte b) {
        ArrayList<Integer> arrayList = this.map.get(Byte.valueOf(b));
        if (arrayList == null) {
            return null;
        }
        this.returnListCache.clear();
        Iterator<Integer> it = arrayList.iterator();
        while (it.hasNext()) {
            this.returnListCache.add(this.cache.get(it.next().intValue()));
        }
        return this.returnListCache;
    }

    public T getFirst(byte b) {
        ArrayList<Integer> arrayList = this.map.get(Byte.valueOf(b));
        if (arrayList == null || arrayList.isEmpty()) {
            return null;
        }
        return this.cache.get(arrayList.get(0).intValue());
    }

    public Set<Byte> getIndexesOfAll() {
        return this.map.keySet();
    }

    public String toString() {
        return "numValid: " + this.numValid + "\n" + this.map.toString() + "\n" + this.cache.toString();
    }
}
