From a21a1c99cdd3602acd0cf18fa5ed29273b65e4d6 Mon Sep 17 00:00:00 2001 From: ringabout <43030857+ringabout@users.noreply.github.com> Date: Mon, 29 Dec 2025 17:25:56 +0800 Subject: [PATCH] fixes #19983; implements bitmasked bitshifting for all backends (#25390) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit replaces https://github.com/nim-lang/Nim/pull/11555 fixes https://github.com/nim-lang/Nim/issues/19983 fixes https://github.com/nim-lang/Nim/issues/13566 - [x] JS backend --------- Co-authored-by: Arne Döring (cherry picked from commit f1b97caf92dab122063a598b680a464b438e74bc) --- changelog.md | 2 ++ compiler/ccgexprs.nim | 6 ++-- compiler/jsgen.nim | 29 +++++++++-------- compiler/semfold.nim | 36 +++++++++++---------- compiler/vmgen.nim | 45 ++++++++++++++++++++++---- lib/system/arithmetics.nim | 15 ++++++--- tests/int/tarithm.nim | 65 ++++++++++++++++++++++++++++++++++++++ 7 files changed, 154 insertions(+), 44 deletions(-) diff --git a/changelog.md b/changelog.md index 38f87d1328..2ed76f8eb2 100644 --- a/changelog.md +++ b/changelog.md @@ -31,6 +31,8 @@ errors. - The second parameter of `succ`, `pred`, `inc`, and `dec` in `system` now accepts `SomeInteger` (previously `Ordinal`). +- Bitshift operators (`shl`, `shr`, `ashr`) now apply bitmasking to the right operand in the C/C++/VM/JS backends. + ## Standard library additions and changes [//]: # "Additions:" diff --git a/compiler/ccgexprs.nim b/compiler/ccgexprs.nim index e7dfc9b22a..274477fc8b 100644 --- a/compiler/ccgexprs.nim +++ b/compiler/ccgexprs.nim @@ -660,9 +660,9 @@ proc binaryArith(p: BProc, e: PNode, d: var TLoc, op: TMagic) = of mSubF64: applyFormat("(($4)($1) - ($4)($2))") of mMulF64: applyFormat("(($4)($1) * ($4)($2))") of mDivF64: applyFormat("(($4)($1) / ($4)($2))") - of mShrI: applyFormat("($4)((NU$5)($1) >> (NU$3)($2))") - of mShlI: applyFormat("($4)((NU$3)($1) << (NU$3)($2))") - of mAshrI: applyFormat("($4)((NI$3)($1) >> (NU$3)($2))") + of mShrI: applyFormat("($4)((NU$5)($1) >> (NU$3)($2 & ($5 - 1)))") + of mShlI: applyFormat("($4)((NU$3)($1) << (NU$3)($2 & ($5 - 1)))") + of mAshrI: applyFormat("($4)((NI$3)($1) >> (NU$3)($2 & ($5 - 1)))") of mBitandI: applyFormat("($4)($1 & $2)") of mBitorI: applyFormat("($4)($1 | $2)") of mBitxorI: applyFormat("($4)($1 ^ $2)") diff --git a/compiler/jsgen.nim b/compiler/jsgen.nim index bfc15a6120..8f966b7dd1 100644 --- a/compiler/jsgen.nim +++ b/compiler/jsgen.nim @@ -728,44 +728,47 @@ proc arithAux(p: PProc, n: PNode, r: var TCompRes, op: TMagic) = of mShrI: let typ = n[1].typ.skipTypes(abstractVarRange) if typ.kind == tyInt64 and optJsBigInt64 in p.config.globalOptions: - applyFormat("BigInt.asIntN(64, BigInt.asUintN(64, $1) >> BigInt($2))") + applyFormat("BigInt.asIntN(64, BigInt.asUintN(64, $1) >> (BigInt($2) & 63n))") elif typ.kind == tyUInt64 and optJsBigInt64 in p.config.globalOptions: - applyFormat("($1 >> BigInt($2))") + applyFormat("($1 >> (BigInt($2) & 63n))") else: + let bitmask = typ.size * 8 - 1 if typ.kind in {tyInt..tyInt32}: let trimmerU = unsignedTrimmer(typ.size) let trimmerS = signedTrimmer(typ.size) - r.res = "((($1 $2) >>> $3) $4)" % [xLoc, trimmerU, yLoc, trimmerS] + r.res = "((($1 $2) >>> ($3 & $5)) $4)" % [xLoc, trimmerU, yLoc, trimmerS, $bitmask] else: - applyFormat("($1 >>> $2)") + r.res = "($1 >>> ($2 & $3))" % [xLoc, yLoc, $bitmask] of mShlI: let typ = n[1].typ.skipTypes(abstractVarRange) if typ.size == 8: if typ.kind == tyInt64 and optJsBigInt64 in p.config.globalOptions: - applyFormat("BigInt.asIntN(64, $1 << BigInt($2))") + applyFormat("BigInt.asIntN(64, $1 << (BigInt($2) & 63n))") elif typ.kind == tyUInt64 and optJsBigInt64 in p.config.globalOptions: - applyFormat("BigInt.asUintN(64, $1 << BigInt($2))") + applyFormat("BigInt.asUintN(64, $1 << (BigInt($2) & 63n))") else: - applyFormat("($1 * Math.pow(2, $2))") + applyFormat("($1 * Math.pow(2, ($2 & 63)))") else: + let bitmask = typ.size * 8 - 1 if typ.kind in {tyUInt..tyUInt32}: let trimmer = unsignedTrimmer(typ.size) - r.res = "(($1 << $2) $3)" % [xLoc, yLoc, trimmer] + r.res = "(($1 << ($2 & $4)) $3)" % [xLoc, yLoc, trimmer, $bitmask] else: let trimmer = signedTrimmer(typ.size) - r.res = "(($1 << $2) $3)" % [xLoc, yLoc, trimmer] + r.res = "(($1 << ($2 & $4)) $3)" % [xLoc, yLoc, trimmer, $bitmask] of mAshrI: let typ = n[1].typ.skipTypes(abstractVarRange) if typ.size == 8: if optJsBigInt64 in p.config.globalOptions: - applyFormat("($1 >> BigInt($2))") + applyFormat("($1 >> (BigInt($2) & 63n))") else: - applyFormat("Math.floor($1 / Math.pow(2, $2))") + applyFormat("Math.floor($1 / Math.pow(2, ($2 & 63)))") else: + let bitmask = typ.size * 8 - 1 if typ.kind in {tyUInt..tyUInt32}: - applyFormat("($1 >>> $2)") + r.res = "($1 >>> ($2 & $3)))" % [xLoc, yLoc, $bitmask] else: - applyFormat("($1 >> $2)") + r.res = "($1 >> ($2 & $3))" % [xLoc, yLoc, $bitmask] of mBitandI: bitwiseExpr("&") of mBitorI: bitwiseExpr("|") of mBitxorI: bitwiseExpr("^") diff --git a/compiler/semfold.nim b/compiler/semfold.nim index 634b645f68..f9bffd8ace 100644 --- a/compiler/semfold.nim +++ b/compiler/semfold.nim @@ -179,29 +179,30 @@ proc evalOp(m: TMagic, n, a, b, c: PNode; idgen: IdGenerator; g: ModuleGraph): P let argB = getInt(b) result = newIntNodeT(if argA > argB: argA else: argB, n, idgen, g) of mShlI: + let valueB = toInt64(getInt(b)) and (n.typ.size * 8 - 1) case skipTypes(n.typ, abstractRange).kind - of tyInt8: result = newIntNodeT(toInt128(toInt8(getInt(a)) shl toInt64(getInt(b))), n, idgen, g) - of tyInt16: result = newIntNodeT(toInt128(toInt16(getInt(a)) shl toInt64(getInt(b))), n, idgen, g) - of tyInt32: result = newIntNodeT(toInt128(toInt32(getInt(a)) shl toInt64(getInt(b))), n, idgen, g) - of tyInt64: result = newIntNodeT(toInt128(toInt64(getInt(a)) shl toInt64(getInt(b))), n, idgen, g) + of tyInt8: result = newIntNodeT(toInt128(toInt8(getInt(a)) shl valueB), n, idgen, g) + of tyInt16: result = newIntNodeT(toInt128(toInt16(getInt(a)) shl valueB), n, idgen, g) + of tyInt32: result = newIntNodeT(toInt128(toInt32(getInt(a)) shl valueB), n, idgen, g) + of tyInt64: result = newIntNodeT(toInt128(toInt64(getInt(a)) shl valueB), n, idgen, g) of tyInt: if g.config.target.intSize == 4: - result = newIntNodeT(toInt128(toInt32(getInt(a)) shl toInt64(getInt(b))), n, idgen, g) + result = newIntNodeT(toInt128(toInt32(getInt(a)) shl valueB), n, idgen, g) else: - result = newIntNodeT(toInt128(toInt64(getInt(a)) shl toInt64(getInt(b))), n, idgen, g) - of tyUInt8: result = newIntNodeT(toInt128(toUInt8(getInt(a)) shl toInt64(getInt(b))), n, idgen, g) - of tyUInt16: result = newIntNodeT(toInt128(toUInt16(getInt(a)) shl toInt64(getInt(b))), n, idgen, g) - of tyUInt32: result = newIntNodeT(toInt128(toUInt32(getInt(a)) shl toInt64(getInt(b))), n, idgen, g) - of tyUInt64: result = newIntNodeT(toInt128(toUInt64(getInt(a)) shl toInt64(getInt(b))), n, idgen, g) + result = newIntNodeT(toInt128(toInt64(getInt(a)) shl valueB), n, idgen, g) + of tyUInt8: result = newIntNodeT(toInt128(toUInt8(getInt(a)) shl valueB), n, idgen, g) + of tyUInt16: result = newIntNodeT(toInt128(toUInt16(getInt(a)) shl valueB), n, idgen, g) + of tyUInt32: result = newIntNodeT(toInt128(toUInt32(getInt(a)) shl valueB), n, idgen, g) + of tyUInt64: result = newIntNodeT(toInt128(toUInt64(getInt(a)) shl valueB), n, idgen, g) of tyUInt: if g.config.target.intSize == 4: - result = newIntNodeT(toInt128(toUInt32(getInt(a)) shl toInt64(getInt(b))), n, idgen, g) + result = newIntNodeT(toInt128(toUInt32(getInt(a)) shl valueB), n, idgen, g) else: - result = newIntNodeT(toInt128(toUInt64(getInt(a)) shl toInt64(getInt(b))), n, idgen, g) + result = newIntNodeT(toInt128(toUInt64(getInt(a)) shl valueB), n, idgen, g) else: internalError(g.config, n.info, "constant folding for shl") of mShrI: var a = cast[uint64](getInt(a)) - let b = cast[uint64](getInt(b)) + let b = cast[uint64](getInt(b)) and cast[uint64](n.typ.size * 8 - 1) # To support the ``-d:nimOldShiftRight`` flag, we need to mask the # signed integers to cut off the extended sign bit in the internal # representation. @@ -220,12 +221,13 @@ proc evalOp(m: TMagic, n, a, b, c: PNode; idgen: IdGenerator; g: ModuleGraph): P let c = cast[BiggestInt](a shr b) result = newIntNodeT(toInt128(c), n, idgen, g) of mAshrI: + let valueB = toInt64(getInt(b)) and (n.typ.size * 8 - 1) case skipTypes(n.typ, abstractRange).kind - of tyInt8: result = newIntNodeT(toInt128(ashr(toInt8(getInt(a)), toInt8(getInt(b)))), n, idgen, g) - of tyInt16: result = newIntNodeT(toInt128(ashr(toInt16(getInt(a)), toInt16(getInt(b)))), n, idgen, g) - of tyInt32: result = newIntNodeT(toInt128(ashr(toInt32(getInt(a)), toInt32(getInt(b)))), n, idgen, g) + of tyInt8: result = newIntNodeT(toInt128(ashr(toInt8(getInt(a)), valueB)), n, idgen, g) + of tyInt16: result = newIntNodeT(toInt128(ashr(toInt16(getInt(a)), valueB)), n, idgen, g) + of tyInt32: result = newIntNodeT(toInt128(ashr(toInt32(getInt(a)), valueB)), n, idgen, g) of tyInt64, tyInt: - result = newIntNodeT(toInt128(ashr(toInt64(getInt(a)), toInt64(getInt(b)))), n, idgen, g) + result = newIntNodeT(toInt128(ashr(toInt64(getInt(a)), valueB)), n, idgen, g) else: internalError(g.config, n.info, "constant folding for ashr") of mDivI: let argA = getInt(a) diff --git a/compiler/vmgen.nim b/compiler/vmgen.nim index 2a37c3fd7e..8deb3048e4 100644 --- a/compiler/vmgen.nim +++ b/compiler/vmgen.nim @@ -1074,6 +1074,19 @@ proc whichAsgnOpc(n: PNode; requiresCopy = true): TOpcode = else: (if requiresCopy: opcAsgnComplex else: opcFastAsgnComplex) +proc sizeLog2(typeSize: BiggestInt): TRegister = + case typeSize: + of 8: + result = 3 + of 16: + result = 4 + of 32: + result = 5 + of 64: + result = 6 + else: + raiseAssert $(typeSize) + proc genMagic(c: PCtx; n: PNode; dest: var TDest; flags: TGenFlags = {}, m: TMagic) = case m of mAnd: c.genAndOr(n, opcFJmp, dest) @@ -1159,24 +1172,42 @@ proc genMagic(c: PCtx; n: PNode; dest: var TDest; flags: TGenFlags = {}, m: TMag of mDivF64: genBinaryABC(c, n, dest, opcDivFloat) of mShrI: # modified: genBinaryABC(c, n, dest, opcShrInt) - # narrowU is applied to the left operandthe idea here is to narrow the left operand + # narrowU is applied to the left operand the idea here is to narrow the left operand + let typ = skipTypes(n.typ, abstractVar-{tyTypeDesc}) + let size = getSize(c.config, typ) let tmp = c.genx(n[1]) c.genNarrowU(n, tmp) let tmp2 = c.genx(n[2]) if dest < 0: dest = c.getTemp(n.typ) + c.gABC(n, opcNarrowU, tmp2, sizeLog2(size * 8)) c.gABC(n, opcShrInt, dest, tmp, tmp2) c.freeTemp(tmp) c.freeTemp(tmp2) of mShlI: - genBinaryABC(c, n, dest, opcShlInt) + let typ = skipTypes(n.typ, abstractVar-{tyTypeDesc}) + let size = getSize(c.config, typ) + let tmp1 = c.genx(n[1]) + let tmp2 = c.genx(n[2]) + if dest < 0: dest = c.getTemp(n.typ) + c.gABC(n, opcNarrowU, tmp2, sizeLog2(size * 8)) + c.gABC(n, opcShlInt, dest, tmp1, tmp2) + c.freeTemp(tmp1) + c.freeTemp(tmp2) # genNarrowU modified - let t = skipTypes(n.typ, abstractVar-{tyTypeDesc}) - let size = getSize(c.config, t) - if t.kind in {tyUInt8..tyUInt32} or (t.kind == tyUInt and size < 8): + if typ.kind in {tyUInt8..tyUInt32} or (typ.kind == tyUInt and size < 8): c.gABC(n, opcNarrowU, dest, TRegister(size*8)) - elif t.kind in {tyInt8..tyInt32} or (t.kind == tyInt and size < 8): + elif typ.kind in {tyInt8..tyInt32} or (typ.kind == tyInt and size < 8): c.gABC(n, opcSignExtend, dest, TRegister(size*8)) - of mAshrI: genBinaryABC(c, n, dest, opcAshrInt) + of mAshrI: + let typ = skipTypes(n.typ, abstractVar-{tyTypeDesc}) + let size = getSize(c.config, typ) + let tmp1 = c.genx(n[1]) + let tmp2 = c.genx(n[2]) + if dest < 0: dest = c.getTemp(n.typ) + c.gABC(n, opcNarrowU, tmp2, sizeLog2(size * 8)) + c.gABC(n, opcAshrInt, dest, tmp1, tmp2) + c.freeTemp(tmp1) + c.freeTemp(tmp2) of mBitandI: genBinaryABC(c, n, dest, opcBitandInt) of mBitorI: genBinaryABC(c, n, dest, opcBitorInt) of mBitxorI: genBinaryABC(c, n, dest, opcBitxorInt) diff --git a/lib/system/arithmetics.nim b/lib/system/arithmetics.nim index 1cfb9886e5..c06ea675a4 100644 --- a/lib/system/arithmetics.nim +++ b/lib/system/arithmetics.nim @@ -134,7 +134,10 @@ when defined(nimOldShiftRight): else: proc `shr`*(x: int, y: SomeInteger): int {.magic: "AshrI", noSideEffect.} = ## Computes the `shift right` operation of `x` and `y`, filling - ## vacant bit positions with the sign bit. + ## vacant bit positions with the sign bit. `y` (the number of + ## positions to shift) is reduced to modulo `sizeof(x) * 8`. + ## That is `15'i32 shr 35` is equivalent to `15'i32 shr 3` + ## bitmasked to always be in the range `0 ..< sizeof(int)`. ## ## **Note**: `Operator precedence `_ ## is different than in *C*. @@ -156,7 +159,9 @@ else: proc `shl`*(x: int, y: SomeInteger): int {.magic: "ShlI", noSideEffect.} = - ## Computes the `shift left` operation of `x` and `y`. + ## Computes the `shift left` operation of `x` and `y`. `y` (the number of + ## positions to shift) is reduced to modulo `sizeof(x) * 8`. + ## That is `15'i32 shl 35` is equivalent to `15'i32 shl 3`. ## ## **Note**: `Operator precedence `_ ## is different than in *C*. @@ -170,7 +175,9 @@ proc `shl`*(x: int64, y: SomeInteger): int64 {.magic: "ShlI", noSideEffect.} proc ashr*(x: int, y: SomeInteger): int {.magic: "AshrI", noSideEffect.} = ## Shifts right by pushing copies of the leftmost bit in from the left, - ## and let the rightmost bits fall off. + ## and let the rightmost bits fall off. `y` (the number of + ## positions to shift) is reduced to modulo `sizeof(x) * 8`. + ## That is `ashr(15'i32, 35)` is equivalent to `ashr(15'i32, 3)`. ## ## Note that `ashr` is not an operator so use the normal function ## call syntax for it. @@ -179,7 +186,7 @@ proc ashr*(x: int, y: SomeInteger): int {.magic: "AshrI", noSideEffect.} = ## * `shr func<#shr,int,SomeInteger>`_ runnableExamples: assert ashr(0b0001_0000'i8, 2) == 0b0000_0100'i8 - assert ashr(0b1000_0000'i8, 8) == 0b1111_1111'i8 + assert ashr(0b1000_0000'i8, 8) == 0b1000_0000'i8 assert ashr(0b1000_0000'i8, 1) == 0b1100_0000'i8 proc ashr*(x: int8, y: SomeInteger): int8 {.magic: "AshrI", noSideEffect.} proc ashr*(x: int16, y: SomeInteger): int16 {.magic: "AshrI", noSideEffect.} diff --git a/tests/int/tarithm.nim b/tests/int/tarithm.nim index d0943d225d..ff770e54f3 100644 --- a/tests/int/tarithm.nim +++ b/tests/int/tarithm.nim @@ -14,6 +14,7 @@ int32 0 tUnsignedOps OK ''' +targets: "c cpp js" nimout: "tUnsignedOps OK" """ @@ -185,3 +186,67 @@ block tUnsignedOps: testUnsignedOps() static: testUnsignedOps() + +block tshl: + # Signed types + block: + const t0: int8 = 1'i8 shl 8 + const t1: int16 = 1'i16 shl 16 + const t2: int32 = 1'i32 shl 32 + const t3: int64 = 1'i64 shl 64 + doAssert t0 == 1 + doAssert t1 == 1 + doAssert t2 == 1 + doAssert t3 == 1 + + # Unsigned types + block: + const t0: uint8 = 1'u8 shl 8 + const t1: uint16 = 1'u16 shl 16 + const t2: uint32 = 1'u32 shl 32 + const t3: uint64 = 1'u64 shl 64 + doAssert t0 == 1 + doAssert t1 == 1 + doAssert t2 == 1 + doAssert t3 == 1 + +block bitmaking: + + # test semfold (single expression) + doAssert (0x10'i8 shr 2) == (0x10'i8 shr 0b1010_1010) + doAssert (0x10'u8 shr 2) == (0x10'u8 shr 0b0101_1010) + doAssert (0x10'i16 shr 2) == (0x10'i16 shr 0b1011_0010) + doAssert (0x10'u16 shr 2) == (0x10'u16 shr 0b0101_0010) + doAssert (0x10'i32 shr 2) == (0x10'i32 shr 0b1010_0010) + doAssert (0x10'u32 shr 2) == (0x10'u32 shr 0b0110_0010) + doAssert (0x10'i64 shr 2) == (0x10'i32 shr 0b1100_0010) + doAssert (0x10'u64 shr 2) == (0x10'u32 shr 0b0100_0010) + + doAssert (0x10'i8 shl 2) == (0x10'i8 shl 0b1010_1010) + doAssert (0x10'u8 shl 2) == (0x10'u8 shl 0b0101_1010) + doAssert (0x10'i16 shl 2) == (0x10'i16 shl 0b1011_0010) + doAssert (0x10'u16 shl 2) == (0x10'u16 shl 0b0101_0010) + doAssert (0x10'i32 shl 2) == (0x10'i32 shl 0b1010_0010) + doAssert (0x10'u32 shl 2) == (0x10'u32 shl 0b0110_0010) + doAssert (0x10'i64 shl 2) == (0x10'i32 shl 0b1100_0010) + doAssert (0x10'u64 shl 2) == (0x10'u32 shl 0b0100_0010) + + proc testVmAndBackend[T: SomeInteger](a: T, b1, b2: int) {.sideeffect.} = + # this echo is to cause a side effect and therefore ensure this + # proc isn't evaluated at compile time when it should not. + doAssert((a shr b1) == (a shr b2)) + doAssert((a shl b1) == (a shl b2)) + + proc callTestVmAndBackend() = + testVmAndBackend(0x10'i8, 2, 0b1010_1010) + testVmAndBackend(0x10'u8, 2, 0b0101_1010) + testVmAndBackend(0x10'i16, 2, 0b1011_0010) + testVmAndBackend(0x10'u16, 2, 0b0101_0010) + testVmAndBackend(0x10'i32, 2, 0b1010_0010) + testVmAndBackend(0x10'u32, 2, 0b0110_0010) + testVmAndBackend(0x10'i64, 2, 0b1100_0010) + testVmAndBackend(0x10'u64, 2, 0b0100_0010) + + callTestVmAndBackend() # test at runtime + static: + callTestVmAndBackend() # test at compiletime