From e92fdb4a99bf9d27009dd35fdd074ff14facfc03 Mon Sep 17 00:00:00 2001 From: gingerBill Date: Thu, 5 Mar 2020 20:34:30 +0000 Subject: [PATCH] `x if cond else y` and `x when cond else y` expressions --- core/c/c.odin | 6 +- core/fmt/fmt.odin | 10 +-- core/log/file_console_logger.odin | 2 +- core/log/log.odin | 2 +- core/math/linalg/general.odin | 6 +- core/math/linalg/specific.odin | 32 ++++++--- core/math/math.odin | 8 +-- core/odin/ast/ast.odin | 18 +++++ core/odin/parser/parser.odin | 39 ++++++++++- core/odin/tokenizer/token.odin | 6 +- core/reflect/types.odin | 2 +- core/runtime/internal.odin | 10 +-- core/strconv/generic_float.odin | 4 +- core/strings/strings.odin | 2 +- core/time/time.odin | 2 +- src/check_expr.cpp | 109 ++++++++++++++++++++++++++++++ src/check_type.cpp | 20 ++++++ src/ir.cpp | 43 ++++++++++++ src/parser.cpp | 53 +++++++++++++++ src/parser.hpp | 6 +- 20 files changed, 336 insertions(+), 44 deletions(-) diff --git a/core/c/c.odin b/core/c/c.odin index 3ed49c0c2..757abcac9 100644 --- a/core/c/c.odin +++ b/core/c/c.odin @@ -14,8 +14,8 @@ ushort :: b.u16; int :: b.i32; uint :: b.u32; -long :: (ODIN_OS == "windows" || size_of(b.rawptr) == 4) ? b.i32 : b.i64; -ulong :: (ODIN_OS == "windows" || size_of(b.rawptr) == 4) ? b.u32 : b.u64; +long :: b.i32 when (ODIN_OS == "windows" || size_of(b.rawptr) == 4) else b.i64; +ulong :: b.u32 when (ODIN_OS == "windows" || size_of(b.rawptr) == 4) else b.u64; longlong :: b.i64; ulonglong :: b.u64; @@ -32,4 +32,4 @@ ptrdiff_t :: b.int; uintptr_t :: b.uintptr; intptr_t :: b.int; -wchar_t :: (ODIN_OS == "windows") ? b.u16 : b.u32; +wchar_t :: b.u16 when (ODIN_OS == "windows") else b.u32; diff --git a/core/fmt/fmt.odin b/core/fmt/fmt.odin index c061c1c47..626b0797d 100644 --- a/core/fmt/fmt.odin +++ b/core/fmt/fmt.odin @@ -732,7 +732,7 @@ fmt_float :: proc(fi: ^Info, v: f64, bit_size: int, verb: rune) { } strings.write_string(fi.buf, "0h"); - _fmt_int(fi, u, 16, false, bit_size, verb == 'h' ? __DIGITS_LOWER : __DIGITS_UPPER); + _fmt_int(fi, u, 16, false, bit_size, __DIGITS_LOWER if verb == 'h' else __DIGITS_UPPER); case: @@ -1154,7 +1154,7 @@ fmt_value :: proc(fi: ^Info, v: any, verb: rune) { is_soa := b.soa_kind != .None; strings.write_string(fi.buf, info.name); - strings.write_byte(fi.buf, is_soa ? '[' : '{'); + strings.write_byte(fi.buf, '[' if is_soa else '{'); hash := fi.hash; defer fi.hash = hash; indent := fi.indent; defer fi.indent -= 1; @@ -1165,7 +1165,7 @@ fmt_value :: proc(fi: ^Info, v: any, verb: rune) { if hash do strings.write_byte(fi.buf, '\n'); defer { if hash do for in 0.. 0 ? fmt.tprintf(fmt_str, ..args) : fmt.tprint(fmt_str); //NOTE(Hoej): While tprint isn't thread-safe, no logging is. + str := fmt.tprintf(fmt_str, ..args) if len(args) > 0 else fmt.tprint(fmt_str); //NOTE(Hoej): While tprint isn't thread-safe, no logging is. logger.procedure(logger.data, level, str, logger.options, location); } diff --git a/core/math/linalg/general.odin b/core/math/linalg/general.odin index 921d56c80..c7b455781 100644 --- a/core/math/linalg/general.odin +++ b/core/math/linalg/general.odin @@ -54,11 +54,11 @@ normalize :: proc{vector_normalize, quaternion_normalize}; vector_normalize0 :: proc(v: $T/[$N]$E) -> T where IS_NUMERIC(E) { m := length(v); - return m == 0 ? 0 : v/m; + return 0 if m == 0 else v/m; } quaternion_normalize0 :: proc(q: $Q) -> Q where IS_QUATERNION(Q) { m := abs(q); - return m == 0 ? 0 : q/m; + return 0 if m == 0 else q/m; } normalize0 :: proc{vector_normalize0, quaternion_normalize0}; @@ -258,7 +258,7 @@ vector_mix :: proc(x, y, a: $V/[$N]$E) -> V where IS_NUMERIC(E) { vector_step :: proc(edge, x: $V/[$N]$E) -> V where IS_NUMERIC(E) { s: V; for i in 0.. Vector3 { y := abs(v.y); z := abs(v.z); - other: Vector3 = x < y ? (x < z ? {1, 0, 0} : {0, 0, 1}) : (y < z ? {0, 1, 0} : {0, 0, 1}); - + other: Vector3; + if x < y { + if x < z { + other = {1, 0, 0}; + } else { + other = {0, 0, 1}; + } + } else { + if y < z { + other = {0, 1, 0}; + } else { + other = {0, 0, 1}; + } + } return normalize(cross(v, other)); } @@ -124,7 +136,7 @@ vector4_hsl_to_rgb :: proc(h, s, l: Float, a: Float = 1) -> Vector4 { g = l; b = l; } else { - q := l < 0.5 ? l * (1+s) : l+s - l*s; + q := l * (1+s) if l < 0.5 else l+s - l*s; p := 2*l - q; r = hue_to_rgb(p, q, h + 1.0/3.0); g = hue_to_rgb(p, q, h); @@ -147,10 +159,10 @@ vector4_rgb_to_hsl :: proc(col: Vector4) -> Vector4 { if v_max != v_min { d: = v_max - v_min; - s = l > 0.5 ? d / (2.0 - v_max - v_min) : d / (v_max + v_min); + s = d / (2.0 - v_max - v_min) if l > 0.5 else d / (v_max + v_min); switch { case v_max == r: - h = (g - b) / d + (g < b ? 6.0 : 0.0); + h = (g - b) / d + (6.0 if g < b else 0.0); case v_max == g: h = (b - r) / d + 2.0; case v_max == b: @@ -627,9 +639,9 @@ matrix4_inverse :: proc(m: Matrix4) -> Matrix4 { matrix4_minor :: proc(m: Matrix4, c, r: int) -> Float { cut_down: Matrix3; for i in 0..<3 { - col := i < c ? i : i+1; + col := i if i < c else i+1; for j in 0..<3 { - row := j < r ? j : j+1; + row := j if j < r else j+1; cut_down[i][j] = m[col][row]; } } @@ -638,7 +650,7 @@ matrix4_minor :: proc(m: Matrix4, c, r: int) -> Float { matrix4_cofactor :: proc(m: Matrix4, c, r: int) -> Float { sign, minor: Float; - sign = (c + r) % 2 == 0 ? 1 : -1; + sign = 1 if (c + r) % 2 == 0 else -1; minor = matrix4_minor(m, c, r); return sign * minor; } diff --git a/core/math/math.odin b/core/math/math.odin index d5ec9f3ae..cb7231b86 100644 --- a/core/math/math.odin +++ b/core/math/math.odin @@ -115,7 +115,7 @@ unlerp :: proc{unlerp_f32, unlerp_f64}; wrap :: proc(x, y: $T) -> T where intrinsics.type_is_numeric(T), !intrinsics.type_is_array(T) { tmp := mod(x, y); - return tmp < 0 ? wrap + tmp : tmp; + return wrap + tmp if tmp < 0 else tmp; } angle_diff :: proc(a, b: $T) -> T where intrinsics.type_is_numeric(T), !intrinsics.type_is_array(T) { @@ -128,7 +128,7 @@ angle_lerp :: proc(a, b, t: $T) -> T where intrinsics.type_is_numeric(T), !intri } step :: proc(edge, x: $T) -> T where intrinsics.type_is_numeric(T), !intrinsics.type_is_array(T) { - return x < edge ? 0 : 1; + return 0 if x < edge else 1; } smoothstep :: proc(edge0, edge1, x: $T) -> T where intrinsics.type_is_numeric(T), !intrinsics.type_is_array(T) { @@ -246,10 +246,10 @@ trunc_f64 :: proc(x: f64) -> f64 { trunc :: proc{trunc_f32, trunc_f64}; round_f32 :: proc(x: f32) -> f32 { - return x < 0 ? ceil(x - 0.5) : floor(x + 0.5); + return ceil(x - 0.5) if x < 0 else floor(x + 0.5); } round_f64 :: proc(x: f64) -> f64 { - return x < 0 ? ceil(x - 0.5) : floor(x + 0.5); + return ceil(x - 0.5) if x < 0 else floor(x + 0.5); } round :: proc{round_f32, round_f64}; diff --git a/core/odin/ast/ast.odin b/core/odin/ast/ast.odin index 931000a7c..4bd50abd7 100644 --- a/core/odin/ast/ast.odin +++ b/core/odin/ast/ast.odin @@ -200,6 +200,24 @@ Ternary_Expr :: struct { y: ^Expr, } +Ternary_If_Expr :: struct { + using node: Expr, + x: ^Expr, + op1: tokenizer.Token, + cond: ^Expr, + op2: tokenizer.Token, + y: ^Expr, +} + +Ternary_When_Expr :: struct { + using node: Expr, + x: ^Expr, + op1: tokenizer.Token, + cond: ^Expr, + op2: tokenizer.Token, + y: ^Expr, +} + Type_Assertion :: struct { using node: Expr, expr: ^Expr, diff --git a/core/odin/parser/parser.odin b/core/odin/parser/parser.odin index 3f81cb61d..c4e6d8142 100644 --- a/core/odin/parser/parser.odin +++ b/core/odin/parser/parser.odin @@ -1028,7 +1028,7 @@ parse_stmt :: proc(p: ^Parser) -> ^ast.Stmt { if tok.kind != .Fallthrough && p.curr_tok.kind == .Ident { label = parse_ident(p); } - end := label != nil ? label.end : end_pos(tok); + end := label.end if label != nil else end_pos(tok); s := ast.new(ast.Branch_Stmt, tok.pos, end); expect_semicolon(p, s); return s; @@ -1132,6 +1132,8 @@ parse_stmt :: proc(p: ^Parser) -> ^ast.Stmt { token_precedence :: proc(p: ^Parser, kind: tokenizer.Token_Kind) -> int { #partial switch kind { case .Question: + case .If: + case .When: return 1; case .Ellipsis, .Range_Half: if !p.allow_range { @@ -2453,7 +2455,7 @@ parse_literal_value :: proc(p: ^Parser, type: ^ast.Expr) -> ^ast.Comp_Lit { close := expect_token_after(p, .Close_Brace, "compound literal"); - pos := type != nil ? type.pos : open.pos; + pos := type.pos if type != nil else open.pos; lit := ast.new(ast.Comp_Lit, pos, end_pos(close)); lit.type = type; lit.open = open.pos; @@ -2714,6 +2716,13 @@ parse_binary_expr :: proc(p: ^Parser, lhs: bool, prec_in: int) -> ^ast.Expr { if op_prec != prec { break; } + if op.kind == .If || op.kind == .When { + if p.prev_tok.pos.line < op.pos.line { + // NOTE(bill): Check to see if the `if` or `when` is on the same line of the `lhs` condition + break; + } + } + expect_operator(p); if op.kind == .Question { @@ -2728,6 +2737,32 @@ parse_binary_expr :: proc(p: ^Parser, lhs: bool, prec_in: int) -> ^ast.Expr { te.op2 = colon; te.y = y; + expr = te; + } else if op.kind == .If { + x := expr; + cond := parse_expr(p, lhs); + else_tok := expect_token(p, .Else); + y := parse_expr(p, lhs); + te := ast.new(ast.Ternary_If_Expr, expr.pos, end_pos(p.prev_tok)); + te.x = x; + te.op1 = op; + te.cond = cond; + te.op2 = else_tok; + te.y = y; + + expr = te; + } else if op.kind == .When { + x := expr; + cond := parse_expr(p, lhs); + op2 := expect_token(p, .Else); + y := parse_expr(p, lhs); + te := ast.new(ast.Ternary_When_Expr, expr.pos, end_pos(p.prev_tok)); + te.x = x; + te.op1 = op; + te.cond = cond; + te.op2 = else_tok; + te.y = y; + expr = te; } else { right := parse_binary_expr(p, false, prec+1); diff --git a/core/odin/tokenizer/token.odin b/core/odin/tokenizer/token.odin index c5f0247f4..997ca7ac1 100644 --- a/core/odin/tokenizer/token.odin +++ b/core/odin/tokenizer/token.odin @@ -17,13 +17,13 @@ Pos :: struct { pos_compare :: proc(lhs, rhs: Pos) -> int { if lhs.offset != rhs.offset { - return (lhs.offset < rhs.offset) ? -1 : +1; + return -1 if (lhs.offset < rhs.offset) else +1; } if lhs.line != rhs.line { - return (lhs.line < rhs.line) ? -1 : +1; + return -1 if (lhs.line < rhs.line) else +1; } if lhs.column != rhs.column { - return (lhs.column < rhs.column) ? -1 : +1; + return -1 if (lhs.column < rhs.column) else +1; } return strings.compare(lhs.file, rhs.file); } diff --git a/core/reflect/types.odin b/core/reflect/types.odin index 42ab9828a..3ea121bc3 100644 --- a/core/reflect/types.odin +++ b/core/reflect/types.odin @@ -321,7 +321,7 @@ write_type :: proc(buf: ^strings.Builder, ti: ^rt.Type_Info) { case uint: write_string(buf, "uint"); case uintptr: write_string(buf, "uintptr"); case: - write_byte(buf, info.signed ? 'i' : 'u'); + write_byte(buf, 'i' if info.signed else 'u'); write_i64(buf, i64(8*ti.size), 10); switch info.endianness { case .Platform: // Okay diff --git a/core/runtime/internal.odin b/core/runtime/internal.odin index 60359fd6c..68d613c1b 100644 --- a/core/runtime/internal.odin +++ b/core/runtime/internal.odin @@ -191,7 +191,7 @@ print_type :: proc(fd: os.Handle, ti: ^Type_Info) { case uint: os.write_string(fd, "uint"); case uintptr: os.write_string(fd, "uintptr"); case: - os.write_byte(fd, info.signed ? 'i' : 'u'); + os.write_byte(fd, 'i' if info.signed else 'u'); print_u64(fd, u64(8*ti.size)); } case Type_Info_Rune: @@ -421,7 +421,7 @@ memory_compare :: proc "contextless" (a, b: rawptr, n: int) -> int #no_bounds_ch a := (^byte)(x+pos)^; b := (^byte)(y+pos)^; if a ~ b != 0 { - return (int(a) - int(b)) < 0 ? -1 : +1; + return -1 if (int(a) - int(b)) < 0 else +1; } } } @@ -431,7 +431,7 @@ memory_compare :: proc "contextless" (a, b: rawptr, n: int) -> int #no_bounds_ch a := (^byte)(x+offset)^; b := (^byte)(y+offset)^; if a ~ b != 0 { - return (int(a) - int(b)) < 0 ? -1 : +1; + return -1 if (int(a) - int(b)) < 0 else +1; } } @@ -456,7 +456,7 @@ memory_compare_zero :: proc "contextless" (a: rawptr, n: int) -> int #no_bounds_ for pos := curr_block*SU; pos < n; pos += 1 { a := (^byte)(x+pos)^; if a ~ 0 != 0 { - return int(a) < 0 ? -1 : +1; + return -1 if int(a) < 0 else +1; } } } @@ -465,7 +465,7 @@ memory_compare_zero :: proc "contextless" (a: rawptr, n: int) -> int #no_bounds_ for /**/; offset < n; offset += 1 { a := (^byte)(x+offset)^; if a ~ 0 != 0 { - return int(a) < 0 ? -1 : +1; + return -1 if int(a) < 0 else +1; } } diff --git a/core/strconv/generic_float.odin b/core/strconv/generic_float.odin index 622dcd044..ad0c51880 100644 --- a/core/strconv/generic_float.odin +++ b/core/strconv/generic_float.odin @@ -110,7 +110,7 @@ format_digits :: proc(buf: []byte, shortest: bool, neg: bool, digs: Decimal_Slic switch fmt { case 'f', 'F': - add_bytes(&b, neg ? '-' : '+'); + add_bytes(&b, '-' if neg else '+'); // integer, padded with zeros when needed if digs.decimal_point > 0 { @@ -138,7 +138,7 @@ format_digits :: proc(buf: []byte, shortest: bool, neg: bool, digs: Decimal_Slic return to_bytes(b); case 'e', 'E': - add_bytes(&b, neg ? '-' : '+'); + add_bytes(&b, '-' if neg else '+'); ch := byte('0'); if digs.count != 0 { diff --git a/core/strings/strings.odin b/core/strings/strings.odin index a477b9e13..fd3808684 100644 --- a/core/strings/strings.odin +++ b/core/strings/strings.odin @@ -325,7 +325,7 @@ last_index :: proc(s, substr: string) -> int { case n == 1: return last_index_byte(s, substr[0]); case n == len(s): - return substr == s ? 0 : -1; + return 0 if substr == s else -1; case n > len(s): return -1; } diff --git a/core/time/time.odin b/core/time/time.odin index 51c77a468..93fa7f8b3 100644 --- a/core/time/time.odin +++ b/core/time/time.odin @@ -92,7 +92,7 @@ duration_round :: proc(d, m: Duration) -> Duration { return MAX_DURATION; } duration_truncate :: proc(d, m: Duration) -> Duration { - return m <= 0 ? d : d - d%m; + return d if m <= 0 else d - d%m; } diff --git a/src/check_expr.cpp b/src/check_expr.cpp index 1a01eef31..d73aebbe4 100644 --- a/src/check_expr.cpp +++ b/src/check_expr.cpp @@ -7762,6 +7762,99 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type case_end; + case_ast_node(te, TernaryIfExpr, node); + Operand cond = {Addressing_Invalid}; + check_expr(c, &cond, te->cond); + node->viral_state_flags |= te->cond->viral_state_flags; + + if (cond.mode != Addressing_Invalid && !is_type_boolean(cond.type)) { + error(te->cond, "Non-boolean condition in ternary if expression"); + } + + Operand x = {Addressing_Invalid}; + Operand y = {Addressing_Invalid}; + check_expr_or_type(c, &x, te->x, type_hint); + node->viral_state_flags |= te->x->viral_state_flags; + + if (te->y != nullptr) { + check_expr_or_type(c, &y, te->y, type_hint); + node->viral_state_flags |= te->y->viral_state_flags; + } else { + error(node, "A ternary expression must have an else clause"); + return kind; + } + + if (x.type == nullptr || x.type == t_invalid || + y.type == nullptr || y.type == t_invalid) { + return kind; + } + + convert_to_typed(c, &x, y.type); + if (x.mode == Addressing_Invalid) { + return kind; + } + convert_to_typed(c, &y, x.type); + if (y.mode == Addressing_Invalid) { + x.mode = Addressing_Invalid; + return kind; + } + + if (!ternary_compare_types(x.type, y.type)) { + gbString its = type_to_string(x.type); + gbString ets = type_to_string(y.type); + error(node, "Mismatched types in ternary if expression, %s vs %s", its, ets); + gb_string_free(ets); + gb_string_free(its); + return kind; + } + + Type *type = x.type; + if (is_type_untyped_nil(type) || is_type_untyped_undef(type)) { + type = y.type; + } + + o->type = type; + o->mode = Addressing_Value; + + // if (cond.mode == Addressing_Constant && is_type_boolean(cond.type) && + // x.mode == Addressing_Constant && + // y.mode == Addressing_Constant) { + + // o->mode = Addressing_Constant; + + // if (cond.value.value_bool) { + // o->value = x.value; + // } else { + // o->value = y.value; + // } + // } + + case_end; + + case_ast_node(te, TernaryWhenExpr, node); + Operand cond = {}; + check_expr(c, &cond, te->cond); + node->viral_state_flags |= te->cond->viral_state_flags; + + if (cond.mode != Addressing_Constant || !is_type_boolean(cond.type)) { + error(te->cond, "Expected a constant boolean condition in ternary when expression"); + return kind; + } + + if (cond.value.value_bool) { + check_expr_or_type(c, o, te->x, type_hint); + node->viral_state_flags |= te->x->viral_state_flags; + } else { + if (te->y != nullptr) { + check_expr_or_type(c, o, te->y, type_hint); + node->viral_state_flags |= te->y->viral_state_flags; + } else { + error(node, "A ternary when expression must have an else clause"); + return kind; + } + } + case_end; + case_ast_node(cl, CompoundLit, node); Type *type = type_hint; bool is_to_be_determined_array_count = false; @@ -9333,6 +9426,22 @@ gbString write_expr_to_string(gbString str, Ast *node) { str = write_expr_to_string(str, te->y); case_end; + case_ast_node(te, TernaryIfExpr, node); + str = write_expr_to_string(str, te->x); + str = gb_string_appendc(str, " if "); + str = write_expr_to_string(str, te->cond); + str = gb_string_appendc(str, " else "); + str = write_expr_to_string(str, te->y); + case_end; + + case_ast_node(te, TernaryWhenExpr, node); + str = write_expr_to_string(str, te->x); + str = gb_string_appendc(str, " when "); + str = write_expr_to_string(str, te->cond); + str = gb_string_appendc(str, " else "); + str = write_expr_to_string(str, te->y); + case_end; + case_ast_node(pe, ParenExpr, node); str = gb_string_append_rune(str, '('); diff --git a/src/check_type.cpp b/src/check_type.cpp index 6194951c9..010b31f03 100644 --- a/src/check_type.cpp +++ b/src/check_type.cpp @@ -3454,6 +3454,26 @@ bool check_type_internal(CheckerContext *ctx, Ast *e, Type **type, Type *named_t return true; } case_end; + + case_ast_node(te, TernaryIfExpr, e); + Operand o = {}; + check_expr_or_type(ctx, &o, e); + if (o.mode == Addressing_Type) { + *type = o.type; + set_base_type(named_type, *type); + return true; + } + case_end; + + case_ast_node(te, TernaryWhenExpr, e); + Operand o = {}; + check_expr_or_type(ctx, &o, e); + if (o.mode == Addressing_Type) { + *type = o.type; + set_base_type(named_type, *type); + return true; + } + case_end; } *type = t_invalid; diff --git a/src/ir.cpp b/src/ir.cpp index 6a56eb387..ece960e2b 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -7189,6 +7189,49 @@ irValue *ir_build_expr_internal(irProcedure *proc, Ast *expr) { return ir_emit(proc, ir_instr_phi(proc, edges, type)); case_end; + case_ast_node(te, TernaryIfExpr, expr); + ir_emit_comment(proc, str_lit("TernaryIfExpr")); + + auto edges = array_make(ir_allocator(), 0, 2); + + GB_ASSERT(te->y != nullptr); + irBlock *then = ir_new_block(proc, nullptr, "if.then"); + irBlock *done = ir_new_block(proc, nullptr, "if.done"); // NOTE(bill): Append later + irBlock *else_ = ir_new_block(proc, nullptr, "if.else"); + + irValue *cond = ir_build_cond(proc, te->cond, then, else_); + ir_start_block(proc, then); + + Type *type = type_of_expr(expr); + + ir_open_scope(proc); + array_add(&edges, ir_emit_conv(proc, ir_build_expr(proc, te->x), type)); + ir_close_scope(proc, irDeferExit_Default, nullptr); + + ir_emit_jump(proc, done); + ir_start_block(proc, else_); + + ir_open_scope(proc); + array_add(&edges, ir_emit_conv(proc, ir_build_expr(proc, te->y), type)); + ir_close_scope(proc, irDeferExit_Default, nullptr); + + ir_emit_jump(proc, done); + ir_start_block(proc, done); + + return ir_emit(proc, ir_instr_phi(proc, edges, type)); + case_end; + + case_ast_node(te, TernaryWhenExpr, expr); + TypeAndValue tav = type_and_value_of_expr(te->cond); + GB_ASSERT(tav.mode == Addressing_Constant); + GB_ASSERT(tav.value.kind == ExactValue_Bool); + if (tav.value.value_bool) { + return ir_build_expr(proc, te->x); + } else { + return ir_build_expr(proc, te->y); + } + case_end; + case_ast_node(ta, TypeAssertion, expr); TokenPos pos = ast_token(expr).pos; Type *type = tv.type; diff --git a/src/parser.cpp b/src/parser.cpp index f89b5676b..83ae9743f 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -35,6 +35,8 @@ Token ast_token(Ast *node) { case Ast_FieldValue: return node->FieldValue.eq; case Ast_DerefExpr: return node->DerefExpr.op; case Ast_TernaryExpr: return ast_token(node->TernaryExpr.cond); + case Ast_TernaryIfExpr: return ast_token(node->TernaryIfExpr.x); + case Ast_TernaryWhenExpr: return ast_token(node->TernaryWhenExpr.x); case Ast_TypeAssertion: return ast_token(node->TypeAssertion.expr); case Ast_TypeCast: return node->TypeCast.token; case Ast_AutoCast: return node->AutoCast.token; @@ -198,6 +200,16 @@ Ast *clone_ast(Ast *node) { n->TernaryExpr.x = clone_ast(n->TernaryExpr.x); n->TernaryExpr.y = clone_ast(n->TernaryExpr.y); break; + case Ast_TernaryIfExpr: + n->TernaryIfExpr.x = clone_ast(n->TernaryIfExpr.x); + n->TernaryIfExpr.cond = clone_ast(n->TernaryIfExpr.cond); + n->TernaryIfExpr.y = clone_ast(n->TernaryIfExpr.y); + break; + case Ast_TernaryWhenExpr: + n->TernaryWhenExpr.x = clone_ast(n->TernaryWhenExpr.x); + n->TernaryWhenExpr.cond = clone_ast(n->TernaryWhenExpr.cond); + n->TernaryWhenExpr.y = clone_ast(n->TernaryWhenExpr.y); + break; case Ast_TypeAssertion: n->TypeAssertion.expr = clone_ast(n->TypeAssertion.expr); n->TypeAssertion.type = clone_ast(n->TypeAssertion.type); @@ -638,6 +650,21 @@ Ast *ast_ternary_expr(AstFile *f, Ast *cond, Ast *x, Ast *y) { result->TernaryExpr.y = y; return result; } +Ast *ast_ternary_if_expr(AstFile *f, Ast *x, Ast *cond, Ast *y) { + Ast *result = alloc_ast_node(f, Ast_TernaryIfExpr); + result->TernaryIfExpr.x = x; + result->TernaryIfExpr.cond = cond; + result->TernaryIfExpr.y = y; + return result; +} +Ast *ast_ternary_when_expr(AstFile *f, Ast *x, Ast *cond, Ast *y) { + Ast *result = alloc_ast_node(f, Ast_TernaryWhenExpr); + result->TernaryWhenExpr.x = x; + result->TernaryWhenExpr.cond = cond; + result->TernaryWhenExpr.y = y; + return result; +} + Ast *ast_type_assertion(AstFile *f, Ast *expr, Token dot, Ast *type) { Ast *result = alloc_ast_node(f, Ast_TypeAssertion); result->TypeAssertion.expr = expr; @@ -1199,6 +1226,8 @@ Token expect_operator(AstFile *f) { Token prev = f->curr_token; if ((prev.kind == Token_in || prev.kind == Token_not_in) && (f->expr_level >= 0 || f->allow_in_expr)) { // okay + } else if (prev.kind == Token_if || prev.kind == Token_when) { + // okay } else if (!gb_is_between(prev.kind, Token__OperatorBegin+1, Token__OperatorEnd-1)) { syntax_error(f->curr_token, "Expected an operator, got '%.*s'", LIT(token_strings[prev.kind])); @@ -2512,6 +2541,8 @@ bool is_ast_range(Ast *expr) { i32 token_precedence(AstFile *f, TokenKind t) { switch (t) { case Token_Question: + case Token_if: + case Token_when: return 1; case Token_Ellipsis: case Token_RangeHalf: @@ -2565,6 +2596,14 @@ Ast *parse_binary_expr(AstFile *f, bool lhs, i32 prec_in) { // NOTE(bill): This will also catch operators that are not valid "binary" operators break; } + if (op.kind == Token_if || op.kind == Token_when) { + Token prev = f->prev_token; + if (prev.pos.line < op.pos.line) { + // NOTE(bill): Check to see if the `if` or `when` is on the same line of the `lhs` condition + break; + } + } + expect_operator(f); // NOTE(bill): error checks too if (op.kind == Token_Question) { @@ -2574,6 +2613,20 @@ Ast *parse_binary_expr(AstFile *f, bool lhs, i32 prec_in) { Token token_c = expect_token(f, Token_Colon); Ast *y = parse_expr(f, lhs); expr = ast_ternary_expr(f, cond, x, y); + } else if (op.kind == Token_if) { + Ast *x = expr; + // Token_if + Ast *cond = parse_expr(f, lhs); + Token tok_else = expect_token(f, Token_else); + Ast *y = parse_expr(f, lhs); + expr = ast_ternary_if_expr(f, x, cond, y); + } else if (op.kind == Token_when) { + Ast *x = expr; + // Token_when + Ast *cond = parse_expr(f, lhs); + Token tok_else = expect_token(f, Token_else); + Ast *y = parse_expr(f, lhs); + expr = ast_ternary_when_expr(f, x, cond, y); } else { Ast *right = parse_binary_expr(f, false, prec+1); if (right == nullptr) { diff --git a/src/parser.hpp b/src/parser.hpp index 6426cc96b..3de848ca6 100644 --- a/src/parser.hpp +++ b/src/parser.hpp @@ -282,8 +282,10 @@ AST_KIND(_ExprBegin, "", bool) \ Token ellipsis; \ ProcInlining inlining; \ }) \ - AST_KIND(FieldValue, "field value", struct { Token eq; Ast *field, *value; }) \ - AST_KIND(TernaryExpr, "ternary expression", struct { Ast *cond, *x, *y; }) \ + AST_KIND(FieldValue, "field value", struct { Token eq; Ast *field, *value; }) \ + AST_KIND(TernaryExpr, "ternary expression", struct { Ast *cond, *x, *y; }) \ + AST_KIND(TernaryIfExpr, "ternary if expression", struct { Ast *x, *cond, *y; }) \ + AST_KIND(TernaryWhenExpr, "ternary when expression", struct { Ast *x, *cond, *y; }) \ AST_KIND(TypeAssertion, "type assertion", struct { Ast *expr; Token dot; Ast *type; }) \ AST_KIND(TypeCast, "type cast", struct { Token token; Ast *type, *expr; }) \ AST_KIND(AutoCast, "auto_cast", struct { Token token; Ast *expr; }) \