/*
 * Decompiled with CFR 0.152.
 */
package jdk.graal.compiler.phases.common;

import java.util.Optional;
import jdk.graal.compiler.core.common.NumUtil;
import jdk.graal.compiler.core.common.type.IntegerStamp;
import jdk.graal.compiler.core.common.type.Stamp;
import jdk.graal.compiler.debug.Assertions;
import jdk.graal.compiler.debug.DebugCloseable;
import jdk.graal.compiler.debug.GraalError;
import jdk.graal.compiler.graph.Graph;
import jdk.graal.compiler.graph.Node;
import jdk.graal.compiler.nodes.ConstantNode;
import jdk.graal.compiler.nodes.FixedNode;
import jdk.graal.compiler.nodes.FixedWithNextNode;
import jdk.graal.compiler.nodes.GraphState;
import jdk.graal.compiler.nodes.NodeView;
import jdk.graal.compiler.nodes.PiNode;
import jdk.graal.compiler.nodes.StructuredGraph;
import jdk.graal.compiler.nodes.ValueNode;
import jdk.graal.compiler.nodes.calc.BinaryArithmeticNode;
import jdk.graal.compiler.nodes.calc.BinaryNode;
import jdk.graal.compiler.nodes.calc.FixedBinaryNode;
import jdk.graal.compiler.nodes.calc.IntegerDivRemNode;
import jdk.graal.compiler.nodes.calc.IntegerMulHighNode;
import jdk.graal.compiler.nodes.calc.MulNode;
import jdk.graal.compiler.nodes.calc.NarrowNode;
import jdk.graal.compiler.nodes.calc.RightShiftNode;
import jdk.graal.compiler.nodes.calc.SignExtendNode;
import jdk.graal.compiler.nodes.calc.SignedDivNode;
import jdk.graal.compiler.nodes.calc.SignedFloatingIntegerDivNode;
import jdk.graal.compiler.nodes.calc.SignedFloatingIntegerRemNode;
import jdk.graal.compiler.nodes.calc.SignedRemNode;
import jdk.graal.compiler.nodes.calc.UnsignedRightShiftNode;
import jdk.graal.compiler.nodes.spi.Canonicalizable;
import jdk.graal.compiler.nodes.spi.CoreProviders;
import jdk.graal.compiler.phases.BasePhase;
import jdk.graal.compiler.phases.common.CanonicalizerPhase;
import jdk.graal.compiler.phases.common.util.EconomicSetNodeEventListener;
import jdk.vm.ci.code.CodeUtil;
import org.graalvm.collections.Pair;

