Skip to content
13 changes: 11 additions & 2 deletions spec/compiler/formatter/formatter_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -2094,10 +2094,10 @@ describe Crystal::Formatter do
assert_format "G_(A, (B -> R))"
assert_format "G_(A, ->)"
assert_format "G_(A, (->))"
pending { assert_format "G_(A, () ->)" } # #16741
assert_format "G_(A, () ->)"
assert_format "G_(A, -> R)"
assert_format "G_(-> R)"
pending { assert_format "G_(() -> R)" } # #16741
assert_format "G_(() -> R)"

assert_format "G_(A -> R | S)"
assert_format "G_((A -> R | S))"
Expand All @@ -2108,6 +2108,15 @@ describe Crystal::Formatter do
assert_format "G_(A | B -> R)"
assert_format "G_((A | B) -> C)"
assert_format "G_(A | (B -> C))"

assert_format "G_((A*) -> R)"

assert_format "G_(((A) ->) ->)"
assert_format "G_((A ->) ->)"
assert_format "G_(A -> ->)"
assert_format "G_((A -> ->))"
assert_format "G_(((A) ->, B) ->)"
assert_format "G_((A) | B ->)"
end

assert_format "foo &.bar.is_a?(Baz)"
Expand Down
28 changes: 26 additions & 2 deletions spec/compiler/parser/parser_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -491,14 +491,19 @@ module Crystal
it_parses "def foo(var : self.class); end", Def.new("foo", [Arg.new("var", restriction: Metaclass.new(Self.new))])
it_parses "def foo(var : self*); end", Def.new("foo", [Arg.new("var", restriction: Self.new.pointer_of)])
it_parses "def foo(var : Int | Double); end", Def.new("foo", [Arg.new("var", restriction: Crystal::Union.new(["Int".path, "Double".path] of ASTNode))])
it_parses "def foo(var : (Int | Double)); end", Def.new("foo", [Arg.new("var", restriction: Crystal::Union.parens(Crystal::Union.new(["Int".path, "Double".path] of ASTNode)))])
it_parses "def foo(var : Int?); end", Def.new("foo", [Arg.new("var", restriction: Crystal::Union.new(["Int".path, "Nil".path(true)] of ASTNode))])
it_parses "def foo(var : Int*); end", Def.new("foo", [Arg.new("var", restriction: "Int".path.pointer_of)])
it_parses "def foo(var : Int**); end", Def.new("foo", [Arg.new("var", restriction: "Int".path.pointer_of.pointer_of)])
it_parses "def foo(var : Int -> Double); end", Def.new("foo", [Arg.new("var", restriction: ProcNotation.new(["Int".path] of ASTNode, "Double".path))])
it_parses "def foo(var : Int, Float -> Double); end", Def.new("foo", [Arg.new("var", restriction: ProcNotation.new(["Int".path, "Float".path] of ASTNode, "Double".path))])
it_parses "def foo(var : (Int, Float -> Double)); end", Def.new("foo", [Arg.new("var", restriction: ProcNotation.new(["Int".path, "Float".path] of ASTNode, "Double".path))])
it_parses "def foo(var : (Int, Float -> Double)); end", Def.new("foo", [Arg.new("var", restriction: Crystal::Union.parens(ProcNotation.new(["Int".path, "Float".path] of ASTNode, "Double".path)))])
it_parses "def foo(var : (Int, Float) -> Double); end", Def.new("foo", [Arg.new("var", restriction: ProcNotation.new(["Int".path, "Float".path] of ASTNode, "Double".path))])
it_parses "def foo(var : () -> Double); end", Def.new("foo", [Arg.new("var", restriction: ProcNotation.new([] of ASTNode, "Double".path))])
it_parses "x : (A -> B)", TypeDeclaration.new("x".var, declared_type: Crystal::Union.parens(ProcNotation.new(["A".path] of ASTNode, "B".path)))
it_parses "x : (A -> B).class", TypeDeclaration.new("x".var, declared_type: Metaclass.new(Crystal::Union.parens(ProcNotation.new(["A".path] of ASTNode, "B".path))))
it_parses "alias T = (A*) -> R", Alias.new("T".path, ProcNotation.new([Generic.new(Path.global("Pointer"), ["A".path] of ASTNode, suffix: :asterisk)] of ASTNode, "R".path))
it_parses "alias T = (A -> ) ->", Alias.new("T".path, ProcNotation.new([ProcNotation.new(["A".path] of ASTNode)] of ASTNode))
it_parses "def foo(var : Char[256]); end", Def.new("foo", [Arg.new("var", restriction: "Char".static_array_of(256))])
it_parses "def foo(var : Char[N]); end", Def.new("foo", [Arg.new("var", restriction: "Char".static_array_of("N".path))])
it_parses "def foo(var : Int32 = 1); end", Def.new("foo", [Arg.new("var", 1.int32, "Int32".path)])
Expand Down Expand Up @@ -1254,7 +1259,8 @@ module Crystal
it_parses "lib LibC\nfun getchar\nend", LibDef.new("LibC".path, [FunDef.new("getchar")] of ASTNode)
it_parses "lib LibC\nfun getchar(...)\nend", LibDef.new("LibC".path, [FunDef.new("getchar", varargs: true)] of ASTNode)
it_parses "lib LibC\nfun getchar : Int\nend", LibDef.new("LibC".path, [FunDef.new("getchar", return_type: "Int".path)] of ASTNode)
it_parses "lib LibC\nfun getchar : (->)?\nend", LibDef.new("LibC".path, [FunDef.new("getchar", return_type: Crystal::Union.new([ProcNotation.new, "Nil".path(true)] of ASTNode))] of ASTNode)
it_parses "lib LibC\nfun getchar : (->)?\nend", LibDef.new("LibC".path, [FunDef.new("getchar", return_type: Crystal::Union.new([Crystal::Union.parens(ProcNotation.new), "Nil".path(true)] of ASTNode))] of ASTNode)
it_parses "lib LibC\nfun getchar : ((->))?\nend", LibDef.new("LibC".path, [FunDef.new("getchar", return_type: Crystal::Union.new([Crystal::Union.parens(Crystal::Union.parens(ProcNotation.new)), "Nil".path(true)] of ASTNode))] of ASTNode)
it_parses "lib LibC\nfun getchar(Int, Float)\nend", LibDef.new("LibC".path, [FunDef.new("getchar", [Arg.new("", restriction: "Int".path), Arg.new("", restriction: "Float".path)])] of ASTNode)
it_parses "lib LibC\nfun getchar(a : Int, b : Float)\nend", LibDef.new("LibC".path, [FunDef.new("getchar", [Arg.new("a", restriction: "Int".path), Arg.new("b", restriction: "Float".path)])] of ASTNode)
it_parses "lib LibC\nfun getchar(a : Int)\nend", LibDef.new("LibC".path, [FunDef.new("getchar", [Arg.new("a", restriction: "Int".path)])] of ASTNode)
Expand Down Expand Up @@ -3932,6 +3938,24 @@ module Crystal
node_source(source, node).should eq("::Foo")
end

