diff --git a/spec/compiler/formatter/formatter_spec.cr b/spec/compiler/formatter/formatter_spec.cr index fdd2b9012996..2e8366079a66 100644 --- a/spec/compiler/formatter/formatter_spec.cr +++ b/spec/compiler/formatter/formatter_spec.cr @@ -2094,10 +2094,10 @@ describe Crystal::Formatter do assert_format "G_(A, (B -> R))" assert_format "G_(A, ->)" assert_format "G_(A, (->))" - pending { assert_format "G_(A, () ->)" } # #16741 + assert_format "G_(A, () ->)" assert_format "G_(A, -> R)" assert_format "G_(-> R)" - pending { assert_format "G_(() -> R)" } # #16741 + assert_format "G_(() -> R)" assert_format "G_(A -> R | S)" assert_format "G_((A -> R | S))" @@ -2108,6 +2108,15 @@ describe Crystal::Formatter do assert_format "G_(A | B -> R)" assert_format "G_((A | B) -> C)" assert_format "G_(A | (B -> C))" + + assert_format "G_((A*) -> R)" + + assert_format "G_(((A) ->) ->)" + assert_format "G_((A ->) ->)" + assert_format "G_(A -> ->)" + assert_format "G_((A -> ->))" + assert_format "G_(((A) ->, B) ->)" + assert_format "G_((A) | B ->)" end assert_format "foo &.bar.is_a?(Baz)" diff --git a/spec/compiler/parser/parser_spec.cr b/spec/compiler/parser/parser_spec.cr index 408bfda6668f..0bca8a7143c9 100644 --- a/spec/compiler/parser/parser_spec.cr +++ b/spec/compiler/parser/parser_spec.cr @@ -491,14 +491,19 @@ module Crystal it_parses "def foo(var : self.class); end", Def.new("foo", [Arg.new("var", restriction: Metaclass.new(Self.new))]) it_parses "def foo(var : self*); end", Def.new("foo", [Arg.new("var", restriction: Self.new.pointer_of)]) it_parses "def foo(var : Int | Double); end", Def.new("foo", [Arg.new("var", restriction: Crystal::Union.new(["Int".path, "Double".path] of ASTNode))]) + it_parses "def foo(var : (Int | Double)); end", Def.new("foo", [Arg.new("var", restriction: Crystal::Union.parens(Crystal::Union.new(["Int".path, "Double".path] of ASTNode)))]) it_parses "def foo(var : Int?); end", Def.new("foo", [Arg.new("var", restriction: Crystal::Union.new(["Int".path, "Nil".path(true)] of ASTNode))]) it_parses "def foo(var : Int*); end", Def.new("foo", [Arg.new("var", restriction: "Int".path.pointer_of)]) it_parses "def foo(var : Int**); end", Def.new("foo", [Arg.new("var", restriction: "Int".path.pointer_of.pointer_of)]) it_parses "def foo(var : Int -> Double); end", Def.new("foo", [Arg.new("var", restriction: ProcNotation.new(["Int".path] of ASTNode, "Double".path))]) it_parses "def foo(var : Int, Float -> Double); end", Def.new("foo", [Arg.new("var", restriction: ProcNotation.new(["Int".path, "Float".path] of ASTNode, "Double".path))]) - it_parses "def foo(var : (Int, Float -> Double)); end", Def.new("foo", [Arg.new("var", restriction: ProcNotation.new(["Int".path, "Float".path] of ASTNode, "Double".path))]) + it_parses "def foo(var : (Int, Float -> Double)); end", Def.new("foo", [Arg.new("var", restriction: Crystal::Union.parens(ProcNotation.new(["Int".path, "Float".path] of ASTNode, "Double".path)))]) it_parses "def foo(var : (Int, Float) -> Double); end", Def.new("foo", [Arg.new("var", restriction: ProcNotation.new(["Int".path, "Float".path] of ASTNode, "Double".path))]) it_parses "def foo(var : () -> Double); end", Def.new("foo", [Arg.new("var", restriction: ProcNotation.new([] of ASTNode, "Double".path))]) + it_parses "x : (A -> B)", TypeDeclaration.new("x".var, declared_type: Crystal::Union.parens(ProcNotation.new(["A".path] of ASTNode, "B".path))) + it_parses "x : (A -> B).class", TypeDeclaration.new("x".var, declared_type: Metaclass.new(Crystal::Union.parens(ProcNotation.new(["A".path] of ASTNode, "B".path)))) + it_parses "alias T = (A*) -> R", Alias.new("T".path, ProcNotation.new([Generic.new(Path.global("Pointer"), ["A".path] of ASTNode, suffix: :asterisk)] of ASTNode, "R".path)) + it_parses "alias T = (A -> ) ->", Alias.new("T".path, ProcNotation.new([ProcNotation.new(["A".path] of ASTNode)] of ASTNode)) it_parses "def foo(var : Char[256]); end", Def.new("foo", [Arg.new("var", restriction: "Char".static_array_of(256))]) it_parses "def foo(var : Char[N]); end", Def.new("foo", [Arg.new("var", restriction: "Char".static_array_of("N".path))]) it_parses "def foo(var : Int32 = 1); end", Def.new("foo", [Arg.new("var", 1.int32, "Int32".path)]) @@ -1254,7 +1259,8 @@ module Crystal it_parses "lib LibC\nfun getchar\nend", LibDef.new("LibC".path, [FunDef.new("getchar")] of ASTNode) it_parses "lib LibC\nfun getchar(...)\nend", LibDef.new("LibC".path, [FunDef.new("getchar", varargs: true)] of ASTNode) it_parses "lib LibC\nfun getchar : Int\nend", LibDef.new("LibC".path, [FunDef.new("getchar", return_type: "Int".path)] of ASTNode) - it_parses "lib LibC\nfun getchar : (->)?\nend", LibDef.new("LibC".path, [FunDef.new("getchar", return_type: Crystal::Union.new([ProcNotation.new, "Nil".path(true)] of ASTNode))] of ASTNode) + it_parses "lib LibC\nfun getchar : (->)?\nend", LibDef.new("LibC".path, [FunDef.new("getchar", return_type: Crystal::Union.new([Crystal::Union.parens(ProcNotation.new), "Nil".path(true)] of ASTNode))] of ASTNode) + it_parses "lib LibC\nfun getchar : ((->))?\nend", LibDef.new("LibC".path, [FunDef.new("getchar", return_type: Crystal::Union.new([Crystal::Union.parens(Crystal::Union.parens(ProcNotation.new)), "Nil".path(true)] of ASTNode))] of ASTNode) it_parses "lib LibC\nfun getchar(Int, Float)\nend", LibDef.new("LibC".path, [FunDef.new("getchar", [Arg.new("", restriction: "Int".path), Arg.new("", restriction: "Float".path)])] of ASTNode) it_parses "lib LibC\nfun getchar(a : Int, b : Float)\nend", LibDef.new("LibC".path, [FunDef.new("getchar", [Arg.new("a", restriction: "Int".path), Arg.new("b", restriction: "Float".path)])] of ASTNode) it_parses "lib LibC\nfun getchar(a : Int)\nend", LibDef.new("LibC".path, [FunDef.new("getchar", [Arg.new("a", restriction: "Int".path)])] of ASTNode) @@ -3932,6 +3938,24 @@ module Crystal node_source(source, node).should eq("::Foo") end + it "sets correct location of proc notation inputs" do + source = "alias T = ((A), (B)) -> R" + proc_notation = Parser.parse(source).as(Alias).value.should be_a(ProcNotation) + inputs = proc_notation.inputs.should be_a(Array(ASTNode)) + path = inputs.first.should(be_a(Union)).types.first.should be_a(Path) + node_source(source, path).should eq "A" + node_source(source, inputs[1]).should eq "(B)" + end + + it "sets correct location of proc notation inputs" do + source = "alias T = (A) -> R" + proc_notation = Parser.parse(source).as(Alias).value.should be_a(ProcNotation) + inputs = proc_notation.inputs.should be_a(Array(ASTNode)) + path = inputs.first.should be_a(Path) + node_source(source, path).should eq "A" + node_source(source, proc_notation).should eq "(A) -> R" + end + it "sets args_in_brackets to false for `a.b`" do parser = Parser.new("a.b") node = parser.parse.as(Call) diff --git a/spec/compiler/parser/to_s_spec.cr b/spec/compiler/parser/to_s_spec.cr index d5db72d0a9de..0edc2f88f264 100644 --- a/spec/compiler/parser/to_s_spec.cr +++ b/spec/compiler/parser/to_s_spec.cr @@ -123,9 +123,9 @@ describe "ASTNode#to_s" do # 14216 expect_to_s "def foo(x, **args, &block : _ -> _)\nend" - expect_to_s "def foo(x, **args, &block : (_ -> _))\nend", "def foo(x, **args, &block : _ -> _)\nend" + expect_to_s "def foo(x, **args, &block : (_ -> _))\nend" expect_to_s "def foo(& : ->)\nend" - expect_to_s "def foo(& : (->))\nend", "def foo(& : ->)\nend" + expect_to_s "def foo(& : (->))\nend" expect_to_s "def foo(x : (T -> U) -> V, *args : (T -> U) -> V, y : (T -> U) -> V, **opts : (T -> U) -> V, & : (T -> U) -> V) : ((T -> U) -> V)\nend" expect_to_s "foo(x : (T -> U) -> V, W)" expect_to_s "foo[x : (T -> U) -> V, W]" diff --git a/src/compiler/crystal/semantic/normalizer.cr b/src/compiler/crystal/semantic/normalizer.cr index d9420ed23484..a0b712af53e5 100644 --- a/src/compiler/crystal/semantic/normalizer.cr +++ b/src/compiler/crystal/semantic/normalizer.cr @@ -510,5 +510,14 @@ module Crystal values = [Var.new(var_name).at(expressions)] of ASTNode MultiAssign.new(targets, values).at(expressions) end + + def transform(node : Union) + if node.singleton? + # If the union has just one type, return that instead of a union + node.types.first + else + super + end + end end end diff --git a/src/compiler/crystal/syntax/ast.cr b/src/compiler/crystal/syntax/ast.cr index 71bae99be2ad..b6719706dc79 100644 --- a/src/compiler/crystal/syntax/ast.cr +++ b/src/compiler/crystal/syntax/ast.cr @@ -1846,8 +1846,23 @@ module Crystal class Union < ASTNode property types : Array(ASTNode) + property? parens : Bool - def initialize(@types) + def self.parens(type : ASTNode) + # Wrap existing union in parens if it doesn't already have parens + if type.is_a?(Union) && !type.parens? + return type.tap { |t| t.parens = true } + end + + new [type] of ASTNode, parens: true + end + + def initialize(@types, @parens = false) + end + + # A union with only one element represents parenthesis in the type grammar: `(A)` + def singleton? + types.size == 1 end def accept_children(visitor) @@ -1855,10 +1870,10 @@ module Crystal end def clone_without_location - Union.new(@types.clone) + Union.new(@types.clone, @parens) end - def_equals_and_hash types + def_equals_and_hash types, parens? end class Self < ASTNode diff --git a/src/compiler/crystal/syntax/location.cr b/src/compiler/crystal/syntax/location.cr index de40526faae7..2fdd2ccacd39 100644 --- a/src/compiler/crystal/syntax/location.cr +++ b/src/compiler/crystal/syntax/location.cr @@ -98,6 +98,13 @@ class Crystal::Location end end + def equals?(other) + return false unless (@filename || "") == (other.filename || "") + return false unless @line_number == other.line_number + return false unless @column_number == other.column_number + true + end + # Returns the number of lines between start and finish locations. def self.lines(start, finish) return unless start && finish && start.filename == finish.filename diff --git a/src/compiler/crystal/syntax/parser.cr b/src/compiler/crystal/syntax/parser.cr index a5f449351fba..fca16b6f5f8e 100644 --- a/src/compiler/crystal/syntax/parser.cr +++ b/src/compiler/crystal/syntax/parser.cr @@ -5112,15 +5112,15 @@ module Crystal end type = parse_type_splat { parse_union_type } - if type.is_a?(Union) - type.at(location).at_end(@token.location) - end if @token.type.op_rparen? + end_location = @token.location next_token_skip_space if @token.type.op_minus_gt? # `(A) -> B` case - type = parse_proc_type_output([type], location) + type = parse_proc_type_output([type] of ASTNode, location) elsif type.is_a?(Splat) raise "invalid type splat", type.location.not_nil! + else + type = Union.parens(type).at(location).at_end(end_location) end else input_types = [type] @@ -5133,6 +5133,11 @@ module Crystal type = parse_proc_type_output(input_types, input_types.first.location) check :OP_RPAREN next_token_skip_space + unless @token.type.op_minus_gt? + # Usually the parenthesis is encoded in a Union instance. + # But not when nesting proc notations like `(->) ->`. + type = Union.parens(type).at(type) + end else # `(A, B, C) -> D` case check :OP_RPAREN next_token_skip_space diff --git a/src/compiler/crystal/syntax/to_s.cr b/src/compiler/crystal/syntax/to_s.cr index ef0a613f469b..1df530a3538b 100644 --- a/src/compiler/crystal/syntax/to_s.cr +++ b/src/compiler/crystal/syntax/to_s.cr @@ -1149,15 +1149,20 @@ module Crystal end def visit(node : Union) - node.types.join(@str, " | ", &.accept self) + @str << "(" if node.parens? + + if node.singleton? + drop_parens_for_proc_notation(node.types.first, &.accept(self)) + else + node.types.join(@str, " | ", &.accept self) + end + + @str << ")" if node.parens? false end def visit(node : Metaclass) - needs_parens = node.name.is_a?(Union) - @str << '(' if needs_parens node.name.accept self - @str << ')' if needs_parens @str << ".class" false end @@ -1871,7 +1876,7 @@ module Crystal # call arguments node.declared_type.is_a?(ProcNotation) else - false + node.is_a?(ProcNotation) end drop_parens_for_proc_notation(outermost_type_is_proc_notation) { yield node } diff --git a/src/compiler/crystal/tools/formatter.cr b/src/compiler/crystal/tools/formatter.cr index 5b3675c0f509..6eb1deaa7712 100644 --- a/src/compiler/crystal/tools/formatter.cr +++ b/src/compiler/crystal/tools/formatter.cr @@ -129,13 +129,6 @@ module Crystal @exp_needs_indent = true @inside_def = 0 - # When we parse a type, parentheses information is not stored in ASTs, unlike - # for an Expressions node. So when we are printing a type (Path, ProcNotation, Union, etc.) - # we increment this when we find a '(', and decrement it when we find ')', but - # only if `paren_count > 0`: it might be the case of `def foo(x : A)`, but we don't - # want to print that last ')' when printing the type A. - @paren_count = 0 - # This stores the column number (if any) of each comment in every line @when_infos = [] of AlignInfo @hash_infos = [] of AlignInfo @@ -1063,29 +1056,7 @@ module Crystal false end - def check_open_paren - if @token.type.op_lparen? - while @token.type.op_lparen? - write "(" - next_token_skip_space - @paren_count += 1 - end - true - else - false - end - end - - def check_close_paren - while @token.type.op_rparen? && @paren_count > 0 - @paren_count -= 1 - write_token :OP_RPAREN - end - end - def visit(node : Path) - check_open_paren - # Sometimes the :: is not present because the parser generates ::Nil, for example if node.global? && @token.type.op_colon_colon? write "::" @@ -1104,14 +1075,10 @@ module Crystal end end - check_close_paren - false end def visit(node : Generic) - check_open_paren - name = node.name.as(Path) if name.global? && @token.type.op_colon_colon? write "::" @@ -1212,17 +1179,12 @@ module Crystal accept name skip_space_or_newline - # Given that generic type arguments are always inside parentheses - # we can start counting them from 0 inside them. - old_paren_count = @paren_count - @paren_count = 0 - if named_args = node.named_args write_token :OP_LPAREN skip_space has_newlines, _, _ = format_named_args([] of ASTNode, named_args, @indent + 2) # `format_named_args` doesn't skip trailing comma - if @paren_count == 0 && @token.type.op_comma? + if @token.type.op_comma? next_token_skip_space_or_newline if has_newlines write "," @@ -1230,28 +1192,23 @@ module Crystal write_indent end end - skip_space_or_newline if @paren_count == 0 + skip_space_or_newline write_token :OP_RPAREN else format_literal_elements(node.type_vars, :OP_LPAREN, :OP_RPAREN) end - # Restore the old parentheses count - @paren_count = old_paren_count - false - ensure - check_close_paren end def visit(node : Union) - check_open_paren + write_token :OP_LPAREN if node.parens? if @token.type.ident? && @token.value == "self?" && node.types.size == 2 && node.types[0].is_a?(Self) && node.types[1].to_s == "::Nil" write "self?" next_token - check_close_paren + write_token :OP_RPAREN if node.parens? return false end @@ -1280,7 +1237,7 @@ module Crystal skip_space end - check_close_paren + write_token :OP_RPAREN if node.parens? false end @@ -2402,18 +2359,13 @@ module Crystal end def visit(node : ProcNotation) - check_open_paren + inputs = node.inputs - paren_count = @paren_count + has_input_parens = @token.type.op_lparen? && !inputs.try(&.first?).try(&.location).try(&.equals?(@token.location)) - if inputs = node.inputs - # Check if it's ((X, Y) -> Z) - # ^ ^ - sub_paren_count = @paren_count - if check_open_paren - sub_paren_count = @paren_count - end + write_token :op_lparen if has_input_parens + if inputs inputs.each_with_index do |input, i| accept input @@ -2424,14 +2376,10 @@ module Crystal end end - if sub_paren_count != paren_count - check_close_paren - end + write_token :OP_RPAREN if has_input_parens end - skip_space_or_newline if paren_count == @paren_count - check_close_paren - skip_space + skip_space_or_newline write " " if inputs write_token :OP_MINUS_GT @@ -2444,15 +2392,11 @@ module Crystal skip_space end - check_close_paren - false end def visit(node : Self) - check_open_paren write_keyword :self - check_close_paren false end @@ -4918,8 +4862,6 @@ module Crystal end def finish - raise "BUG: unclosed parenthesis" if @paren_count > 0 - skip_space write_line skip_space_or_newline last: true