public class OptimizeDivPhase
extends BasePhase<CoreProviders> {
    protected final CanonicalizerPhase canonicalizer;

    public OptimizeDivPhase(CanonicalizerPhase canonicalizer) {
        this.canonicalizer = canonicalizer;
    }

    @Override
    public Optional<BasePhase.NotApplicable> notApplicableTo(GraphState graphState) {
        return this.canonicalizer.notApplicableTo(graphState);
    }

    @Override
    protected void run(StructuredGraph graph, CoreProviders context) {
        EconomicSetNodeEventListener ec = new EconomicSetNodeEventListener();
        try (Graph.NodeEventScope nes = graph.trackNodeEvents(ec);){
            DebugCloseable position;
            for (IntegerDivRemNode integerDivRemNode : graph.getNodes(IntegerDivRemNode.TYPE)) {
                if (!(integerDivRemNode instanceof SignedRemNode) || !OptimizeDivPhase.isDivByNonZeroConstantNonOverflowingAbs(integerDivRemNode)) continue;
                position = integerDivRemNode.withNodeSourcePosition();
                try {
                    this.optimizeRem(integerDivRemNode);
                }
                finally {
                    if (position == null) continue;
                    position.close();
                }
            }
            for (SignedFloatingIntegerRemNode signedFloatingIntegerRemNode : graph.getNodes(SignedFloatingIntegerRemNode.TYPE)) {
                if (!OptimizeDivPhase.isDivByNonZeroConstantNonOverflowingAbs(signedFloatingIntegerRemNode)) continue;
                position = signedFloatingIntegerRemNode.withNodeSourcePosition();
                try {
                    this.optimizeRem(signedFloatingIntegerRemNode);
                }
                finally {
                    if (position == null) continue;
                    position.close();
                }
            }
            for (IntegerDivRemNode integerDivRemNode : graph.getNodes(IntegerDivRemNode.TYPE)) {
                if (!(integerDivRemNode instanceof SignedDivNode) || !OptimizeDivPhase.isDivByNonZeroConstantNonOverflowingAbs(integerDivRemNode)) continue;
                position = integerDivRemNode.withNodeSourcePosition();
                try {
                    OptimizeDivPhase.optimizeSignedDiv(integerDivRemNode);
                }
                finally {
                    if (position == null) continue;
                    position.close();
                }
            }
            for (SignedFloatingIntegerDivNode signedFloatingIntegerDivNode : graph.getNodes(SignedFloatingIntegerDivNode.TYPE)) {
                if (!OptimizeDivPhase.isDivByNonZeroConstantNonOverflowingAbs(signedFloatingIntegerDivNode)) continue;
                position = signedFloatingIntegerDivNode.withNodeSourcePosition();
                try {
                    OptimizeDivPhase.optimizeSignedDiv(signedFloatingIntegerDivNode);
                }
                finally {
                    if (position == null) continue;
                    position.close();
                }
            }
        }
        if (!ec.getNodes().isEmpty()) {
            this.canonicalizer.applyIncremental(graph, context, (Iterable<? extends Node>)ec.getNodes());
        }
    }

    @Override
    public float codeSizeIncrease() {
        return 5.0f;
    }

    protected static boolean isDivByNonZeroConstantNonOverflowingAbs(Canonicalizable.Binary<ValueNode> divRemNode) {
        if (divRemNode.getY().isConstant()) {
            ValueNode divisor = divRemNode.getY();
            long constantVal = divisor.asJavaConstant().asLong();
            return constantVal != 0L && !NumUtil.absOverflows(constantVal, IntegerStamp.getBits(divisor.stamp(NodeView.DEFAULT)));
        }
        return false;
    }

    protected final void optimizeRem(Canonicalizable.Binary<ValueNode> rem) {
        assert (rem instanceof IntegerDivRemNode || rem instanceof SignedFloatingIntegerRemNode) : Assertions.errorMessageContext("rem", rem);
        StructuredGraph graph = ((ValueNode)((Object)rem)).graph();
        ValueNode div = this.findDivForRem((ValueNode)((Object)rem));
        ValueNode mul = BinaryArithmeticNode.mul(graph, div, rem.getY(), NodeView.DEFAULT);
        ValueNode result = BinaryArithmeticNode.sub(graph, rem.getX(), mul, NodeView.DEFAULT);
        OptimizeDivPhase.replacePreserveOriginalStamp(graph, (ValueNode)((Object)rem), result);
    }

    private ValueNode findDivForRem(ValueNode val) {
        if (val instanceof IntegerDivRemNode) {
            ValueNode div;
            IntegerDivRemNode rem = (IntegerDivRemNode)val;
            if (rem.next() instanceof IntegerDivRemNode && ((IntegerDivRemNode)(div = (IntegerDivRemNode)rem.next())).getOp() == IntegerDivRemNode.Op.DIV && ((IntegerDivRemNode)div).getType() == rem.getType() && ((FixedBinaryNode)div).getX() == rem.getX() && ((FixedBinaryNode)div).getY() == rem.getY()) {
                return div;
            }
            if (rem.predecessor() instanceof IntegerDivRemNode && ((IntegerDivRemNode)(div = (IntegerDivRemNode)rem.predecessor())).getOp() == IntegerDivRemNode.Op.DIV && ((IntegerDivRemNode)div).getType() == rem.getType() && ((FixedBinaryNode)div).getX() == rem.getX() && ((FixedBinaryNode)div).getY() == rem.getY()) {
                return div;
            }
            div = rem.graph().addOrUniqueWithInputs(this.createDiv(rem));
            if (div instanceof FixedNode) {
                rem.graph().addAfterFixed(rem, (FixedNode)div);
            }
            return div;
        }
        if (val instanceof SignedFloatingIntegerRemNode) {
            ValueNode div = val.graph().addOrUniqueWithInputs(this.createDiv(val));
            return div;
        }
        throw GraalError.shouldNotReachHereUnexpectedValue(val);
    }

    protected ValueNode createDiv(ValueNode val) {
        if (val instanceof SignedRemNode) {
            SignedRemNode rem = (SignedRemNode)val;
            return SignedDivNode.create(rem.getX(), rem.getY(), rem.getZeroGuard(), NodeView.DEFAULT);
        }
        SignedFloatingIntegerRemNode rem = (SignedFloatingIntegerRemNode)val;
        return SignedFloatingIntegerDivNode.create(((BinaryNode)val).getX(), ((BinaryNode)val).getY(), NodeView.DEFAULT, rem.getGuard(), rem.divisionOverflowIsJVMSCompliant());
    }

    protected static void optimizeSignedDiv(Canonicalizable.Binary<ValueNode> div) {
        ValueNode value;
        ValueNode forX = div.getX();
        long c = div.getY().asJavaConstant().asLong();
        if (c == 1L || c == -1L || c == 0L) {
            return;
        }
        IntegerStamp dividendStamp = (IntegerStamp)forX.stamp(NodeView.DEFAULT);
        int bitSize = dividendStamp.getBits();
        Pair<Long, Integer> nums = OptimizeDivPhase.magicDivideConstants(c, bitSize);
        long magicNum = (Long)nums.getLeft();
        int shiftNum = (Integer)nums.getRight();
        assert (NumUtil.assertNonNegativeInt(shiftNum));
        ConstantNode m = ConstantNode.forLong(magicNum);
        if (bitSize == 32) {
            value = new MulNode(new SignExtendNode(forX, 64), m);
            if (c > 0L && magicNum < 0L || c < 0L && magicNum > 0L) {
                value = NarrowNode.create(new RightShiftNode(value, ConstantNode.forInt(32)), 32, NodeView.DEFAULT);
                value = c > 0L ? BinaryArithmeticNode.add(value, forX, NodeView.DEFAULT) : BinaryArithmeticNode.sub(value, forX, NodeView.DEFAULT);
                if (shiftNum > 0) {
                    value = new RightShiftNode(value, ConstantNode.forInt(shiftNum));
                }
            } else {
                value = new RightShiftNode(value, ConstantNode.forInt(32 + shiftNum));
                value = new NarrowNode(value, 32);
            }
        } else {
            assert (bitSize == 64) : bitSize;
            value = new IntegerMulHighNode(forX, m);
            if (c > 0L && magicNum < 0L) {
                value = BinaryArithmeticNode.add(value, forX, NodeView.DEFAULT);
            } else if (c < 0L && magicNum > 0L) {
                value = BinaryArithmeticNode.sub(value, forX, NodeView.DEFAULT);
            }
            if (shiftNum > 0) {
                value = new RightShiftNode(value, ConstantNode.forInt(shiftNum));
            }
        }
        if (c < 0L) {
            s = ConstantNode.forInt(bitSize - 1);
            ValueNode sign = UnsignedRightShiftNode.create(value, s, NodeView.DEFAULT);
            value = BinaryArithmeticNode.add(value, sign, NodeView.DEFAULT);
        } else if (dividendStamp.canBeNegative()) {
            s = ConstantNode.forInt(bitSize - 1);
            ValueNode sign = UnsignedRightShiftNode.create(forX, s, NodeView.DEFAULT);
            value = BinaryArithmeticNode.add(value, sign, NodeView.DEFAULT);
        }
        StructuredGraph graph = ((ValueNode)((Object)div)).graph();
        assert (div instanceof SignedDivNode || div instanceof SignedFloatingIntegerDivNode) : "Unknown or invalid div:" + String.valueOf(div);
        OptimizeDivPhase.replacePreserveOriginalStamp(graph, (ValueNode)((Object)div), value);
    }

    private static void replacePreserveOriginalStamp(StructuredGraph graph, ValueNode originalNode, ValueNode replacement) {
        Stamp oldStamp = originalNode.stamp(NodeView.DEFAULT);
        ValueNode replacementWrapped = graph.addOrUniqueWithInputs(PiNode.create(replacement, oldStamp));
        if (originalNode instanceof FixedNode) {
            graph.replaceFixed((FixedWithNextNode)originalNode, replacementWrapped);
        } else {
            originalNode.replaceAndDelete(replacementWrapped);
        }
        graph.getOptimizationLog().report(OptimizeDivPhase.class, "DivOptimization", originalNode);
    }

    private static Pair<Long, Integer> magicDivideConstants(long divisor, int size) {
        long delta;
        long twoW = 1L << size - 1;
        long t = twoW + (divisor >>> 63);
        long ad = NumUtil.safeAbs(divisor, 64);
        long anc = t - 1L - Long.remainderUnsigned(t, ad);
        long q1 = Long.divideUnsigned(twoW, anc);
        long r1 = Long.remainderUnsigned(twoW, anc);
        long q2 = Long.divideUnsigned(twoW, ad);
        long r2 = Long.remainderUnsigned(twoW, ad);
        int p = size - 1;
        do {
            ++p;
            q1 = 2L * q1;
            if (Long.compareUnsigned(r1 = 2L * r1, anc) >= 0) {
                ++q1;
                r1 -= anc;
            }
            q2 = 2L * q2;
            if (Long.compareUnsigned(r2 = 2L * r2, ad) < 0) continue;
            ++q2;
            r2 -= ad;
        } while (Long.compareUnsigned(q1, delta = ad - r2) < 0 || q1 == delta && r1 == 0L);
        long magic = CodeUtil.signExtend((long)(q2 + 1L), (int)size);
        if (divisor < 0L) {
            magic = -magic;
        }
        return Pair.create((Object)magic, (Object)(p - size));
    }
}