it "sets correct location of proc notation inputs" do
source = "alias T = ((A), (B)) -> R"
proc_notation = Parser.parse(source).as(Alias).value.should be_a(ProcNotation)
inputs = proc_notation.inputs.should be_a(Array(ASTNode))
path = inputs.first.should(be_a(Union)).types.first.should be_a(Path)
node_source(source, path).should eq "A"
node_source(source, inputs[1]).should eq "(B)"
end

it "sets correct location of proc notation inputs" do
source = "alias T = (A) -> R"
proc_notation = Parser.parse(source).as(Alias).value.should be_a(ProcNotation)
inputs = proc_notation.inputs.should be_a(Array(ASTNode))
path = inputs.first.should be_a(Path)
node_source(source, path).should eq "A"
node_source(source, proc_notation).should eq "(A) -> R"
end

it "sets args_in_brackets to false for `a.b`" do
parser = Parser.new("a.b")
node = parser.parse.as(Call)
Expand Down
4 changes: 2 additions & 2 deletions spec/compiler/parser/to_s_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ describe "ASTNode#to_s" do

# 14216
expect_to_s "def foo(x, **args, &block : _ -> _)\nend"
expect_to_s "def foo(x, **args, &block : (_ -> _))\nend", "def foo(x, **args, &block : _ -> _)\nend"
expect_to_s "def foo(x, **args, &block : (_ -> _))\nend"
expect_to_s "def foo(& : ->)\nend"
expect_to_s "def foo(& : (->))\nend", "def foo(& : ->)\nend"
expect_to_s "def foo(& : (->))\nend"
expect_to_s "def foo(x : (T -> U) -> V, *args : (T -> U) -> V, y : (T -> U) -> V, **opts : (T -> U) -> V, & : (T -> U) -> V) : ((T -> U) -> V)\nend"
expect_to_s "foo(x : (T -> U) -> V, W)"
expect_to_s "foo[x : (T -> U) -> V, W]"
Expand Down
9 changes: 9 additions & 0 deletions src/compiler/crystal/semantic/normalizer.cr
Original file line number Diff line number Diff line change
Expand Up @@ -510,5 +510,14 @@ module Crystal
values = [Var.new(var_name).at(expressions)] of ASTNode
MultiAssign.new(targets, values).at(expressions)
end

