From 7f0e07492fa523fc5779fdedb653005a1c58a13d Mon Sep 17 00:00:00 2001 From: metagn Date: Mon, 21 Apr 2025 10:01:44 +0300 Subject: [PATCH] generally disallow recursive structural types, check proc param types (#24893) fixes #5631, fixes #8938, fixes #18855, fixes #19271, fixes #23885, fixes #24877 `isTupleRecursive`, previously only called to give an error for illegal recursions for: * tuple fields * types declared in type sections * explicitly instantiated generic types did not check for recursions in proc types. It now does, meaning proc types now need a nominal type layer to recurse over themselves. It is renamed to `isRecursiveStructuralType` to better reflect what it does, it is different from a recursive type that cannot exist due to a lack of pointer indirection which is possible for nominal types. It is now also called to check the param/return types of procs, similar to how tuple field types are checked. Pointer indirection checks are not needed since procs are pointers. I wondered if this would lead to a slowdown in the compiler but since it only skips structural types it shouldn't take too many iterations, not to mention only proc types are newly considered and aren't that common. But maybe something in the implementation could be inefficient, like the cycle detector using an IntSet. Note: The name `isRecursiveStructuralType` is not exactly correct because it still checks for `distinct` types. If it didn't, then the compiler would accept this: ```nim type A = distinct B B = ref A ``` But this breaks when attempting to write `var x: A`. However this is not the case for: ```nim type A = object x: B B = ref A ``` So a better description would be "types that are structural on the backend". A future step to deal with #14015 and #23224 might be to check the arguments of `tyGenericInst` as well but I don't know if this makes perfect sense. --- compiler/seminst.nim | 4 +++ compiler/semtypes.nim | 10 ++++-- compiler/semtypinst.nim | 2 +- compiler/types.nim | 23 +++++++++---- tests/errmsgs/trecursiveproctype1.nim | 10 ++++++ tests/errmsgs/trecursiveproctype2.nim | 18 ++++++++++ tests/errmsgs/trecursiveproctype3.nim | 9 +++++ tests/errmsgs/trecursiveproctype4.nim | 10 ++++++ tests/errmsgs/trecursiveproctype5.nim | 49 +++++++++++++++++++++++++++ tests/errmsgs/trecursiveproctype6.nim | 10 ++++++ 10 files changed, 135 insertions(+), 10 deletions(-) create mode 100644 tests/errmsgs/trecursiveproctype1.nim create mode 100644 tests/errmsgs/trecursiveproctype2.nim create mode 100644 tests/errmsgs/trecursiveproctype3.nim create mode 100644 tests/errmsgs/trecursiveproctype4.nim create mode 100644 tests/errmsgs/trecursiveproctype5.nim create mode 100644 tests/errmsgs/trecursiveproctype6.nim diff --git a/compiler/seminst.nim b/compiler/seminst.nim index ab00810f92..c23e3f80d8 100644 --- a/compiler/seminst.nim +++ b/compiler/seminst.nim @@ -308,6 +308,8 @@ proc instantiateProcType(c: PContext, pt: LayeredIdTable, param.typ = result[i] result.n[i] = newSymNode(param) + if isRecursiveStructuralType(result[i]): + localError(c.config, originalParams[i].sym.info, "illegal recursion in type '" & typeToString(result[i]) & "'") propagateToOwner(result, result[i]) addDecl(c, param) @@ -318,6 +320,8 @@ proc instantiateProcType(c: PContext, pt: LayeredIdTable, cl.isReturnType = false result.n[0] = originalParams[0].copyTree if result[0] != nil: + if isRecursiveStructuralType(result[0]): + localError(c.config, originalParams[0].info, "illegal recursion in type '" & typeToString(result[0]) & "'") propagateToOwner(result, result[0]) eraseVoidParams(result) diff --git a/compiler/semtypes.nim b/compiler/semtypes.nim index 41189fc7f8..9ebc930079 100644 --- a/compiler/semtypes.nim +++ b/compiler/semtypes.nim @@ -576,7 +576,7 @@ proc semTuple(c: PContext, n: PNode, prev: PType): PType = styleCheckDef(c, a[j].info, field) onDef(field.info, field) if result.n.len == 0: result.n = nil - if isTupleRecursive(result): + if isRecursiveStructuralType(result): localError(c.config, n.info, errIllegalRecursionInTypeX % typeToString(result)) proc semIdentVis(c: PContext, kind: TSymKind, n: PNode, @@ -1500,6 +1500,8 @@ proc semProcTypeNode(c: PContext, n, genericParams: PNode, if isType: localError(c.config, a.info, "':' expected") if kind in {skTemplate, skMacro}: typ = newTypeS(tyUntyped, c) + elif isRecursiveStructuralType(typ): + localError(c.config, a[^2].info, errIllegalRecursionInTypeX % typeToString(typ)) elif skipTypes(typ, {tyGenericInst, tyAlias, tySink}).kind == tyVoid: continue @@ -1563,7 +1565,9 @@ proc semProcTypeNode(c: PContext, n, genericParams: PNode, if r != nil: # turn explicit 'void' return type into 'nil' because the rest of the # compiler only checks for 'nil': - if skipTypes(r, {tyGenericInst, tyAlias, tySink}).kind != tyVoid: + if isRecursiveStructuralType(r): + localError(c.config, n.info, errIllegalRecursionInTypeX % typeToString(r)) + elif skipTypes(r, {tyGenericInst, tyAlias, tySink}).kind != tyVoid: if kind notin {skMacro, skTemplate} and r.kind in {tyTyped, tyUntyped}: localError(c.config, n[0].info, "return type '" & typeToString(r) & "' is only valid for macros and templates") @@ -1751,7 +1755,7 @@ proc semGeneric(c: PContext, n: PNode, s: PSym, prev: PType): PType = # special check for generic object with # generic/partial specialized parent let tx = result.skipTypes(abstractPtrs, 50) - if tx.isNil or isTupleRecursive(tx): + if tx.isNil or isRecursiveStructuralType(tx): localError(c.config, n.info, "illegal recursion in type '$1'" % typeToString(result[0])) return errorType(c) if tx != result and tx.kind == tyObject: diff --git a/compiler/semtypinst.nim b/compiler/semtypinst.nim index 4637ea4046..daee9ba4fc 100644 --- a/compiler/semtypinst.nim +++ b/compiler/semtypinst.nim @@ -28,7 +28,7 @@ proc checkConstructedType*(conf: ConfigRef; info: TLineInfo, typ: PType) = if t.kind in tyTypeClasses: discard elif t.kind in {tyVar, tyLent} and t.elementType.kind in {tyVar, tyLent}: localError(conf, info, "type 'var var' is not allowed") - elif computeSize(conf, t) == szIllegalRecursion or isTupleRecursive(t): + elif computeSize(conf, t) == szIllegalRecursion or isRecursiveStructuralType(t): localError(conf, info, "illegal recursion in type '" & typeToString(t) & "'") proc searchInstTypes*(g: ModuleGraph; key: PType): PType = diff --git a/compiler/types.nim b/compiler/types.nim index 6f098b1c3e..8744f173ce 100644 --- a/compiler/types.nim +++ b/compiler/types.nim @@ -1897,7 +1897,7 @@ proc typeMismatch*(conf: ConfigRef; info: TLineInfo, formal, actual: PType, n: P processPragmaAndCallConvMismatch(msg, a, b, conf) localError(conf, info, msg) -proc isTupleRecursive(t: PType, cycleDetector: var IntSet): bool = +proc isRecursiveStructuralType(t: PType, cycleDetector: var IntSet): bool = if t == nil: return false if cycleDetector.containsOrIncl(t.id): @@ -1908,19 +1908,30 @@ proc isTupleRecursive(t: PType, cycleDetector: var IntSet): bool = var cycleDetectorCopy: IntSet for a in t.kids: cycleDetectorCopy = cycleDetector - if isTupleRecursive(a, cycleDetectorCopy): + if isRecursiveStructuralType(a, cycleDetectorCopy): + return true + of tyProc: + result = false + var cycleDetectorCopy: IntSet + if t.returnType != nil: + cycleDetectorCopy = cycleDetector + if isRecursiveStructuralType(t.returnType, cycleDetectorCopy): + return true + for _, a in t.paramTypes: + cycleDetectorCopy = cycleDetector + if isRecursiveStructuralType(a, cycleDetectorCopy): return true of tyRef, tyPtr, tyVar, tyLent, tySink, tyArray, tyUncheckedArray, tySequence, tyDistinct: - return isTupleRecursive(t.elementType, cycleDetector) + return isRecursiveStructuralType(t.elementType, cycleDetector) of tyAlias, tyGenericInst: - return isTupleRecursive(t.skipModifier, cycleDetector) + return isRecursiveStructuralType(t.skipModifier, cycleDetector) else: return false -proc isTupleRecursive*(t: PType): bool = +proc isRecursiveStructuralType*(t: PType): bool = var cycleDetector = initIntSet() - isTupleRecursive(t, cycleDetector) + isRecursiveStructuralType(t, cycleDetector) proc isException*(t: PType): bool = # check if `y` is object type and it inherits from Exception diff --git a/tests/errmsgs/trecursiveproctype1.nim b/tests/errmsgs/trecursiveproctype1.nim new file mode 100644 index 0000000000..0bd5b8e0dc --- /dev/null +++ b/tests/errmsgs/trecursiveproctype1.nim @@ -0,0 +1,10 @@ +discard """ + errormsg: "illegal recursion in type 'Behavior'" + line: 10 +""" + +# issue #5631 + +type + Behavior = proc(): Effect + Effect = proc(behavior: Behavior): Behavior diff --git a/tests/errmsgs/trecursiveproctype2.nim b/tests/errmsgs/trecursiveproctype2.nim new file mode 100644 index 0000000000..60306278dd --- /dev/null +++ b/tests/errmsgs/trecursiveproctype2.nim @@ -0,0 +1,18 @@ +discard """ + errormsg: "illegal recursion in type 'B'" + line: 9 +""" + +# issue #8938 + +type + A = proc(acc, x: int, y: B): int + B = proc(acc, x: int, y: A): int + +proc fact(n: int): int = + proc g(acc, a: int, b: proc(acc, a: int, b: A): int): A = + if a == 0: + acc + else: + b(a * acc, a - 1, b) + g(1, n, g) diff --git a/tests/errmsgs/trecursiveproctype3.nim b/tests/errmsgs/trecursiveproctype3.nim new file mode 100644 index 0000000000..288bb27909 --- /dev/null +++ b/tests/errmsgs/trecursiveproctype3.nim @@ -0,0 +1,9 @@ +discard """ + errormsg: "illegal recursion in type 'ptr MyFunc'" + line: 9 +""" + +# issue #19271 + +type + MyFunc = proc(f: ptr MyFunc) diff --git a/tests/errmsgs/trecursiveproctype4.nim b/tests/errmsgs/trecursiveproctype4.nim new file mode 100644 index 0000000000..860ae313dd --- /dev/null +++ b/tests/errmsgs/trecursiveproctype4.nim @@ -0,0 +1,10 @@ +discard """ + errormsg: "illegal recursion in type 'BB'" + line: 9 +""" + +# issue #23885 + +type + EventHandler = proc(target: BB) + BB = (EventHandler,) diff --git a/tests/errmsgs/trecursiveproctype5.nim b/tests/errmsgs/trecursiveproctype5.nim new file mode 100644 index 0000000000..58237959b2 --- /dev/null +++ b/tests/errmsgs/trecursiveproctype5.nim @@ -0,0 +1,49 @@ +discard """ + errormsg: "illegal recursion in type 'seq[Shape[system.float32]]" + line: 20 +""" + +# issue #24877 + +type + ValT = float32|float64 + Square[T: ValT] = object + inner: seq[Shape[T]] + Circle[T: ValT] = object + inner: seq[Shape[T]] + + InnerShapesProc[T: ValT] = proc(): seq[Shape[T]] + Shape[T: ValT] = tuple[ + innerShapes: InnerShapesProc[T], + ] + +func newSquare[T: ValT](inner: seq[Shape[T]] = @[]): Square[T] = + Square[T](inner: inner) + +proc innerShapes[T: ValT](sq: Square[T]): seq[Shape[T]] = sq.inner +proc iInnerShapes[T: ValT](sq: Square[T]): InnerShapesProc[T] = + proc(): seq[Shape[T]] = sq.innerShapes() + +func toShape[T: ValT](sq: Square[T]): Shape[T] = + (innerShapes: sq.iInnerShapes()) + +func newCircle[T: ValT](inner: seq[Shape[T]] = @[]): Circle[T] = + Circle[T](inner: inner) + +proc innerShapes[T: ValT](c: Circle[T]): seq[Shape[T]] = c.inner +proc iInnerShapes[T: ValT](c: Circle[T]): InnerShapesProc[T] = + proc(): seq[Shape[T]] = c.innerShapes() + +func toShape[T: ValT](c: Circle[T]): Shape[T] = + (innerShapes: c.iInnerShapes()) + +const + sq1 = newSquare[float32]() + sq2 = newSquare[float32]() + sq3 = newSquare[float64]() + c1 = newCircle[float64](@[sq3]) + c2 = newCircle[float32](@[sq1, sq2]) + +let + shapes32 = @[sq1.toShape, sq2.toShape, c2.toShape] + shapes64 = @[sq3.toShape, c1.toShape] diff --git a/tests/errmsgs/trecursiveproctype6.nim b/tests/errmsgs/trecursiveproctype6.nim new file mode 100644 index 0000000000..2e1f2fa78e --- /dev/null +++ b/tests/errmsgs/trecursiveproctype6.nim @@ -0,0 +1,10 @@ +discard """ + errormsg: "illegal recursion in type 'Test" + line: 9 +""" + +# issue #18855 + +type + TestProc = proc(a: Test) + Test = Test