def transform(node : Union)
if node.singleton?
# If the union has just one type, return that instead of a union
node.types.first
else
super
end
end
end
end
21 changes: 18 additions & 3 deletions src/compiler/crystal/syntax/ast.cr
Original file line number Diff line number Diff line change
Expand Up @@ -1846,19 +1846,34 @@ module Crystal

class Union < ASTNode
property types : Array(ASTNode)
property? parens : Bool

def initialize(@types)
def self.parens(type : ASTNode)
# Wrap existing union in parens if it doesn't already have parens
if type.is_a?(Union) && !type.parens?
return type.tap { |t| t.parens = true }
end

new [type] of ASTNode, parens: true
end

def initialize(@types, @parens = false)
end

# A union with only one element represents parenthesis in the type grammar: `(A)`
def singleton?
types.size == 1
end

def accept_children(visitor)
@types.each &.accept visitor
end

def clone_without_location
Union.new(@types.clone)
Union.new(@types.clone, @parens)
end

def_equals_and_hash types
def_equals_and_hash types, parens?
end

class Self < ASTNode
Expand Down
7 changes: 7 additions & 0 deletions src/compiler/crystal/syntax/location.cr
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ class Crystal::Location
end
end

def equals?(other)
return false unless (@filename || "") == (other.filename || "")
return false unless @line_number == other.line_number
return false unless @column_number == other.column_number
true
end

# Returns the number of lines between start and finish locations.
def self.lines(start, finish)
return unless start && finish && start.filename == finish.filename
Expand Down
13 changes: 9 additions & 4 deletions src/compiler/crystal/syntax/parser.cr
Original file line number Diff line number Diff line change
Expand Up @@ -5112,15 +5112,15 @@ module Crystal
end

type = parse_type_splat { parse_union_type }
if type.is_a?(Union)
type.at(location).at_end(@token.location)
end
if @token.type.op_rparen?
end_location = @token.location
next_token_skip_space
if @token.type.op_minus_gt? # `(A) -> B` case
type = parse_proc_type_output([type], location)
type = parse_proc_type_output([type] of ASTNode, location)
elsif type.is_a?(Splat)
raise "invalid type splat", type.location.not_nil!
else
type = Union.parens(type).at(location).at_end(end_location)
end
else
input_types = [type]
Expand All @@ -5133,6 +5133,11 @@ module Crystal
type = parse_proc_type_output(input_types, input_types.first.location)
check :OP_RPAREN
next_token_skip_space
unless @token.type.op_minus_gt?
# Usually the parenthesis is encoded in a Union instance.
# But not when nesting proc notations like `(->) ->`.
type = Union.parens(type).at(type)
end
else # `(A, B, C) -> D` case
check :OP_RPAREN
next_token_skip_space
Expand Down
15 changes: 10 additions & 5 deletions src/compiler/crystal/syntax/to_s.cr
Original file line number Diff line number Diff line change
Expand Up @@ -1149,15 +1149,20 @@ module Crystal
end

def visit(node : Union)
node.types.join(@str, " | ", &.accept self)
@str << "(" if node.parens?

if node.singleton?
drop_parens_for_proc_notation(node.types.first, &.accept(self))
else
node.types.join(@str, " | ", &.accept self)
end

@str << ")" if node.parens?
false
end

def visit(node : Metaclass)
needs_parens = node.name.is_a?(Union)
@str << '(' if needs_parens
node.name.accept self
@str << ')' if needs_parens
@str << ".class"
false
end
Expand Down Expand Up @@ -1871,7 +1876,7 @@ module Crystal
# call arguments
node.declared_type.is_a?(ProcNotation)
else
false
node.is_a?(ProcNotation)
end

drop_parens_for_proc_notation(outermost_type_is_proc_notation) { yield node }
Expand Down
Loading
Loading