diff --git a/.rubocop.yml b/.rubocop.yml index f6ffbcd0..3323c741 100644 --- a/.rubocop.yml +++ b/.rubocop.yml @@ -55,6 +55,9 @@ Style/IdenticalConditionalBranches: Style/IfInsideElse: Enabled: false +Style/IfWithBooleanLiteralBranches: + Enabled: false + Style/KeywordParametersOrder: Enabled: false diff --git a/CHANGELOG.md b/CHANGELOG.md index f0ba115e..c4558185 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,19 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) a ## [Unreleased] +## [4.0.0] - 2022-10-17 + +### Added + +- [#169](https://github.com/ruby-syntax-tree/syntax_tree/pull/169) - You can now pass `--ignore-files` multiple times. +- [#157](https://github.com/ruby-syntax-tree/syntax_tree/pull/157) - We now support tracking local variable definitions throughout the visitor. This allows you to access scope information while visiting the tree. +- [#170](https://github.com/ruby-syntax-tree/syntax_tree/pull/170) - There is now an undocumented `STREE_FAST_FORMAT` environment variable checked when formatting. It has the effect of turning _off_ formatting call chains and ternaries in special ways. This improves performance quite a bit. I'm leaving it undocumented because ideally we just improve the performance as a whole. This is meant as a stopgap until we get there. + +### Changed + +- [#170](https://github.com/ruby-syntax-tree/syntax_tree/pull/170) - We now require at least version `1.0.0` of `prettier_print`. This is to take advantage of the first-class string support in the doc tree. +- [#170](https://github.com/ruby-syntax-tree/syntax_tree/pull/170) - Pattern matching has been removed from usage internal to this library (excluding the language server). This should hopefully enable runtimes that don't have pattern matching fully implemented yet (e.g., TruffleRuby) to run this gem. + ## [3.6.3] - 2022-10-11 ### Changed @@ -370,7 +383,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) a - 🎉 Initial release! 🎉 -[unreleased]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v3.6.3...HEAD +[unreleased]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v4.0.0...HEAD +[4.0.0]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v3.6.3...v4.0.0 [3.6.3]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v3.6.2...v3.6.3 [3.6.2]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v3.6.1...v3.6.2 [3.6.1]: https://github.com/ruby-syntax-tree/syntax_tree/compare/v3.6.0...v3.6.1 diff --git a/Gemfile.lock b/Gemfile.lock index 6415fcb0..00ae409b 100644 --- a/Gemfile.lock +++ b/Gemfile.lock @@ -1,8 +1,8 @@ PATH remote: . specs: - syntax_tree (3.6.3) - prettier_print + syntax_tree (4.0.0) + prettier_print (>= 1.0.0) GEM remote: https://rubygems.org/ @@ -14,10 +14,10 @@ GEM parallel (1.22.1) parser (3.1.2.1) ast (~> 2.4.1) - prettier_print (0.1.0) + prettier_print (1.0.0) rainbow (3.1.1) rake (13.0.6) - regexp_parser (2.5.0) + regexp_parser (2.6.0) rexml (3.2.5) rubocop (1.36.0) json (~> 2.3) @@ -38,7 +38,7 @@ GEM simplecov_json_formatter (~> 0.1) simplecov-html (0.12.3) simplecov_json_formatter (0.1.4) - unicode-display_width (2.2.0) + unicode-display_width (2.3.0) PLATFORMS arm64-darwin-21 diff --git a/README.md b/README.md index afb65843..30c35ac8 100644 --- a/README.md +++ b/README.md @@ -368,16 +368,16 @@ program = SyntaxTree.parse("1 + 1") puts program.construct_keys # SyntaxTree::Program[ -# statements: SyntaxTree::Statements[ -# body: [ -# SyntaxTree::Binary[ -# left: SyntaxTree::Int[value: "1"], -# operator: :+, -# right: SyntaxTree::Int[value: "1"] -# ] -# ] -# ] -# ] +# statements: SyntaxTree::Statements[ +# body: [ +# SyntaxTree::Binary[ +# left: SyntaxTree::Int[value: "1"], +# operator: :+, +# right: SyntaxTree::Int[value: "1"] +# ] +# ] +# ] +# ] ``` ## Visitor @@ -447,6 +447,28 @@ end The visitor defined above will error out unless it's only visiting a `SyntaxTree::Int` node. This is useful in a couple of ways, e.g., if you're trying to define a visitor to handle the whole tree but it's currently a work-in-progress. +### WithEnvironment + +The `WithEnvironment` module can be included in visitors to automatically keep track of local variables and arguments +defined inside each environment. A `current_environment` accessor is made availble to the request, allowing it to find +all usages and definitions of a local. + +```ruby +class MyVisitor < Visitor + include WithEnvironment + + def visit_ident(node) + # find_local will return a Local for any local variables or arguments present in the current environment or nil if + # the identifier is not a local + local = current_environment.find_local(node) + + puts local.type # print the type of the local (:variable or :argument) + puts local.definitions # print the array of locations where this local is defined + puts local.usages # print the array of locations where this local occurs + end +end +``` + ## Language server Syntax Tree additionally ships with a language server conforming to the [language server protocol](https://microsoft.github.io/language-server-protocol/). It can be invoked through the CLI by running: diff --git a/bin/profile b/bin/profile index 0a1b6ade..15bd28ae 100755 --- a/bin/profile +++ b/bin/profile @@ -6,22 +6,21 @@ require "bundler/inline" gemfile do source "https://rubygems.org" gem "stackprof" + gem "prettier_print" end $:.unshift(File.expand_path("../lib", __dir__)) require "syntax_tree" -GC.disable - StackProf.run(mode: :cpu, out: "tmp/profile.dump", raw: true) do - filepath = File.expand_path("../lib/syntax_tree/node.rb", __dir__) - SyntaxTree.format(File.read(filepath)) + Dir[File.join(RbConfig::CONFIG["libdir"], "**/*.rb")].each do |filepath| + SyntaxTree.format(SyntaxTree.read(filepath)) + end end -GC.enable - File.open("tmp/flamegraph.html", "w") do |file| report = Marshal.load(IO.binread("tmp/profile.dump")) + StackProf::Report.new(report).print_text StackProf::Report.new(report).print_d3_flamegraph(file) end diff --git a/lib/syntax_tree.rb b/lib/syntax_tree.rb index 88c66369..52ec700b 100644 --- a/lib/syntax_tree.rb +++ b/lib/syntax_tree.rb @@ -1,6 +1,5 @@ # frozen_string_literal: true -require "delegate" require "etc" require "json" require "pp" @@ -10,7 +9,6 @@ require_relative "syntax_tree/formatter" require_relative "syntax_tree/node" -require_relative "syntax_tree/parser" require_relative "syntax_tree/version" require_relative "syntax_tree/basic_visitor" @@ -19,6 +17,20 @@ require_relative "syntax_tree/visitor/json_visitor" require_relative "syntax_tree/visitor/match_visitor" require_relative "syntax_tree/visitor/pretty_print_visitor" +require_relative "syntax_tree/visitor/environment" +require_relative "syntax_tree/visitor/with_environment" + +require_relative "syntax_tree/parser" + +# We rely on Symbol#name being available, which is only available in Ruby 3.0+. +# In case we're running on an older Ruby version, we polyfill it here. +unless :+.respond_to?(:name) + class Symbol # rubocop:disable Style/Documentation + def name + to_s.freeze + end + end +end # Syntax Tree is a suite of tools built on top of the internal CRuby parser. It # provides the ability to generate a syntax tree from source, as well as the diff --git a/lib/syntax_tree/cli.rb b/lib/syntax_tree/cli.rb index f3564e29..b839d562 100644 --- a/lib/syntax_tree/cli.rb +++ b/lib/syntax_tree/cli.rb @@ -290,7 +290,7 @@ class Options :target_ruby_version def initialize(print_width: DEFAULT_PRINT_WIDTH) - @ignore_files = "" + @ignore_files = [] @plugins = [] @print_width = print_width @scripts = [] @@ -313,7 +313,7 @@ def parser # Any of the CLI commands that operate on filenames will then ignore # this set of files. opts.on("--ignore-files=GLOB") do |glob| - @ignore_files = glob.match(/\A'(.*)'\z/) ? $1 : glob + @ignore_files << (glob.match(/\A'(.*)'\z/) ? $1 : glob) end # If there are any plugins specified on the command line, then load @@ -434,7 +434,7 @@ def run(argv) .glob(pattern) .each do |filepath| if File.readable?(filepath) && - !File.fnmatch?(options.ignore_files, filepath) + options.ignore_files.none? { File.fnmatch?(_1, filepath) } queue << FileItem.new(filepath) end end diff --git a/lib/syntax_tree/formatter.rb b/lib/syntax_tree/formatter.rb index 4c7a00db..f878490c 100644 --- a/lib/syntax_tree/formatter.rb +++ b/lib/syntax_tree/formatter.rb @@ -62,21 +62,39 @@ def format(node, stackable: true) # If there are comments, then we're going to format them around the node # so that they get printed properly. if node.comments.any? - leading, trailing = node.comments.partition(&:leading?) + trailing = [] + last_leading = nil - # Print all comments that were found before the node. - leading.each do |comment| - comment.format(self) - breakable(force: true) + # First, we're going to print all of the comments that were found before + # the node. We'll also gather up any trailing comments that we find. + node.comments.each do |comment| + if comment.leading? + comment.format(self) + breakable(force: true) + last_leading = comment + else + trailing << comment + end end # If the node has a stree-ignore comment right before it, then we're # going to just print out the node as it was seen in the source. doc = - if leading.last&.ignore? + if last_leading&.ignore? range = source[node.location.start_char...node.location.end_char] - separator = -> { breakable(indent: false, force: true) } - seplist(range.split(/\r?\n/, -1), separator) { |line| text(line) } + first = true + + range.each_line(chomp: true) do |line| + if first + first = false + else + breakable_return + end + + text(line) + end + + breakable_return if range.end_with?("\n") else node.format(self) end @@ -101,6 +119,10 @@ def format_each(nodes) nodes.each { |node| format(node) } end + def grandparent + stack[-3] + end + def parent stack[-2] end @@ -108,5 +130,42 @@ def parent def parents stack[0...-1].reverse_each end + + # This is a simplified version of prettyprint's group. It doesn't provide + # any of the more advanced options because we don't need them and they take + # up expensive computation time. + def group + contents = [] + doc = Group.new(0, contents: contents) + + groups << doc + target << doc + + with_target(contents) { yield } + groups.pop + doc + end + + # A similar version to the super, except that it calls back into the + # separator proc with the instance of `self`. + def seplist(list, sep = nil, iter_method = :each) + first = true + list.__send__(iter_method) do |*v| + if first + first = false + elsif sep + sep.call(self) + else + comma_breakable + end + yield(*v) + end + end + + # This is a much simplified version of prettyprint's text. It avoids + # calculating width by pushing the string directly onto the target. + def text(string) + target << string + end end end diff --git a/lib/syntax_tree/node.rb b/lib/syntax_tree/node.rb index 7ecd69ff..dcdd0275 100644 --- a/lib/syntax_tree/node.rb +++ b/lib/syntax_tree/node.rb @@ -177,10 +177,10 @@ def format(q) q.text("BEGIN ") q.format(lbrace) q.indent do - q.breakable + q.breakable_space q.format(statements) end - q.breakable + q.breakable_space q.text("}") end end @@ -280,10 +280,10 @@ def format(q) q.text("END ") q.format(lbrace) q.indent do - q.breakable + q.breakable_space q.format(statements) end - q.breakable + q.breakable_space q.text("}") end end @@ -327,10 +327,20 @@ def deconstruct_keys(_keys) def format(q) q.text("__END__") - q.breakable(force: true) + q.breakable_force - separator = -> { q.breakable(indent: false, force: true) } - q.seplist(value.split(/\r?\n/, -1), separator) { |line| q.text(line) } + first = true + value.each_line(chomp: true) do |line| + if first + first = false + else + q.breakable_return + end + + q.text(line) + end + + q.breakable_return if value.end_with?("\n") end end @@ -412,7 +422,7 @@ def format(q) q.format(left_argument, stackable: false) q.group do q.nest(keyword.length) do - q.breakable(force: left_argument.comments.any?) + left_argument.comments.any? ? q.breakable_force : q.breakable_space q.format(AliasArgumentFormatter.new(right), stackable: false) end end @@ -476,10 +486,10 @@ def format(q) if index q.indent do - q.breakable("") + q.breakable_empty q.format(index) end - q.breakable("") + q.breakable_empty end q.text("]") @@ -537,10 +547,10 @@ def format(q) if index q.indent do - q.breakable("") + q.breakable_empty q.format(index) end - q.breakable("") + q.breakable_empty end q.text("]") @@ -593,25 +603,30 @@ def format(q) return end - q.group(0, "(", ")") do + q.text("(") + q.group do q.indent do - q.breakable("") + q.breakable_empty q.format(arguments) q.if_break { q.text(",") } if q.trailing_comma? && trailing_comma? end - q.breakable("") + q.breakable_empty end + q.text(")") end private def trailing_comma? - case arguments - in Args[parts: [*, ArgBlock]] + return false unless arguments.is_a?(Args) + parts = arguments.parts + + if parts.last.is_a?(ArgBlock) # If the last argument is a block, then we can't put a trailing comma # after it without resulting in a syntax error. false - in Args[parts: [Command | CommandCall]] + elsif (parts.length == 1) && (part = parts.first) && + (part.is_a?(Command) || part.is_a?(CommandCall)) # If the only argument is a command or command call, then a trailing # comma would be parsed as part of that expression instead of on this # one, so we don't want to add a trailing comma. @@ -790,6 +805,17 @@ def format(q) # [one, two, three] # class ArrayLiteral < Node + # It's very common to use seplist with ->(q) { q.breakable_space }. We wrap + # that pattern into an object to cut down on having to create a bunch of + # lambdas all over the place. + class BreakableSpaceSeparator + def call(q) + q.breakable_space + end + end + + BREAKABLE_SPACE_SEPARATOR = BreakableSpaceSeparator.new + # Formats an array of multiple simple string literals into the %w syntax. class QWordsFormatter # [Args] the contents of the array @@ -800,10 +826,11 @@ def initialize(contents) end def format(q) - q.group(0, "%w[", "]") do + q.text("%w[") + q.group do q.indent do - q.breakable("") - q.seplist(contents.parts, -> { q.breakable }) do |part| + q.breakable_empty + q.seplist(contents.parts, BREAKABLE_SPACE_SEPARATOR) do |part| if part.is_a?(StringLiteral) q.format(part.parts.first) else @@ -811,8 +838,9 @@ def format(q) end end end - q.breakable("") + q.breakable_empty end + q.text("]") end end @@ -826,15 +854,17 @@ def initialize(contents) end def format(q) - q.group(0, "%i[", "]") do + q.text("%i[") + q.group do q.indent do - q.breakable("") - q.seplist(contents.parts, -> { q.breakable }) do |part| + q.breakable_empty + q.seplist(contents.parts, BREAKABLE_SPACE_SEPARATOR) do |part| q.format(part.value) end end - q.breakable("") + q.breakable_empty end + q.text("]") end end @@ -861,6 +891,14 @@ def format(q) # # provided the line length was hit between `bar` and `baz`. class VarRefsFormatter + # The separator for the fill algorithm. + class Separator + def call(q) + q.text(",") + q.fill_breakable + end + end + # [Args] the contents of the array attr_reader :contents @@ -869,20 +907,16 @@ def initialize(contents) end def format(q) - q.group(0, "[", "]") do + q.text("[") + q.group do q.indent do - q.breakable("") - - separator = -> do - q.text(",") - q.fill_breakable - end - - q.seplist(contents.parts, separator) { |part| q.format(part) } + q.breakable_empty + q.seplist(contents.parts, Separator.new) { |part| q.format(part) } q.if_break { q.text(",") } if q.trailing_comma? end - q.breakable("") + q.breakable_empty end + q.text("]") end end @@ -902,11 +936,11 @@ def format(q) q.text("[") q.indent do lbracket.comments.each do |comment| - q.breakable(force: true) + q.breakable_force comment.format(q) end end - q.breakable(force: true) + q.breakable_force q.text("]") end end @@ -973,13 +1007,13 @@ def format(q) if contents q.indent do - q.breakable("") + q.breakable_empty q.format(contents) q.if_break { q.text(",") } if q.trailing_comma? end end - q.breakable("") + q.breakable_empty q.text("]") end end @@ -1127,7 +1161,7 @@ def format(q) q.format(constant) if constant q.text("[") q.indent do - q.breakable("") + q.breakable_empty parts = [*requireds] parts << RestFormatter.new(rest) if rest @@ -1135,7 +1169,7 @@ def format(q) q.seplist(parts) { |part| q.format(part) } end - q.breakable("") + q.breakable_empty q.text("]") end end @@ -1145,13 +1179,13 @@ def format(q) module AssignFormatting def self.skip_indent?(value) case value - in ArrayLiteral | HashLiteral | Heredoc | Lambda | QSymbols | QWords | - Symbols | Words + when ArrayLiteral, HashLiteral, Heredoc, Lambda, QSymbols, QWords, + Symbols, Words true - in Call[receiver:] - skip_indent?(receiver) - in DynaSymbol[quote:] - quote.start_with?("%s") + when Call + skip_indent?(value.receiver) + when DynaSymbol + value.quote.start_with?("%s") else false end @@ -1206,7 +1240,7 @@ def format(q) q.format(value) else q.indent do - q.breakable + q.breakable_space q.format(value) end end @@ -1277,7 +1311,7 @@ def format_contents(q) q.format(value) else q.indent do - q.breakable + q.breakable_space q.format(value) end end @@ -1404,17 +1438,22 @@ class Labels def format_key(q, key) case key - in Label + when Label q.format(key) - in SymbolLiteral + when SymbolLiteral q.format(key.value) q.text(":") - in DynaSymbol[parts: [TStringContent[value: LABEL] => part]] - q.format(part) - q.text(":") - in DynaSymbol - q.format(key) - q.text(":") + when DynaSymbol + parts = key.parts + + if parts.length == 1 && (part = parts.first) && + part.is_a?(TStringContent) && part.value.match?(LABEL) + q.format(part) + q.text(":") + else + q.format(key) + q.text(":") + end end end end @@ -1424,8 +1463,7 @@ class Rockets def format_key(q, key) case key when Label - q.text(":") - q.text(key.value.chomp(":")) + q.text(":#{key.value.chomp(":")}") when DynaSymbol q.text(":") q.format(key) @@ -1544,12 +1582,12 @@ def format(q) unless bodystmt.empty? q.indent do - q.breakable(force: true) unless bodystmt.statements.empty? + q.breakable_force unless bodystmt.statements.empty? q.format(bodystmt) end end - q.breakable(force: true) + q.breakable_force q.text("end") end end @@ -1592,10 +1630,10 @@ def format(q) q.text("^(") q.nest(1) do q.indent do - q.breakable("") + q.breakable_empty q.format(statement) end - q.breakable("") + q.breakable_empty q.text(")") end end @@ -1661,15 +1699,13 @@ def format(q) q.text(" ") unless power if operator == :<< - q.text(operator.to_s) - q.text(" ") + q.text("<< ") q.format(right) else q.group do - q.text(operator.to_s) - + q.text(operator.name) q.indent do - q.breakable(power ? "" : " ") + power ? q.breakable_empty : q.breakable_space q.format(right) end end @@ -1716,15 +1752,29 @@ def deconstruct_keys(_keys) { params: params, locals: locals, location: location, comments: comments } end + # Within the pipes of the block declaration, we don't want any spaces. So + # we'll separate the parameters with a comma and space but no breakables. + class Separator + def call(q) + q.text(", ") + end + end + + # We'll keep a single instance of this separator around for all block vars + # to cut down on allocations. + SEPARATOR = Separator.new + def format(q) - q.group(0, "|", "|") do + q.text("|") + q.group do q.remove_breaks(q.format(params)) if locals.any? q.text("; ") - q.seplist(locals, -> { q.text(", ") }) { |local| q.format(local) } + q.seplist(locals, SEPARATOR) { |local| q.format(local) } end end + q.text("|") end end @@ -1816,10 +1866,8 @@ def bind(start_char, start_column, end_char, end_column) end_column: end_column ) - parts = [rescue_clause, else_clause, ensure_clause] - # Here we're going to determine the bounds for the statements - consequent = parts.compact.first + consequent = rescue_clause || else_clause || ensure_clause statements.bind( start_char, start_column, @@ -1829,7 +1877,7 @@ def bind(start_char, start_column, end_char, end_column) # Next we're going to determine the rescue clause if there is one if rescue_clause - consequent = parts.drop(1).compact.first + consequent = else_clause || ensure_clause rescue_clause.bind_end( consequent ? consequent.location.start_char : end_char, consequent ? consequent.location.start_column : end_column @@ -1868,26 +1916,26 @@ def format(q) if rescue_clause q.nest(-2) do - q.breakable(force: true) + q.breakable_force q.format(rescue_clause) end end if else_clause q.nest(-2) do - q.breakable(force: true) + q.breakable_force q.format(else_keyword) end unless else_clause.empty? - q.breakable(force: true) + q.breakable_force q.format(else_clause) end end if ensure_clause q.nest(-2) do - q.breakable(force: true) + q.breakable_force q.format(ensure_clause) end end @@ -1955,12 +2003,11 @@ def format(q) # If the receiver of this block a Command or CommandCall node, then there # are no parentheses around the arguments to that command, so we need to # break the block. - case q.parent - in { call: Command | CommandCall } + case q.parent.call + when Command, CommandCall q.break_parent format_break(q, break_opening, break_closing) return - else end q.group do @@ -1980,9 +2027,9 @@ def unchangeable_bounds?(q) # know for certain we're going to get split over multiple lines # anyway. case parent - in Statements | ArgParen + when Statements, ArgParen break false - in Command | CommandCall + when Command, CommandCall true else false @@ -1993,8 +2040,8 @@ def unchangeable_bounds?(q) # If we're a sibling of a control-flow keyword, then we're going to have to # use the do..end bounds. def forced_do_end_bounds?(q) - case q.parent - in { call: Break | Next | Return | Super } + case q.parent.call + when Break, Next, Return, Super true else false @@ -2004,22 +2051,19 @@ def forced_do_end_bounds?(q) # If we're the predicate of a loop or conditional, then we're going to have # to go with the {..} bounds. def forced_brace_bounds?(q) - parents = q.parents.to_a - parents.each_with_index.any? do |parent, index| - # If we hit certain breakpoints then we know we're safe. - break false if [Paren, Statements].include?(parent.class) + previous = nil + q.parents.any? do |parent| + case parent + when Paren, Statements + # If we hit certain breakpoints then we know we're safe. + return false + when If, IfMod, IfOp, Unless, UnlessMod, While, WhileMod, Until, + UntilMod + return true if parent.predicate == previous + end - [ - If, - IfMod, - IfOp, - Unless, - UnlessMod, - While, - WhileMod, - Until, - UntilMod - ].include?(parent.class) && parent.predicate == parents[index - 1] + previous = parent + false end end @@ -2034,12 +2078,12 @@ def format_break(q, opening, closing) unless statements.empty? q.indent do - q.breakable + q.breakable_space q.format(statements) end end - q.breakable + q.breakable_space q.text(closing) end @@ -2048,17 +2092,17 @@ def format_flat(q, opening, closing) q.format(BlockOpenFormatter.new(opening, block_open), stackable: false) if node.block_var - q.breakable + q.breakable_space q.format(node.block_var) - q.breakable + q.breakable_space end if statements.empty? q.text(" ") if opening == "do" else - q.breakable unless node.block_var + q.breakable_space unless node.block_var q.format(statements) - q.breakable + q.breakable_space end q.text(closing) @@ -2133,105 +2177,128 @@ def format(q) q.group do q.text(keyword) - case node.arguments.parts - in [] + parts = node.arguments.parts + length = parts.length + + if length == 0 # Here there are no arguments at all, so we're not going to print # anything. This would be like if we had: # # break # - in [ - Paren[ - contents: { - body: [ArrayLiteral[contents: { parts: [_, _, *] }] => array] } - ] - ] - # Here we have a single argument that is a set of parentheses wrapping - # an array literal that has at least 2 elements. We're going to print - # the contents of the array directly. This would be like if we had: - # - # break([1, 2, 3]) - # - # which we will print as: - # - # break 1, 2, 3 - # - q.text(" ") - format_array_contents(q, array) - in [Paren[contents: { body: [ArrayLiteral => statement] }]] - # Here we have a single argument that is a set of parentheses wrapping - # an array literal that has 0 or 1 elements. We're going to skip the - # parentheses but print the array itself. This would be like if we - # had: - # - # break([1]) - # - # which we will print as: - # - # break [1] - # - q.text(" ") - q.format(statement) - in [Paren[contents: { body: [statement] }]] if skip_parens?(statement) - # Here we have a single argument that is a set of parentheses that - # themselves contain a single statement. That statement is a simple - # value that we can skip the parentheses for. This would be like if we - # had: - # - # break(1) - # - # which we will print as: - # - # break 1 - # - q.text(" ") - q.format(statement) - in [Paren => part] - # Here we have a single argument that is a set of parentheses. We're - # going to print the parentheses themselves as if they were the set of - # arguments. This would be like if we had: - # - # break(foo.bar) - # - q.format(part) - in [ArrayLiteral[contents: { parts: [_, _, *] }] => array] - # Here there is a single argument that is an array literal with at - # least two elements. We skip directly into the array literal's - # elements in order to print the contents. This would be like if we - # had: - # - # break [1, 2, 3] - # - # which we will print as: - # - # break 1, 2, 3 - # - q.text(" ") - format_array_contents(q, array) - in [ArrayLiteral => part] - # Here there is a single argument that is an array literal with 0 or 1 - # elements. In this case we're going to print the array as it is - # because skipping the brackets would change the remaining. This would - # be like if we had: - # - # break [] - # break [1] - # - q.text(" ") - q.format(part) - in [_] - # Here there is a single argument that hasn't matched one of our - # previous cases. We're going to print the argument as it is. This - # would be like if we had: - # - # break foo - # - format_arguments(q, "(", ")") - else + elsif length >= 2 # If there are multiple arguments, format them all. If the line is # going to break into multiple, then use brackets to start and end the # expression. format_arguments(q, " [", "]") + else + # If we get here, then we're formatting a single argument to the flow + # control keyword. + part = parts.first + + case part + when Paren + statements = part.contents.body + + if statements.length == 1 + statement = statements.first + + if statement.is_a?(ArrayLiteral) + contents = statement.contents + + if contents && contents.parts.length >= 2 + # Here we have a single argument that is a set of parentheses + # wrapping an array literal that has at least 2 elements. + # We're going to print the contents of the array directly. + # This would be like if we had: + # + # break([1, 2, 3]) + # + # which we will print as: + # + # break 1, 2, 3 + # + q.text(" ") + format_array_contents(q, statement) + else + # Here we have a single argument that is a set of parentheses + # wrapping an array literal that has 0 or 1 elements. We're + # going to skip the parentheses but print the array itself. + # This would be like if we had: + # + # break([1]) + # + # which we will print as: + # + # break [1] + # + q.text(" ") + q.format(statement) + end + elsif skip_parens?(statement) + # Here we have a single argument that is a set of parentheses + # that themselves contain a single statement. That statement is + # a simple value that we can skip the parentheses for. This + # would be like if we had: + # + # break(1) + # + # which we will print as: + # + # break 1 + # + q.text(" ") + q.format(statement) + else + # Here we have a single argument that is a set of parentheses. + # We're going to print the parentheses themselves as if they + # were the set of arguments. This would be like if we had: + # + # break(foo.bar) + # + q.format(part) + end + else + q.format(part) + end + when ArrayLiteral + contents = part.contents + + if contents && contents.parts.length >= 2 + # Here there is a single argument that is an array literal with at + # least two elements. We skip directly into the array literal's + # elements in order to print the contents. This would be like if + # we had: + # + # break [1, 2, 3] + # + # which we will print as: + # + # break 1, 2, 3 + # + q.text(" ") + format_array_contents(q, part) + else + # Here there is a single argument that is an array literal with 0 + # or 1 elements. In this case we're going to print the array as it + # is because skipping the brackets would change the remaining. + # This would be like if we had: + # + # break [] + # break [1] + # + q.text(" ") + q.format(part) + end + else + # Here there is a single argument that hasn't matched one of our + # previous cases. We're going to print the argument as it is. This + # would be like if we had: + # + # break foo + # + format_arguments(q, "(", ")") + end end end end @@ -2241,29 +2308,34 @@ def format(q) def format_array_contents(q, array) q.if_break { q.text("[") } q.indent do - q.breakable("") + q.breakable_empty q.format(array.contents) end - q.breakable("") + q.breakable_empty q.if_break { q.text("]") } end def format_arguments(q, opening, closing) q.if_break { q.text(opening) } q.indent do - q.breakable(" ") + q.breakable_space q.format(node.arguments) end - q.breakable("") + q.breakable_empty q.if_break { q.text(closing) } end def skip_parens?(node) case node - in FloatLiteral | Imaginary | Int | RationalLiteral - true - in VarRef[value: Const | CVar | GVar | IVar | Kw] + when FloatLiteral, Imaginary, Int, RationalLiteral true + when VarRef + case node.value + when Const, CVar, GVar, IVar, Kw + true + else + false + end else false end @@ -2326,8 +2398,10 @@ def comments def format(q) case operator - in :"::" | Op[value: "::"] + when :"::" q.text(".") + when Op + operator.value == "::" ? q.text(".") : operator.format(q) else operator.format(q) end @@ -2363,13 +2437,18 @@ def format(q) # First, walk down the chain until we get to the point where we're not # longer at a chainable node. loop do - case children.last - in Call[receiver: Call] - children << children.last.receiver - in Call[receiver: MethodAddBlock[call: Call]] - children << children.last.receiver - in MethodAddBlock[call: Call] - children << children.last.call + case (child = children.last) + when Call + case (receiver = child.receiver) + when Call + children << receiver + when MethodAddBlock + receiver.call.is_a?(Call) ? children << receiver : break + else + break + end + when MethodAddBlock + child.call.is_a?(Call) ? children << child.call : break else break end @@ -2388,10 +2467,9 @@ def format(q) # nodes. parent = parents[3] if parent.is_a?(DoBlock) - case parent - in MethodAddBlock[call: FCall[value: { value: "sig" }]] + if parent.is_a?(MethodAddBlock) && parent.call.is_a?(FCall) && + parent.call.value.value == "sig" threshold = 2 - else end end @@ -2434,20 +2512,21 @@ def format_chain(q, children) skip_operator = false while (child = children.pop) - case child - in Call[ - receiver: Call[message: { value: "where" }], - message: { value: "not" } - ] - # This is very specialized behavior wherein we group - # .where.not calls together because it looks better. For more - # information, see - # https://github.com/prettier/plugin-ruby/issues/862. - in Call - # If we're at a Call node and not a MethodAddBlock node in the - # chain then we're going to add a newline so it indents properly. - q.breakable("") - else + if child.is_a?(Call) + if child.receiver.is_a?(Call) && + (child.receiver.message != :call) && + (child.receiver.message.value == "where") && + (child.message.value == "not") + # This is very specialized behavior wherein we group + # .where.not calls together because it looks better. For more + # information, see + # https://github.com/prettier/plugin-ruby/issues/862. + else + # If we're at a Call node and not a MethodAddBlock node in the + # chain then we're going to add a newline so it indents + # properly. + q.breakable_empty + end end format_child( @@ -2460,9 +2539,9 @@ def format_chain(q, children) # If the parent call node has a comment on the message then we need # to print the operator trailing in order to keep it working. - case children.last - in Call[message: { comments: [_, *] }, operator:] - q.format(CallOperatorFormatter.new(operator)) + last_child = children.last + if last_child.is_a?(Call) && last_child.message.comments.any? + q.format(CallOperatorFormatter.new(last_child.operator)) skip_operator = true else skip_operator = false @@ -2477,18 +2556,22 @@ def format_chain(q, children) if empty_except_last case node - in Call + when Call node.format_arguments(q) - in MethodAddBlock[block:] - q.format(block) + when MethodAddBlock + q.format(node.block) end end end def self.chained?(node) + return false if ENV["STREE_FAST_FORMAT"] + case node - in Call | MethodAddBlock[call: Call] + when Call true + when MethodAddBlock + node.call.is_a?(Call) else false end @@ -2500,9 +2583,12 @@ def self.chained?(node) # want to indent the first call. So we'll pop off the first children and # format it separately here. def attach_directly?(node) - [ArrayLiteral, HashLiteral, Heredoc, If, Unless, XStringLiteral].include?( - node.receiver.class - ) + case node.receiver + when ArrayLiteral, HashLiteral, Heredoc, If, Unless, XStringLiteral + true + else + false + end end def format_child( @@ -2514,7 +2600,7 @@ def format_child( ) # First, format the actual contents of the child. case child - in Call + when Call q.group do unless skip_operator q.format(CallOperatorFormatter.new(child.operator)) @@ -2522,7 +2608,7 @@ def format_child( q.format(child.message) if child.message != :call child.format_arguments(q) unless skip_attached end - in MethodAddBlock + when MethodAddBlock q.format(child.block) unless skip_attached end @@ -2530,7 +2616,7 @@ def format_child( # them out here since we're bypassing the normal comment printing. if child.comments.any? && !skip_comments child.comments.each do |comment| - comment.inline? ? q.text(" ") : q.breakable + comment.inline? ? q.text(" ") : q.breakable_space comment.format(q) end @@ -2605,8 +2691,8 @@ def format(q) # If we're at the top of a call chain, then we're going to do some # specialized printing in case we can print it nicely. We _only_ do this # at the top of the chain to avoid weird recursion issues. - if !CallChainFormatter.chained?(q.parent) && - CallChainFormatter.chained?(receiver) + if CallChainFormatter.chained?(receiver) && + !CallChainFormatter.chained?(q.parent) q.group do q .if_break { CallChainFormatter.new(self).format(q) } @@ -2617,15 +2703,15 @@ def format(q) end end + # Print out the arguments to this call. If there are no arguments, then do + #nothing. def format_arguments(q) case arguments - in ArgParen + when ArgParen q.format(arguments) - in Args + when Args q.text(" ") q.format(arguments) - else - # Do nothing if there are no arguments. end end @@ -2642,7 +2728,7 @@ def format_contents(q) q.group do q.indent do if receiver.comments.any? || call_operator.comments.any? - q.breakable(force: true) + q.breakable_force end if call_operator.comments.empty? @@ -2719,9 +2805,9 @@ def format(q) q.format(value) end - q.breakable(force: true) + q.breakable_force q.format(consequent) - q.breakable(force: true) + q.breakable_force q.text("end") end @@ -2782,13 +2868,13 @@ def format(q) q.format(operator) case pattern - in AryPtn | FndPtn | HshPtn + when AryPtn, FndPtn, HshPtn q.text(" ") q.format(pattern) else q.group do q.indent do - q.breakable + q.breakable_space q.format(pattern) end end @@ -2872,38 +2958,40 @@ def deconstruct_keys(_keys) end def format(q) - declaration = -> do - q.group do - q.text("class ") - q.format(constant) - - if superclass - q.text(" < ") - q.format(superclass) - end - end - end - if bodystmt.empty? q.group do - declaration.call - q.breakable(force: true) + format_declaration(q) + q.breakable_force q.text("end") end else q.group do - declaration.call + format_declaration(q) q.indent do - q.breakable(force: true) + q.breakable_force q.format(bodystmt) end - q.breakable(force: true) + q.breakable_force q.text("end") end end end + + private + + def format_declaration(q) + q.group do + q.text("class ") + q.format(constant) + + if superclass + q.text(" < ") + q.format(superclass) + end + end + end end # Comma represents the use of the , operator. @@ -2983,15 +3071,31 @@ def format(q) private def align(q, node, &block) - case node.arguments - in Args[parts: [Def | Defs | DefEndless]] - q.text(" ") - yield - in Args[parts: [IfOp]] - q.if_flat { q.text(" ") } - yield - in Args[parts: [Command => command]] - align(q, command, &block) + arguments = node.arguments + + if arguments.is_a?(Args) + parts = arguments.parts + + if parts.size == 1 + part = parts.first + + case part + when Def, Defs, DefEndless + q.text(" ") + yield + when IfOp + q.if_flat { q.text(" ") } + yield + when Command + align(q, part, &block) + else + q.text(" ") + q.nest(message.value.length + 1) { yield } + end + else + q.text(" ") + q.nest(message.value.length + 1) { yield } + end else q.text(" ") q.nest(message.value.length + 1) { yield } @@ -3069,7 +3173,7 @@ def format(q) if message.comments.any?(&:leading?) q.format(CallOperatorFormatter.new(operator), stackable: false) q.indent do - q.breakable("") + q.breakable_empty q.format(message) end else @@ -3078,15 +3182,18 @@ def format(q) end end - case arguments - in Args[parts: [IfOp]] - q.if_flat { q.text(" ") } - q.format(arguments) - in Args - q.text(" ") - q.nest(argument_alignment(q, doc)) { q.format(arguments) } - else - # If there are no arguments, print nothing. + # Format the arguments for this command call here. If there are no + # arguments, then print nothing. + if arguments + parts = arguments.parts + + if parts.length == 1 && parts.first.is_a?(IfOp) + q.if_flat { q.text(" ") } + q.format(arguments) + else + q.text(" ") + q.nest(argument_alignment(q, doc)) { q.format(arguments) } + end end end end @@ -3155,7 +3262,7 @@ def trailing? end def ignore? - value[1..].strip == "stree-ignore" + value.match?(/\A#\s*stree-ignore\s*\z/) end def comments @@ -3455,12 +3562,12 @@ def format(q) unless bodystmt.empty? q.indent do - q.breakable(force: true) + q.breakable_force q.format(bodystmt) end end - q.breakable(force: true) + q.breakable_force q.text("end") end end @@ -3549,7 +3656,7 @@ def format(q) q.text(" =") q.group do q.indent do - q.breakable + q.breakable_space q.format(statement) end end @@ -3590,13 +3697,15 @@ def deconstruct_keys(_keys) end def format(q) - q.group(0, "defined?(", ")") do + q.text("defined?(") + q.group do q.indent do - q.breakable("") + q.breakable_empty q.format(value) end - q.breakable("") + q.breakable_empty end + q.text(")") end end @@ -3678,12 +3787,12 @@ def format(q) unless bodystmt.empty? q.indent do - q.breakable(force: true) + q.breakable_force q.format(bodystmt) end end - q.breakable(force: true) + q.breakable_force q.text("end") end end @@ -3755,15 +3864,18 @@ def initialize(operator, node) end def format(q) - space = [If, IfMod, Unless, UnlessMod].include?(q.parent.class) - left = node.left right = node.right q.format(left) if left - q.text(" ") if space - q.text(operator) - q.text(" ") if space + + case q.parent + when If, IfMod, Unless, UnlessMod + q.text(" #{operator} ") + else + q.text(operator) + end + q.format(right) if right end end @@ -3948,19 +4060,30 @@ def deconstruct_keys(_keys) def format(q) opening_quote, closing_quote = quotes(q) - q.group(0, opening_quote, closing_quote) do + q.text(opening_quote) + q.group do parts.each do |part| if part.is_a?(TStringContent) value = Quotes.normalize(part.value, closing_quote) - separator = -> { q.breakable(force: true, indent: false) } - q.seplist(value.split(/\r?\n/, -1), separator) do |text| - q.text(text) + first = true + + value.each_line(chomp: true) do |line| + if first + first = false + else + q.breakable_return + end + + q.text(line) end + + q.breakable_return if value.end_with?("\n") else q.format(part) end end end + q.text(closing_quote) end private @@ -4056,7 +4179,7 @@ def format(q) unless statements.empty? q.indent do - q.breakable(force: true) + q.breakable_force q.format(statements) end end @@ -4126,14 +4249,14 @@ def format(q) unless statements.empty? q.indent do - q.breakable(force: true) + q.breakable_force q.format(statements) end end if consequent q.group do - q.breakable(force: true) + q.breakable_force q.format(consequent) end end @@ -4329,7 +4452,7 @@ def format(q) unless statements.empty? q.indent do - q.breakable(force: true) + q.breakable_force q.format(statements) end end @@ -4588,7 +4711,7 @@ def format(q) q.text("[") q.indent do - q.breakable("") + q.breakable_empty q.text("*") q.format(left) @@ -4601,7 +4724,7 @@ def format(q) q.format(right) end - q.breakable("") + q.breakable_empty q.text("]") end end @@ -4663,12 +4786,12 @@ def format(q) unless statements.empty? q.indent do - q.breakable(force: true) + q.breakable_force q.format(statements) end end - q.breakable(force: true) + q.breakable_force q.text("end") end end @@ -4731,11 +4854,11 @@ def format(q) q.text("{") q.indent do lbrace.comments.each do |comment| - q.breakable(force: true) + q.breakable_force comment.format(q) end end - q.breakable(force: true) + q.breakable_force q.text("}") end end @@ -4800,14 +4923,14 @@ def format_contents(q) q.format(lbrace) if assocs.empty? - q.breakable("") + q.breakable_empty else q.indent do - q.breakable + q.breakable_space q.seplist(assocs) { |assoc| q.format(assoc) } q.if_break { q.text(",") } if q.trailing_comma? end - q.breakable + q.breakable_space end q.text("}") @@ -4873,22 +4996,34 @@ def deconstruct_keys(_keys) } end - def format(q) - # This is a very specific behavior where you want to force a newline, but - # don't want to force the break parent. - breakable = -> { q.breakable(indent: false, force: :skip_break_parent) } + # This is a very specific behavior where you want to force a newline, but + # don't want to force the break parent. + SEPARATOR = PrettierPrint::Breakable.new(" ", 1, indent: false, force: true) + def format(q) q.group do q.format(beginning) q.line_suffix(priority: Formatter::HEREDOC_PRIORITY) do q.group do - breakable.call + q.target << SEPARATOR parts.each do |part| if part.is_a?(TStringContent) - texts = part.value.split(/\r?\n/, -1) - q.seplist(texts, breakable) { |text| q.text(text) } + value = part.value + first = true + + value.each_line(chomp: true) do |line| + if first + first = false + else + q.target << SEPARATOR + end + + q.text(line) + end + + q.target << SEPARATOR if value.end_with?("\n") else q.format(part) end @@ -5077,18 +5212,7 @@ def deconstruct_keys(_keys) def format(q) parts = keywords.map { |(key, value)| KeywordFormatter.new(key, value) } parts << KeywordRestFormatter.new(keyword_rest) if keyword_rest - nested = PATTERNS.include?(q.parent.class) - contents = -> do - q.group { q.seplist(parts) { |part| q.format(part, stackable: false) } } - - # If there isn't a constant, and there's a blank keyword_rest, then we - # have an plain ** that needs to have a `then` after it in order to - # parse correctly on the next parse. - if !constant && keyword_rest && keyword_rest.value.nil? && !nested - q.text(" then") - end - end # If there is a constant, we're going to format to have the constant name # first and then use brackets. @@ -5097,10 +5221,10 @@ def format(q) q.format(constant) q.text("[") q.indent do - q.breakable("") - contents.call + q.breakable_empty + format_contents(q, parts, nested) end - q.breakable("") + q.breakable_empty q.text("]") end return @@ -5115,7 +5239,7 @@ def format(q) # If there's only one pair, then we'll just print the contents provided # we're not inside another pattern. if !nested && parts.size == 1 - contents.call + format_contents(q, parts, nested) return end @@ -5124,18 +5248,31 @@ def format(q) q.group do q.text("{") q.indent do - q.breakable - contents.call + q.breakable_space + format_contents(q, parts, nested) end if q.target_ruby_version < Gem::Version.new("2.7.3") q.text(" }") else - q.breakable + q.breakable_space q.text("}") end end end + + private + + def format_contents(q, parts, nested) + q.group { q.seplist(parts) { |part| q.format(part, stackable: false) } } + + # If there isn't a constant, and there's a blank keyword_rest, then we + # have an plain ** that needs to have a `then` after it in order to + # parse correctly on the next parse. + if !constant && keyword_rest && keyword_rest.value.nil? && !nested + q.text(" then") + end + end end # The list of nodes that represent patterns inside of pattern matching so that @@ -5188,8 +5325,12 @@ def self.call(parent) queue = [parent] while (node = queue.shift) - return true if [Assign, MAssign, OpAssign].include?(node.class) - queue += node.child_nodes.compact + case node + when Assign, MAssign, OpAssign + return true + else + node.child_nodes.each { |child| queue << child if child } + end end false @@ -5204,28 +5345,36 @@ def self.call(parent) module Ternaryable class << self def call(q, node) - case q.parents.take(2)[1] - in Paren[contents: Statements[body: [node]]] - # If this is a conditional inside of a parentheses as the only - # content, then we don't want to transform it into a ternary. - # Presumably the user wanted it to be an explicit conditional because - # there are parentheses around it. So we'll just leave it in place. - false - else - # Otherwise, we're going to check the conditional for certain cases. - case node - in predicate: Assign | Command | CommandCall | MAssign | OpAssign - false - in predicate: Not[parentheses: false] - false - in { - statements: { body: [truthy] }, - consequent: Else[statements: { body: [falsy] }] } - ternaryable?(truthy) && ternaryable?(falsy) - else - false - end + return false if ENV["STREE_FAST_FORMAT"] + + # If this is a conditional inside of a parentheses as the only content, + # then we don't want to transform it into a ternary. Presumably the user + # wanted it to be an explicit conditional because there are parentheses + # around it. So we'll just leave it in place. + grandparent = q.grandparent + if grandparent.is_a?(Paren) && (body = grandparent.contents.body) && + body.length == 1 && body.first == node + return false end + + # Otherwise, we'll check the type of predicate. For certain nodes we + # want to force it to not be a ternary, like if the predicate is an + # assignment because it's hard to read. + case node.predicate + when Assign, Command, CommandCall, MAssign, OpAssign + return false + when Not + return false unless node.predicate.parentheses? + end + + # If there's no Else, then this can't be represented as a ternary. + return false unless node.consequent.is_a?(Else) + + truthy_body = node.statements.body + falsy_body = node.consequent.statements.body + + (truthy_body.length == 1) && ternaryable?(truthy_body.first) && + (falsy_body.length == 1) && ternaryable?(falsy_body.first) end private @@ -5234,24 +5383,23 @@ def call(q, node) # parentheses around them. In this case we say they cannot be ternaried # and default instead to breaking them into multiple lines. def ternaryable?(statement) - # This is a list of nodes that should not be allowed to be a part of a - # ternary clause. - no_ternary = [ - Alias, Assign, Break, Command, CommandCall, Heredoc, If, IfMod, IfOp, - Lambda, MAssign, Next, OpAssign, RescueMod, Return, Return0, Super, - Undef, Unless, UnlessMod, Until, UntilMod, VarAlias, VoidStmt, While, - WhileMod, Yield, Yield0, ZSuper - ] - - # Here we're going to check that the only statement inside the - # statements node is no a part of our denied list of nodes that can be - # ternaries. - # - # If the user is using one of the lower precedence "and" or "or" - # operators, then we can't use a ternary expression as it would break - # the flow control. - !no_ternary.include?(statement.class) && - !(statement.is_a?(Binary) && %i[and or].include?(statement.operator)) + case statement + when Alias, Assign, Break, Command, CommandCall, Heredoc, If, IfMod, + IfOp, Lambda, MAssign, Next, OpAssign, RescueMod, Return, Return0, + Super, Undef, Unless, UnlessMod, Until, UntilMod, VarAlias, + VoidStmt, While, WhileMod, Yield, Yield0, ZSuper + # This is a list of nodes that should not be allowed to be a part of a + # ternary clause. + false + when Binary + # If the user is using one of the lower precedence "and" or "or" + # operators, then we can't use a ternary expression as it would break + # the flow control. + operator = statement.operator + operator != :and && operator != :or + else + true + end end end end @@ -5311,17 +5459,17 @@ def format_break(q, force:) unless node.statements.empty? q.indent do - q.breakable(force: force) + force ? q.breakable_force : q.breakable_space q.format(node.statements) end end if node.consequent - q.breakable(force: force) + force ? q.breakable_force : q.breakable_space q.format(node.consequent) end - q.breakable(force: force) + force ? q.breakable_force : q.breakable_space q.text("end") end @@ -5333,11 +5481,11 @@ def format_ternary(q) q.nest(keyword.length + 1) { q.format(node.predicate) } q.indent do - q.breakable + q.breakable_space q.format(node.statements) end - q.breakable + q.breakable_space q.group do q.format(node.consequent.keyword) q.indent do @@ -5351,7 +5499,7 @@ def format_ternary(q) end end - q.breakable + q.breakable_space q.text("end") end .if_flat do @@ -5371,8 +5519,11 @@ def format_ternary(q) end def contains_conditional? - case node - in statements: { body: [If | IfMod | IfOp | Unless | UnlessMod] } + statements = node.statements.body + return false if statements.length != 1 + + case statements.first + when If, IfMod, IfOp, Unless, UnlessMod true else false @@ -5507,19 +5658,19 @@ def format_break(q) q.nest("if ".length) { q.format(predicate) } q.indent do - q.breakable + q.breakable_space q.format(truthy) end - q.breakable + q.breakable_space q.text("else") q.indent do - q.breakable + q.breakable_space q.format(falsy) end - q.breakable + q.breakable_space q.text("end") end end @@ -5529,11 +5680,11 @@ def format_flat(q) q.text(" ?") q.indent do - q.breakable + q.breakable_space q.format(truthy) q.text(" :") - q.breakable + q.breakable_space q.format(falsy) end end @@ -5566,10 +5717,10 @@ def format_break(q) q.text("#{keyword} ") q.nest(keyword.length + 1) { q.format(node.predicate) } q.indent do - q.breakable + q.breakable_space q.format(node.statement) end - q.breakable + q.breakable_space q.text("end") end @@ -5720,13 +5871,13 @@ def format(q) unless statements.empty? q.indent do - q.breakable(force: true) + q.breakable_force q.format(statements) end end if consequent - q.breakable(force: true) + q.breakable_force q.format(consequent) end end @@ -5830,11 +5981,15 @@ class Kw < Node # [String] the value of the keyword attr_reader :value + # [Symbol] the symbol version of the value + attr_reader :name + # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments def initialize(value:, location:, comments: []) @value = value + @name = value.to_sym @location = location @comments = comments end @@ -6013,7 +6168,8 @@ def deconstruct_keys(_keys) end def format(q) - q.group(0, "->") do + q.text("->") + q.group do if params.is_a?(Paren) q.format(params) unless params.contents.empty? elsif params.empty? && params.comments.any? @@ -6039,10 +6195,10 @@ def format(q) unless statements.empty? q.indent do - q.breakable + q.breakable_space q.format(statements) end - q.breakable + q.breakable_space end q.text("}") @@ -6051,12 +6207,12 @@ def format(q) unless statements.empty? q.indent do - q.breakable + q.breakable_space q.format(statements) end end - q.breakable + q.breakable_space q.text("end") end end @@ -6123,7 +6279,7 @@ def format(q) if locals.any? q.text("; ") - q.seplist(locals, -> { q.text(", ") }) { |local| q.format(local) } + q.seplist(locals, BlockVar::SEPARATOR) { |local| q.format(local) } end end end @@ -6277,7 +6433,7 @@ def format(q) q.group { q.format(target) } q.text(" =") q.indent do - q.breakable + q.breakable_space q.format(value) end end @@ -6323,8 +6479,8 @@ def format(q) # If we're at the top of a call chain, then we're going to do some # specialized printing in case we can print it nicely. We _only_ do this # at the top of the chain to avoid weird recursion issues. - if !CallChainFormatter.chained?(q.parent) && - CallChainFormatter.chained?(call) + if CallChainFormatter.chained?(call) && + !CallChainFormatter.chained?(q.parent) q.group do q .if_break { CallChainFormatter.new(self).format(q) } @@ -6431,15 +6587,17 @@ def format(q) q.format(contents) q.text(",") if comma else - q.group(0, "(", ")") do + q.text("(") + q.group do q.indent do - q.breakable("") + q.breakable_empty q.format(contents) end q.text(",") if comma - q.breakable("") + q.breakable_empty end + q.text(")") end end end @@ -6486,33 +6644,35 @@ def deconstruct_keys(_keys) end def format(q) - declaration = -> do - q.group do - q.text("module ") - q.format(constant) - end - end - if bodystmt.empty? q.group do - declaration.call - q.breakable(force: true) + format_declaration(q) + q.breakable_force q.text("end") end else q.group do - declaration.call + format_declaration(q) q.indent do - q.breakable(force: true) + q.breakable_force q.format(bodystmt) end - q.breakable(force: true) + q.breakable_force q.text("end") end end end + + private + + def format_declaration(q) + q.group do + q.text("module ") + q.format(constant) + end + end end # MRHS represents the values that are being assigned on the right-hand side of @@ -6610,11 +6770,15 @@ class Op < Node # [String] the operator attr_reader :value + # [Symbol] the symbol version of the value + attr_reader :name + # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments def initialize(value:, location:, comments: []) @value = value + @name = value.to_sym @location = location @comments = comments end @@ -6696,7 +6860,7 @@ def format(q) q.format(value) else q.indent do - q.breakable + q.breakable_space q.format(value) end end @@ -6767,10 +6931,10 @@ def self.break(q) q.text("(") q.indent do - q.breakable("") + q.breakable_empty yield end - q.breakable("") + q.breakable_empty q.text(")") end end @@ -6962,23 +7126,35 @@ def format(q) parts << KeywordRestFormatter.new(keyword_rest) if keyword_rest parts << block if block - contents = -> do - q.seplist(parts) { |part| q.format(part) } - q.format(rest) if rest.is_a?(ExcessedComma) + if parts.empty? + q.nest(0) { format_contents(q, parts) } + return end - if ![Def, Defs, DefEndless].include?(q.parent.class) || parts.empty? - q.nest(0, &contents) - else - q.group(0, "(", ")") do - q.indent do - q.breakable("") - contents.call + case q.parent + when Def, Defs, DefEndless + q.nest(0) do + q.text("(") + q.group do + q.indent do + q.breakable_empty + format_contents(q, parts) + end + q.breakable_empty end - q.breakable("") + q.text(")") end + else + q.nest(0) { format_contents(q, parts) } end end + + private + + def format_contents(q, parts) + q.seplist(parts) { |part| q.format(part) } + q.format(rest) if rest.is_a?(ExcessedComma) + end end # Paren represents using balanced parentheses in a couple places in a Ruby @@ -7029,12 +7205,12 @@ def format(q) if contents && (!contents.is_a?(Params) || !contents.empty?) q.indent do - q.breakable("") + q.breakable_empty q.format(contents) end end - q.breakable("") + q.breakable_empty q.text(")") end end @@ -7108,7 +7284,7 @@ def format(q) # We're going to put a newline on the end so that it always has one unless # it ends with the special __END__ syntax. In that case we want to # replicate the text exactly so we will just let it be. - q.breakable(force: true) unless statements.body.last.is_a?(EndContent) + q.breakable_force unless statements.body.last.is_a?(EndContent) end end @@ -7160,15 +7336,18 @@ def format(q) closing = Quotes.matching(opening[2]) end - q.group(0, opening, closing) do + q.text(opening) + q.group do q.indent do - q.breakable("") - q.seplist(elements, -> { q.breakable }) do |element| - q.format(element) - end + q.breakable_empty + q.seplist( + elements, + ArrayLiteral::BREAKABLE_SPACE_SEPARATOR + ) { |element| q.format(element) } end - q.breakable("") + q.breakable_empty end + q.text(closing) end end @@ -7251,15 +7430,18 @@ def format(q) closing = Quotes.matching(opening[2]) end - q.group(0, opening, closing) do + q.text(opening) + q.group do q.indent do - q.breakable("") - q.seplist(elements, -> { q.breakable }) do |element| - q.format(element) - end + q.breakable_empty + q.seplist( + elements, + ArrayLiteral::BREAKABLE_SPACE_SEPARATOR + ) { |element| q.format(element) } end - q.breakable("") + q.breakable_empty end + q.text(closing) end end @@ -7781,13 +7963,13 @@ def format(q) unless statements.empty? q.indent do - q.breakable(force: true) + q.breakable_force q.format(statements) end end if consequent - q.breakable(force: true) + q.breakable_force q.format(consequent) end end @@ -7835,19 +8017,21 @@ def deconstruct_keys(_keys) end def format(q) - q.group(0, "begin", "end") do + q.text("begin") + q.group do q.indent do - q.breakable(force: true) + q.breakable_force q.format(statement) end - q.breakable(force: true) + q.breakable_force q.text("rescue StandardError") q.indent do - q.breakable(force: true) + q.breakable_force q.format(value) end - q.breakable(force: true) + q.breakable_force end + q.text("end") end end @@ -8066,14 +8250,16 @@ def deconstruct_keys(_keys) end def format(q) - q.group(0, "class << ", "end") do + q.text("class << ") + q.group do q.format(target) q.indent do - q.breakable(force: true) + q.breakable_force q.format(bodystmt) end - q.breakable(force: true) + q.breakable_force end + q.text("end") end end @@ -8179,30 +8365,26 @@ def format(q) end end - access_controls = - Hash.new do |hash, node| - hash[node] = node.is_a?(VCall) && - %w[private protected public].include?(node.value.value) - end - - body.each_with_index do |statement, index| + previous = nil + body.each do |statement| next if statement.is_a?(VoidStmt) if line.nil? q.format(statement) elsif (statement.location.start_line - line) > 1 - q.breakable(force: true) - q.breakable(force: true) + q.breakable_force + q.breakable_force q.format(statement) - elsif access_controls[statement] || access_controls[body[index - 1]] - q.breakable(force: true) - q.breakable(force: true) + elsif (statement.is_a?(VCall) && statement.access_control?) || + (previous.is_a?(VCall) && previous.access_control?) + q.breakable_force + q.breakable_force q.format(statement) elsif statement.location.start_line != line - q.breakable(force: true) + q.breakable_force q.format(statement) elsif !q.parent.is_a?(StringEmbExpr) - q.breakable(force: true) + q.breakable_force q.format(statement) else q.text("; ") @@ -8210,6 +8392,7 @@ def format(q) end line = statement.location.end_line + previous = statement end end @@ -8327,7 +8510,7 @@ def format(q) q.format(left) q.text(" \\") q.indent do - q.breakable(force: true) + q.breakable_force q.format(right) end end @@ -8413,15 +8596,21 @@ def format(q) # same line in the source, then we're going to leave them in place and # assume that's the way the developer wanted this expression # represented. - q.remove_breaks(q.group(0, '#{', "}") { q.format(statements) }) + q.remove_breaks( + q.group do + q.text('#{') + q.format(statements) + q.text("}") + end + ) else q.group do q.text('#{') q.indent do - q.breakable("") + q.breakable_empty q.format(statements) end - q.breakable("") + q.breakable_empty q.text("}") end end @@ -8479,19 +8668,30 @@ def format(q) [quote, quote] end - q.group(0, opening_quote, closing_quote) do + q.text(opening_quote) + q.group do parts.each do |part| if part.is_a?(TStringContent) value = Quotes.normalize(part.value, closing_quote) - separator = -> { q.breakable(force: true, indent: false) } - q.seplist(value.split(/\r?\n/, -1), separator) do |text| - q.text(text) + first = true + + value.each_line(chomp: true) do |line| + if first + first = false + else + q.breakable_return + end + + q.text(line) end + + q.breakable_return if value.end_with?("\n") else q.format(part) end end end + q.text(closing_quote) end end @@ -8698,15 +8898,18 @@ def format(q) closing = Quotes.matching(opening[2]) end - q.group(0, opening, closing) do + q.text(opening) + q.group do q.indent do - q.breakable("") - q.seplist(elements, -> { q.breakable }) do |element| - q.format(element) - end + q.breakable_empty + q.seplist( + elements, + ArrayLiteral::BREAKABLE_SPACE_SEPARATOR + ) { |element| q.format(element) } end - q.breakable("") + q.breakable_empty end + q.text(closing) end end @@ -9000,6 +9203,7 @@ class Not < Node # [boolean] whether or not parentheses were used attr_reader :parentheses + alias parentheses? parentheses # [Array[ Comment | EmbDoc ]] the comments attached to this node attr_reader :comments @@ -9031,27 +9235,26 @@ def deconstruct_keys(_keys) end def format(q) - parent = q.parents.take(2)[1] - ternary = - (parent.is_a?(If) || parent.is_a?(Unless)) && - Ternaryable.call(q, parent) - q.text("not") if parentheses q.text("(") - elsif ternary - q.if_break { q.text(" ") }.if_flat { q.text("(") } - else - q.text(" ") - end - - q.format(statement) if statement - - if parentheses + q.format(statement) if statement q.text(")") - elsif ternary - q.if_flat { q.text(")") } + else + grandparent = q.grandparent + ternary = + (grandparent.is_a?(If) || grandparent.is_a?(Unless)) && + Ternaryable.call(q, grandparent) + + if ternary + q.if_break { q.text(" ") }.if_flat { q.text("(") } + q.format(statement) if statement + q.if_flat { q.text(")") } if ternary + else + q.text(" ") + q.format(statement) if statement + end end end end @@ -9316,10 +9519,10 @@ def format_break(q) q.text("#{keyword} ") q.nest(keyword.length + 1) { q.format(node.predicate) } q.indent do - q.breakable("") + q.breakable_empty q.format(statements) end - q.breakable("") + q.breakable_empty q.text("end") end end @@ -9372,7 +9575,7 @@ def format(q) q.group do q.text(keyword) q.nest(keyword.length) { q.format(predicate) } - q.breakable(force: true) + q.breakable_force q.text("end") end else @@ -9572,6 +9775,29 @@ def deconstruct_keys(_keys) def format(q) q.format(value) end + + # Oh man I hate this so much. Basically, ripper doesn't provide enough + # functionality to actually know where pins are within an expression. So we + # have to walk the tree ourselves and insert more information. In doing so, + # we have to replace this node by a pinned node when necessary. + # + # To be clear, this method should just not exist. It's not good. It's a + # place of shame. But it's necessary for now, so I'm keeping it. + def pin(parent) + replace = PinnedVarRef.new(value: value, location: location) + + parent + .deconstruct_keys([]) + .each do |key, value| + if value == self + parent.instance_variable_set(:"@#{key}", replace) + break + elsif value.is_a?(Array) && (index = value.index(self)) + parent.public_send(key)[index] = replace + break + end + end + end end # PinnedVarRef represents a pinned variable reference within a pattern @@ -9653,6 +9879,10 @@ def deconstruct_keys(_keys) def format(q) q.format(value) end + + def access_control? + @access_control ||= %w[private protected public].include?(value.value) + end end # VoidStmt represents an empty lexical block of code. @@ -9742,6 +9972,22 @@ def deconstruct_keys(_keys) } end + # We have a special separator here for when clauses which causes them to + # fill as much of the line as possible as opposed to everything breaking + # into its own line as soon as you hit the print limit. + class Separator + def call(q) + q.group do + q.text(",") + q.breakable_space + end + end + end + + # We're going to keep a single instance of this separator around so we don't + # have to allocate a new one every time we format a when clause. + SEPARATOR = Separator.new + def format(q) keyword = "when " @@ -9752,8 +9998,7 @@ def format(q) if arguments.comments.any? q.format(arguments) else - separator = -> { q.group { q.comma_breakable } } - q.seplist(arguments.parts, separator) { |part| q.format(part) } + q.seplist(arguments.parts, SEPARATOR) { |part| q.format(part) } end # Very special case here. If you're inside of a when clause and the @@ -9768,13 +10013,13 @@ def format(q) unless statements.empty? q.indent do - q.breakable(force: true) + q.breakable_force q.format(statements) end end if consequent - q.breakable(force: true) + q.breakable_force q.format(consequent) end end @@ -9829,7 +10074,7 @@ def format(q) q.group do q.text(keyword) q.nest(keyword.length) { q.format(predicate) } - q.breakable(force: true) + q.breakable_force q.text("end") end else @@ -9995,15 +10240,18 @@ def format(q) closing = Quotes.matching(opening[2]) end - q.group(0, opening, closing) do + q.text(opening) + q.group do q.indent do - q.breakable("") - q.seplist(elements, -> { q.breakable }) do |element| - q.format(element) - end + q.breakable_empty + q.seplist( + elements, + ArrayLiteral::BREAKABLE_SPACE_SEPARATOR + ) { |element| q.format(element) } end - q.breakable("") + q.breakable_empty end + q.text(closing) end end @@ -10147,10 +10395,10 @@ def format(q) else q.if_break { q.text("(") }.if_flat { q.text(" ") } q.indent do - q.breakable("") + q.breakable_empty q.format(arguments) end - q.breakable("") + q.breakable_empty q.if_break { q.text(")") } end end diff --git a/lib/syntax_tree/parser.rb b/lib/syntax_tree/parser.rb index 94ce115a..61a7ca57 100644 --- a/lib/syntax_tree/parser.rb +++ b/lib/syntax_tree/parser.rb @@ -60,29 +60,46 @@ def [](byteindex) # This represents all of the tokens coming back from the lexer. It is # replacing a simple array because it keeps track of the last deleted token # from the list for better error messages. - class TokenList < SimpleDelegator - attr_reader :last_deleted + class TokenList + attr_reader :tokens, :last_deleted - def initialize(object) - super + def initialize + @tokens = [] @last_deleted = nil end + def <<(token) + tokens << token + end + + def [](index) + tokens[index] + end + + def any?(&block) + tokens.any?(&block) + end + + def reverse_each(&block) + tokens.reverse_each(&block) + end + + def rindex(&block) + tokens.rindex(&block) + end + def delete(value) - @last_deleted = super || @last_deleted + @last_deleted = tokens.delete(value) || @last_deleted end def delete_at(index) - @last_deleted = super + @last_deleted = tokens.delete_at(index) end end # [String] the source being parsed attr_reader :source - # [Array[ String ]] the list of lines in the source - attr_reader :lines - # [Array[ SingleByteString | MultiByteString ]] the list of objects that # represent the start of each line in character offsets attr_reader :line_counts @@ -105,12 +122,6 @@ def initialize(source, *) # example. @source = source - # Similarly, we keep the lines of the source string around to be able to - # check if certain lines contain certain characters. For example, we'll - # use this to generate the content that goes after the __END__ keyword. - # Or we'll use this to check if a comment has other content on its line. - @lines = source.split(/\r?\n/) - # This is the full set of comments that have been found by the parser. # It's a running list. At the end of every block of statements, they will # go in and attempt to grab any comments that are on their own line and @@ -144,7 +155,7 @@ def initialize(source, *) # Most of the time, when a parser event consumes one of these events, it # will be deleted from the list. So ideally, this list stays pretty short # over the course of parsing a source string. - @tokens = TokenList.new([]) + @tokens = TokenList.new # Here we're going to build up a list of SingleByteString or # MultiByteString objects. They're each going to represent a string in the @@ -153,7 +164,7 @@ def initialize(source, *) @line_counts = [] last_index = 0 - @source.lines.each do |line| + @source.each_line do |line| @line_counts << if line.size == line.bytesize SingleByteString.new(last_index) else @@ -233,28 +244,55 @@ def find_token_error(location) # "module" (which would happen to be the innermost keyword). Then the outer # one would only be able to grab the first one. In this way all of the # tokens act as their own stack. - def find_token(type, value = :any, consume: true, location: nil) - index = - tokens.rindex do |token| - token.is_a?(type) && (value == :any || (token.value == value)) - end + # + # If we're expecting to be able to find a token and consume it, but can't + # actually find it, then we need to raise an error. This is _usually_ caused + # by a syntax error in the source that we're printing. It could also be + # caused by accidentally attempting to consume a token twice by two + # different parser event handlers. - if consume - # If we're expecting to be able to find a token and consume it, but - # can't actually find it, then we need to raise an error. This is - # _usually_ caused by a syntax error in the source that we're printing. - # It could also be caused by accidentally attempting to consume a token - # twice by two different parser event handlers. - unless index - token = value == :any ? type.name.split("::", 2).last : value - message = "Cannot find expected #{token}" - raise ParseError.new(message, *find_token_error(location)) - end + def find_token(type) + index = tokens.rindex { |token| token.is_a?(type) } + tokens[index] if index + end - tokens.delete_at(index) - elsif index - tokens[index] - end + def find_keyword(name) + index = tokens.rindex { |token| token.is_a?(Kw) && (token.name == name) } + tokens[index] if index + end + + def find_operator(name) + index = tokens.rindex { |token| token.is_a?(Op) && (token.name == name) } + tokens[index] if index + end + + def consume_error(name, location) + message = "Cannot find expected #{name}" + raise ParseError.new(message, *find_token_error(location)) + end + + def consume_token(type) + index = tokens.rindex { |token| token.is_a?(type) } + consume_error(type.name.split("::", 2).last, nil) unless index + tokens.delete_at(index) + end + + def consume_tstring_end(location) + index = tokens.rindex { |token| token.is_a?(TStringEnd) } + consume_error("string ending", location) unless index + tokens.delete_at(index) + end + + def consume_keyword(name) + index = tokens.rindex { |token| token.is_a?(Kw) && (token.name == name) } + consume_error(name, nil) unless index + tokens.delete_at(index) + end + + def consume_operator(name) + index = tokens.rindex { |token| token.is_a?(Op) && (token.name == name) } + consume_error(name, nil) unless index + tokens.delete_at(index) end # A helper function to find a :: operator. We do special handling instead of @@ -283,13 +321,18 @@ def find_colon2_before(const) # By finding the next non-space character, we can make sure that the bounds # of the statement list are correct. def find_next_statement_start(position) - remaining = source[position..] - - if remaining.sub(/\A +/, "")[0] == "#" - return position + remaining.index("\n") + maximum = source.length + + position.upto(maximum) do |pound_index| + case source[pound_index] + when "#" + return source.index("\n", pound_index + 1) || maximum + when " " + # continue + else + return position + end end - - position end # -------------------------------------------------------------------------- @@ -300,8 +343,8 @@ def find_next_statement_start(position) # :call-seq: # on_BEGIN: (Statements statements) -> BEGINBlock def on_BEGIN(statements) - lbrace = find_token(LBrace) - rbrace = find_token(RBrace) + lbrace = consume_token(LBrace) + rbrace = consume_token(RBrace) start_char = find_next_statement_start(lbrace.location.end_char) statements.bind( @@ -311,7 +354,7 @@ def on_BEGIN(statements) rbrace.location.start_column ) - keyword = find_token(Kw, "BEGIN") + keyword = consume_keyword(:BEGIN) BEGINBlock.new( lbrace: lbrace, @@ -338,8 +381,8 @@ def on_CHAR(value) # :call-seq: # on_END: (Statements statements) -> ENDBlock def on_END(statements) - lbrace = find_token(LBrace) - rbrace = find_token(RBrace) + lbrace = consume_token(LBrace) + rbrace = consume_token(RBrace) start_char = find_next_statement_start(lbrace.location.end_char) statements.bind( @@ -349,7 +392,7 @@ def on_END(statements) rbrace.location.start_column ) - keyword = find_token(Kw, "END") + keyword = consume_keyword(:END) ENDBlock.new( lbrace: lbrace, @@ -380,7 +423,7 @@ def on___end__(value) # (DynaSymbol | SymbolLiteral) right # ) -> Alias def on_alias(left, right) - keyword = find_token(Kw, "alias") + keyword = consume_keyword(:alias) Alias.new( left: left, @@ -392,8 +435,8 @@ def on_alias(left, right) # :call-seq: # on_aref: (untyped collection, (nil | Args) index) -> ARef def on_aref(collection, index) - find_token(LBracket) - rbracket = find_token(RBracket) + consume_token(LBracket) + rbracket = consume_token(RBracket) ARef.new( collection: collection, @@ -408,8 +451,8 @@ def on_aref(collection, index) # (nil | Args) index # ) -> ARefField def on_aref_field(collection, index) - find_token(LBracket) - rbracket = find_token(RBracket) + consume_token(LBracket) + rbracket = consume_token(RBracket) ARefField.new( collection: collection, @@ -427,8 +470,8 @@ def on_aref_field(collection, index) # (nil | Args | ArgsForward) arguments # ) -> ArgParen def on_arg_paren(arguments) - lparen = find_token(LParen) - rparen = find_token(RParen) + lparen = consume_token(LParen) + rparen = consume_token(RParen) # If the arguments exceed the ending of the parentheses, then we know we # have a heredoc in the arguments, and we need to use the bounds of the @@ -470,23 +513,26 @@ def on_args_add(arguments, argument) # (false | untyped) block # ) -> Args def on_args_add_block(arguments, block) + end_char = arguments.parts.any? && arguments.location.end_char + # First, see if there is an & operator that could potentially be # associated with the block part of this args_add_block. If there is not, # then just return the arguments. - operator = find_token(Op, "&", consume: false) - return arguments unless operator - - # If there are any arguments and the operator we found from the list is - # not after them, then we're going to return the arguments as-is because - # we're looking at an & that occurs before the arguments are done. - if arguments.parts.any? && - operator.location.start_char < arguments.location.end_char - return arguments - end + index = + tokens.rindex do |token| + # If there are any arguments and the operator we found from the list + # is not after them, then we're going to return the arguments as-is + # because we're looking at an & that occurs before the arguments are + # done. + return arguments if end_char && token.location.start_char < end_char + token.is_a?(Op) && (token.name == :&) + end + + return arguments unless index # Now we know we have an & operator, so we're going to delete it from the # list of tokens to make sure it doesn't get confused with anything else. - tokens.delete(operator) + operator = tokens.delete_at(index) # Construct the location that represents the block argument. location = operator.location @@ -505,7 +551,7 @@ def on_args_add_block(arguments, block) # :call-seq: # on_args_add_star: (Args arguments, untyped star) -> Args def on_args_add_star(arguments, argument) - beginning = find_token(Op, "*") + beginning = consume_operator(:*) ending = argument || beginning location = @@ -527,7 +573,7 @@ def on_args_add_star(arguments, argument) # :call-seq: # on_args_forward: () -> ArgsForward def on_args_forward - op = find_token(Op, "...") + op = consume_operator(:"...") ArgsForward.new(value: op.value, location: op.location) end @@ -547,8 +593,8 @@ def on_args_new # ArrayLiteral | QSymbols | QWords | Symbols | Words def on_array(contents) if !contents || contents.is_a?(Args) - lbracket = find_token(LBracket) - rbracket = find_token(RBracket) + lbracket = consume_token(LBracket) + rbracket = consume_token(RBracket) ArrayLiteral.new( lbracket: lbracket, @@ -556,8 +602,7 @@ def on_array(contents) location: lbracket.location.to(rbracket.location) ) else - tstring_end = - find_token(TStringEnd, location: contents.beginning.location) + tstring_end = consume_tstring_end(contents.beginning.location) contents.class.new( beginning: contents.beginning, @@ -567,6 +612,56 @@ def on_array(contents) end end + # Ugh... I really do not like this class. Basically, ripper doesn't provide + # enough information about where pins are located in the tree. It only gives + # events for ^ ops and var_ref nodes. You have to piece it together + # yourself. + # + # Note that there are edge cases here that we straight up do not address, + # because I honestly think it's going to be faster to write a new parser + # than to address them. For example, this will not work properly: + # + # foo in ^((bar = 0; bar; baz)) + # + # If someone actually does something like that, we'll have to find another + # way to make this work. + class PinVisitor < Visitor + attr_reader :pins, :stack + + def initialize(pins) + @pins = pins + @stack = [] + end + + def visit(node) + return if pins.empty? + stack << node + super + stack.pop + end + + def visit_var_ref(node) + pins.shift + node.pin(stack[-2]) + end + + def self.visit(node, tokens) + start_char = node.location.start_char + allocated = [] + + tokens.reverse_each do |token| + char = token.location.start_char + break if char <= start_char + + if token.is_a?(Op) && token.value == "^" + allocated.unshift(tokens.delete(token)) + end + end + + new(allocated).visit(node) if allocated.any? + end + end + # :call-seq: # on_aryptn: ( # (nil | VarRef) constant, @@ -583,7 +678,7 @@ def on_aryptn(constant, requireds, rest, posts) # of the various parts. location = if parts.empty? - find_token(LBracket).location.to(find_token(RBracket).location) + consume_token(LBracket).location.to(consume_token(RBracket).location) else parts[0].location.to(parts[-1].location) end @@ -594,12 +689,13 @@ def on_aryptn(constant, requireds, rest, posts) if rest.is_a?(VarField) && rest.value.nil? tokens.rindex do |rtoken| case rtoken - in Op[value: "*"] - rest = VarField.new(value: nil, location: rtoken.location) + when Comma break - in Comma - break - else + when Op + if rtoken.value == "*" + rest = VarField.new(value: nil, location: rtoken.location) + break + end end end end @@ -644,7 +740,7 @@ def on_assoc_new(key, value) # :call-seq: # on_assoc_splat: (untyped value) -> AssocSplat def on_assoc_splat(value) - operator = find_token(Op, "**") + operator = consume_operator(:**) AssocSplat.new( value: value, @@ -704,23 +800,23 @@ def on_bare_assoc_hash(assocs) # :call-seq: # on_begin: (untyped bodystmt) -> Begin | PinnedBegin def on_begin(bodystmt) - pin = find_token(Op, "^", consume: false) + pin = find_operator(:^) if pin && pin.location.start_char < bodystmt.location.start_char tokens.delete(pin) - find_token(LParen) + consume_token(LParen) - rparen = find_token(RParen) + rparen = consume_token(RParen) location = pin.location.to(rparen.location) PinnedBegin.new(statement: bodystmt, location: location) else - keyword = find_token(Kw, "begin") + keyword = consume_keyword(:begin) end_location = if bodystmt.else_clause bodystmt.location else - find_token(Kw, "end").location + consume_keyword(:end).location end bodystmt.bind( @@ -746,13 +842,11 @@ def on_binary(left, operator, right) # Here, we're going to search backward for the token that's between the # two operands that matches the operator so we can delete it from the # list. + range = (left.location.end_char + 1)...right.location.start_char index = tokens.rindex do |token| - location = token.location - - token.is_a?(Op) && token.value == operator.to_s && - location.start_char > left.location.end_char && - location.end_char < right.location.start_char + token.is_a?(Op) && token.name == operator && + range.cover?(token.location.start_char) end tokens.delete_at(index) if index @@ -795,7 +889,7 @@ def on_block_var(params, locals) # :call-seq: # on_blockarg: (Ident name) -> BlockArg def on_blockarg(name) - operator = find_token(Op, "&") + operator = consume_operator(:&) location = operator.location location = location.to(name.location) if name @@ -814,7 +908,7 @@ def on_bodystmt(statements, rescue_clause, else_clause, ensure_clause) BodyStmt.new( statements: statements, rescue_clause: rescue_clause, - else_keyword: else_clause && find_token(Kw, "else"), + else_keyword: else_clause && consume_keyword(:else), else_clause: else_clause, ensure_clause: ensure_clause, location: @@ -828,8 +922,8 @@ def on_bodystmt(statements, rescue_clause, else_clause, ensure_clause) # Statements statements # ) -> BraceBlock def on_brace_block(block_var, statements) - lbrace = find_token(LBrace) - rbrace = find_token(RBrace) + lbrace = consume_token(LBrace) + rbrace = consume_token(RBrace) location = (block_var || lbrace).location start_char = find_next_statement_start(location.end_char) @@ -864,7 +958,7 @@ def on_brace_block(block_var, statements) # :call-seq: # on_break: (Args arguments) -> Break def on_break(arguments) - keyword = find_token(Kw, "break") + keyword = consume_keyword(:break) location = keyword.location location = location.to(arguments.location) if arguments.parts.any? @@ -900,7 +994,7 @@ def on_call(receiver, operator, message) # :call-seq: # on_case: (untyped value, untyped consequent) -> Case | RAssign def on_case(value, consequent) - if (keyword = find_token(Kw, "case", consume: false)) + if (keyword = find_keyword(:case)) tokens.delete(keyword) Case.new( @@ -911,18 +1005,22 @@ def on_case(value, consequent) ) else operator = - if (keyword = find_token(Kw, "in", consume: false)) + if (keyword = find_keyword(:in)) tokens.delete(keyword) else - find_token(Op, "=>") + consume_operator(:"=>") end - RAssign.new( - value: value, - operator: operator, - pattern: consequent, - location: value.location.to(consequent.location) - ) + node = + RAssign.new( + value: value, + operator: operator, + pattern: consequent, + location: value.location.to(consequent.location) + ) + + PinVisitor.visit(node, tokens) + node end end @@ -933,8 +1031,8 @@ def on_case(value, consequent) # BodyStmt bodystmt # ) -> ClassDeclaration def on_class(constant, superclass, bodystmt) - beginning = find_token(Kw, "class") - ending = find_token(Kw, "end") + beginning = consume_keyword(:class) + ending = consume_keyword(:end) location = (superclass || constant).location start_char = find_next_statement_start(location.end_char) @@ -1004,20 +1102,20 @@ def on_command_call(receiver, operator, message, arguments) # :call-seq: # on_comment: (String value) -> Comment def on_comment(value) - line = lineno - comment = - Comment.new( - value: value.chomp, - inline: value.strip != lines[line - 1].strip, - location: - Location.token( - line: line, - char: char_pos, - column: current_column, - size: value.size - 1 - ) + char = char_pos + location = + Location.token( + line: lineno, + char: char, + column: current_column, + size: value.size - 1 ) + index = source.rindex(/[^\t ]/, char - 1) if char != 0 + inline = index && (source[index] != "\n") + comment = + Comment.new(value: value.chomp, inline: inline, location: location) + @comments << comment comment end @@ -1092,7 +1190,7 @@ def on_def(name, params, bodystmt) # Find the beginning of the method definition, which works for single-line # and normal method definitions. - beginning = find_token(Kw, "def") + beginning = consume_keyword(:def) # If there aren't any params then we need to correct the params node # location information @@ -1112,7 +1210,7 @@ def on_def(name, params, bodystmt) params = Params.new(location: location) end - ending = find_token(Kw, "end", consume: false) + ending = find_keyword(:end) if ending tokens.delete(ending) @@ -1150,13 +1248,13 @@ def on_def(name, params, bodystmt) # :call-seq: # on_defined: (untyped value) -> Defined def on_defined(value) - beginning = find_token(Kw, "defined?") + beginning = consume_keyword(:defined?) ending = value range = beginning.location.end_char...value.location.start_char if source[range].include?("(") - find_token(LParen) - ending = find_token(RParen) + consume_token(LParen) + ending = consume_token(RParen) end Defined.new( @@ -1197,8 +1295,8 @@ def on_defs(target, operator, name, params, bodystmt) params = Params.new(location: location) end - beginning = find_token(Kw, "def") - ending = find_token(Kw, "end", consume: false) + beginning = consume_keyword(:def) + ending = find_keyword(:end) if ending tokens.delete(ending) @@ -1238,8 +1336,8 @@ def on_defs(target, operator, name, params, bodystmt) # :call-seq: # on_do_block: (BlockVar block_var, BodyStmt bodystmt) -> DoBlock def on_do_block(block_var, bodystmt) - beginning = find_token(Kw, "do") - ending = find_token(Kw, "end") + beginning = consume_keyword(:do) + ending = consume_keyword(:end) location = (block_var || beginning).location start_char = find_next_statement_start(location.end_char) @@ -1261,7 +1359,7 @@ def on_do_block(block_var, bodystmt) # :call-seq: # on_dot2: ((nil | untyped) left, (nil | untyped) right) -> Dot2 def on_dot2(left, right) - operator = find_token(Op, "..") + operator = consume_operator(:"..") beginning = left || operator ending = right || operator @@ -1276,7 +1374,7 @@ def on_dot2(left, right) # :call-seq: # on_dot3: ((nil | untyped) left, (nil | untyped) right) -> Dot3 def on_dot3(left, right) - operator = find_token(Op, "...") + operator = consume_operator(:"...") beginning = left || operator ending = right || operator @@ -1291,10 +1389,10 @@ def on_dot3(left, right) # :call-seq: # on_dyna_symbol: (StringContent string_content) -> DynaSymbol def on_dyna_symbol(string_content) - if find_token(SymBeg, consume: false) + if (symbeg = find_token(SymBeg)) # A normal dynamic symbol - symbeg = find_token(SymBeg) - tstring_end = find_token(TStringEnd, location: symbeg.location) + tokens.delete(symbeg) + tstring_end = consume_tstring_end(symbeg.location) DynaSymbol.new( quote: symbeg.value, @@ -1303,8 +1401,8 @@ def on_dyna_symbol(string_content) ) else # A dynamic symbol as a hash key - tstring_beg = find_token(TStringBeg) - label_end = find_token(LabelEnd) + tstring_beg = consume_token(TStringBeg) + label_end = consume_token(LabelEnd) DynaSymbol.new( parts: string_content.parts, @@ -1317,7 +1415,7 @@ def on_dyna_symbol(string_content) # :call-seq: # on_else: (Statements statements) -> Else def on_else(statements) - keyword = find_token(Kw, "else") + keyword = consume_keyword(:else) # else can either end with an end keyword (in which case we'll want to # consume that event) or it can end with an ensure keyword (in which case @@ -1357,8 +1455,8 @@ def on_else(statements) # (nil | Elsif | Else) consequent # ) -> Elsif def on_elsif(predicate, statements, consequent) - beginning = find_token(Kw, "elsif") - ending = consequent || find_token(Kw, "end") + beginning = consume_keyword(:elsif) + ending = consequent || consume_keyword(:end) start_char = find_next_statement_start(predicate.location.end_char) statements.bind( @@ -1478,11 +1576,11 @@ def on_embvar(value) # :call-seq: # on_ensure: (Statements statements) -> Ensure def on_ensure(statements) - keyword = find_token(Kw, "ensure") + keyword = consume_keyword(:ensure) # We don't want to consume the :@kw event, because that would break # def..ensure..end chains. - ending = find_token(Kw, "end", consume: false) + ending = find_keyword(:end) start_char = find_next_statement_start(keyword.location.end_char) statements.bind( start_char, @@ -1504,7 +1602,7 @@ def on_ensure(statements) # :call-seq: # on_excessed_comma: () -> ExcessedComma def on_excessed_comma(*) - comma = find_token(Comma) + comma = consume_token(Comma) ExcessedComma.new(value: comma.value, location: comma.location) end @@ -1557,20 +1655,18 @@ def on_fndptn(constant, left, values, right) # right left parenthesis, or the left splat. We're going to use this to # determine how to find the closing of the pattern, as well as determining # the location of the node. - opening = - find_token(LBracket, consume: false) || - find_token(LParen, consume: false) || left + opening = find_token(LBracket) || find_token(LParen) || left # The closing is based on the opening, which is either the matched # punctuation or the right splat. closing = case opening - in LBracket + when LBracket tokens.delete(opening) - find_token(RBracket) - in LParen + consume_token(RBracket) + when LParen tokens.delete(opening) - find_token(RParen) + consume_token(RParen) else right end @@ -1591,13 +1687,13 @@ def on_fndptn(constant, left, values, right) # Statements statements # ) -> For def on_for(index, collection, statements) - beginning = find_token(Kw, "for") - in_keyword = find_token(Kw, "in") - ending = find_token(Kw, "end") + beginning = consume_keyword(:for) + in_keyword = consume_keyword(:in) + ending = consume_keyword(:end) # Consume the do keyword if it exists so that it doesn't get confused for # some other block - keyword = find_token(Kw, "do", consume: false) + keyword = find_keyword(:do) if keyword && keyword.location.start_char > collection.location.end_char && keyword.location.end_char < ending.location.start_char @@ -1645,8 +1741,8 @@ def on_gvar(value) # :call-seq: # on_hash: ((nil | Array[AssocNew | AssocSplat]) assocs) -> HashLiteral def on_hash(assocs) - lbrace = find_token(LBrace) - rbrace = find_token(RBrace) + lbrace = consume_token(LBrace) + rbrace = consume_token(RBrace) HashLiteral.new( lbrace: lbrace, @@ -1730,8 +1826,8 @@ def on_hshptn(constant, keywords, keyword_rest) if keyword_rest # We're doing this to delete the token from the list so that it doesn't # confuse future patterns by thinking they have an extra ** on the end. - find_token(Op, "**") - elsif (token = find_token(Op, "**", consume: false)) + consume_operator(:**) + elsif (token = find_operator(:**)) tokens.delete(token) # Create an artificial VarField if we find an extra ** on the end. This @@ -1744,8 +1840,8 @@ def on_hshptn(constant, keywords, keyword_rest) # If there's no constant, there may be braces, so we're going to look for # those to get our bounds. unless constant - lbrace = find_token(LBrace, consume: false) - rbrace = find_token(RBrace, consume: false) + lbrace = find_token(LBrace) + rbrace = find_token(RBrace) if lbrace && rbrace parts = [lbrace, *parts, rbrace] @@ -1784,8 +1880,8 @@ def on_ident(value) # (nil | Elsif | Else) consequent # ) -> If def on_if(predicate, statements, consequent) - beginning = find_token(Kw, "if") - ending = consequent || find_token(Kw, "end") + beginning = consume_keyword(:if) + ending = consequent || consume_keyword(:end) start_char = find_next_statement_start(predicate.location.end_char) statements.bind( @@ -1817,7 +1913,7 @@ def on_ifop(predicate, truthy, falsy) # :call-seq: # on_if_mod: (untyped predicate, untyped statement) -> IfMod def on_if_mod(predicate, statement) - find_token(Kw, "if") + consume_keyword(:if) IfMod.new( statement: statement, @@ -1860,11 +1956,11 @@ def on_in(pattern, statements, consequent) # Here we have a rightward assignment return pattern unless statements - beginning = find_token(Kw, "in") - ending = consequent || find_token(Kw, "end") + beginning = consume_keyword(:in) + ending = consequent || consume_keyword(:end) statements_start = pattern - if (token = find_token(Kw, "then", consume: false)) + if (token = find_keyword(:then)) tokens.delete(token) statements_start = token end @@ -1878,12 +1974,16 @@ def on_in(pattern, statements, consequent) ending.location.start_column ) - In.new( - pattern: pattern, - statements: statements, - consequent: consequent, - location: beginning.location.to(ending.location) - ) + node = + In.new( + pattern: pattern, + statements: statements, + consequent: consequent, + location: beginning.location.to(ending.location) + ) + + PinVisitor.visit(node, tokens) + node end # :call-seq: @@ -1938,7 +2038,7 @@ def on_kw(value) # :call-seq: # on_kwrest_param: ((nil | Ident) name) -> KwRestParam def on_kwrest_param(name) - location = find_token(Op, "**").location + location = consume_operator(:**).location location = location.to(name.location) if name KwRestParam.new(name: name, location: location) @@ -1984,7 +2084,7 @@ def on_label_end(value) # (BodyStmt | Statements) statements # ) -> Lambda def on_lambda(params, statements) - beginning = find_token(TLambda) + beginning = consume_token(TLambda) braces = tokens.any? do |token| token.is_a?(TLamBeg) && @@ -1995,7 +2095,7 @@ def on_lambda(params, statements) # capturing lambda var until 3.2, we need to normalize all of that here. params = case params - in Paren[contents: Params] + when Paren # In this case we've gotten to the <3.2 parentheses wrapping a set of # parameters case. Here we need to manually scan for lambda locals. range = (params.location.start_char + 1)...params.location.end_char @@ -2015,23 +2115,23 @@ def on_lambda(params, statements) location: params.location, comments: params.comments ) - in Params + when Params # In this case we've gotten to the <3.2 plain set of parameters. In # this case there cannot be lambda locals, so we will wrap the # parameters into a lambda var that has no locals. LambdaVar.new(params: params, locals: [], location: params.location) - in LambdaVar + when LambdaVar # In this case we've gotten to 3.2+ lambda var. In this case we don't # need to do anything and can just the value as given. params end if braces - opening = find_token(TLamBeg) - closing = find_token(RBrace) + opening = consume_token(TLamBeg) + closing = consume_token(RBrace) else - opening = find_token(Kw, "do") - closing = find_token(Kw, "end") + opening = consume_keyword(:do) + closing = consume_keyword(:end) end start_char = find_next_statement_start(opening.location.end_char) @@ -2262,7 +2362,7 @@ def on_mlhs_add_post(left, right) # (nil | ARefField | Field | Ident | VarField) part # ) -> MLHS def on_mlhs_add_star(mlhs, part) - beginning = find_token(Op, "*") + beginning = consume_operator(:*) ending = part || beginning location = beginning.location.to(ending.location) @@ -2285,8 +2385,8 @@ def on_mlhs_new # :call-seq: # on_mlhs_paren: ((MLHS | MLHSParen) contents) -> MLHSParen def on_mlhs_paren(contents) - lparen = find_token(LParen) - rparen = find_token(RParen) + lparen = consume_token(LParen) + rparen = consume_token(RParen) comma_range = lparen.location.end_char...rparen.location.start_char contents.comma = true if source[comma_range].strip.end_with?(",") @@ -2303,8 +2403,8 @@ def on_mlhs_paren(contents) # BodyStmt bodystmt # ) -> ModuleDeclaration def on_module(constant, bodystmt) - beginning = find_token(Kw, "module") - ending = find_token(Kw, "end") + beginning = consume_keyword(:module) + ending = consume_keyword(:end) start_char = find_next_statement_start(constant.location.end_char) bodystmt.bind( @@ -2343,7 +2443,7 @@ def on_mrhs_add(mrhs, part) # :call-seq: # on_mrhs_add_star: (MRHS mrhs, untyped value) -> MRHS def on_mrhs_add_star(mrhs, value) - beginning = find_token(Op, "*") + beginning = consume_operator(:*) ending = value || beginning arg_star = @@ -2371,7 +2471,7 @@ def on_mrhs_new_from_args(arguments) # :call-seq: # on_next: (Args arguments) -> Next def on_next(arguments) - keyword = find_token(Kw, "next") + keyword = consume_keyword(:next) location = keyword.location location = location.to(arguments.location) if arguments.parts.any? @@ -2486,8 +2586,8 @@ def on_params( # :call-seq: # on_paren: (untyped contents) -> Paren def on_paren(contents) - lparen = find_token(LParen) - rparen = find_token(RParen) + lparen = consume_token(LParen) + rparen = consume_token(RParen) if contents.is_a?(Params) location = contents.location @@ -2551,13 +2651,13 @@ def on_period(value) # :call-seq: # on_program: (Statements statements) -> Program def on_program(statements) - last_column = source.length - line_counts[lines.length - 1].start + last_column = source.length - line_counts.last.start location = Location.new( start_line: 1, start_char: 0, start_column: 0, - end_line: lines.length, + end_line: line_counts.length - 1, end_char: source.length, end_column: last_column ) @@ -2692,7 +2792,7 @@ def on_qsymbols_beg(value) # :call-seq: # on_qsymbols_new: () -> QSymbols def on_qsymbols_new - beginning = find_token(QSymbolsBeg) + beginning = consume_token(QSymbolsBeg) QSymbols.new( beginning: beginning, @@ -2733,7 +2833,7 @@ def on_qwords_beg(value) # :call-seq: # on_qwords_new: () -> QWords def on_qwords_new - beginning = find_token(QWordsBeg) + beginning = consume_token(QWordsBeg) QWords.new( beginning: beginning, @@ -2798,7 +2898,7 @@ def on_rbracket(value) # :call-seq: # on_redo: () -> Redo def on_redo - keyword = find_token(Kw, "redo") + keyword = consume_keyword(:redo) Redo.new(value: keyword.value, location: keyword.location) end @@ -2874,7 +2974,7 @@ def on_regexp_literal(regexp_content, ending) # :call-seq: # on_regexp_new: () -> RegexpContent def on_regexp_new - regexp_beg = find_token(RegexpBeg) + regexp_beg = consume_token(RegexpBeg) RegexpContent.new( beginning: regexp_beg.value, @@ -2891,7 +2991,7 @@ def on_regexp_new # (nil | Rescue) consequent # ) -> Rescue def on_rescue(exceptions, variable, statements, consequent) - keyword = find_token(Kw, "rescue") + keyword = consume_keyword(:rescue) exceptions = exceptions[0] if exceptions.is_a?(Array) last_node = variable || exceptions || keyword @@ -2943,7 +3043,7 @@ def on_rescue(exceptions, variable, statements, consequent) # :call-seq: # on_rescue_mod: (untyped statement, untyped value) -> RescueMod def on_rescue_mod(statement, value) - find_token(Kw, "rescue") + consume_keyword(:rescue) RescueMod.new( statement: statement, @@ -2955,7 +3055,7 @@ def on_rescue_mod(statement, value) # :call-seq: # on_rest_param: ((nil | Ident) name) -> RestParam def on_rest_param(name) - location = find_token(Op, "*").location + location = consume_operator(:*).location location = location.to(name.location) if name RestParam.new(name: name, location: location) @@ -2964,7 +3064,7 @@ def on_rest_param(name) # :call-seq: # on_retry: () -> Retry def on_retry - keyword = find_token(Kw, "retry") + keyword = consume_keyword(:retry) Retry.new(value: keyword.value, location: keyword.location) end @@ -2972,7 +3072,7 @@ def on_retry # :call-seq: # on_return: (Args arguments) -> Return def on_return(arguments) - keyword = find_token(Kw, "return") + keyword = consume_keyword(:return) Return.new( arguments: arguments, @@ -2983,7 +3083,7 @@ def on_return(arguments) # :call-seq: # on_return0: () -> Return0 def on_return0 - keyword = find_token(Kw, "return") + keyword = consume_keyword(:return) Return0.new(value: keyword.value, location: keyword.location) end @@ -3010,8 +3110,8 @@ def on_rparen(value) # :call-seq: # on_sclass: (untyped target, BodyStmt bodystmt) -> SClass def on_sclass(target, bodystmt) - beginning = find_token(Kw, "class") - ending = find_token(Kw, "end") + beginning = consume_keyword(:class) + ending = consume_keyword(:end) start_char = find_next_statement_start(target.location.end_char) bodystmt.bind( @@ -3109,7 +3209,7 @@ def on_string_content # :call-seq: # on_string_dvar: ((Backref | VarRef) variable) -> StringDVar def on_string_dvar(variable) - embvar = find_token(EmbVar) + embvar = consume_token(EmbVar) StringDVar.new( variable: variable, @@ -3120,8 +3220,8 @@ def on_string_dvar(variable) # :call-seq: # on_string_embexpr: (Statements statements) -> StringEmbExpr def on_string_embexpr(statements) - embexpr_beg = find_token(EmbExprBeg) - embexpr_end = find_token(EmbExprEnd) + embexpr_beg = consume_token(EmbExprBeg) + embexpr_end = consume_token(EmbExprEnd) statements.bind( embexpr_beg.location.end_char, @@ -3162,8 +3262,8 @@ def on_string_literal(string) location: heredoc.location ) else - tstring_beg = find_token(TStringBeg) - tstring_end = find_token(TStringEnd, location: tstring_beg.location) + tstring_beg = consume_token(TStringBeg) + tstring_end = consume_tstring_end(tstring_beg.location) location = Location.new( @@ -3189,7 +3289,7 @@ def on_string_literal(string) # :call-seq: # on_super: ((ArgParen | Args) arguments) -> Super def on_super(arguments) - keyword = find_token(Kw, "super") + keyword = consume_keyword(:super) Super.new( arguments: arguments, @@ -3236,7 +3336,7 @@ def on_symbol(value) # ) -> SymbolLiteral def on_symbol_literal(value) if value.is_a?(SymbolContent) - symbeg = find_token(SymBeg) + symbeg = consume_token(SymBeg) SymbolLiteral.new( value: value.value, @@ -3280,7 +3380,7 @@ def on_symbols_beg(value) # :call-seq: # on_symbols_new: () -> Symbols def on_symbols_new - beginning = find_token(SymbolsBeg) + beginning = consume_token(SymbolsBeg) Symbols.new( beginning: beginning, @@ -3410,13 +3510,13 @@ def on_unary(operator, statement) # We have somewhat special handling of the not operator since if it has # parentheses they don't get reported as a paren node for some reason. - beginning = find_token(Kw, "not") + beginning = consume_keyword(:not) ending = statement || beginning parentheses = source[beginning.location.end_char] == "(" if parentheses - find_token(LParen) - ending = find_token(RParen) + consume_token(LParen) + ending = consume_token(RParen) end Not.new( @@ -3449,7 +3549,7 @@ def on_unary(operator, statement) # :call-seq: # on_undef: (Array[DynaSymbol | SymbolLiteral] symbols) -> Undef def on_undef(symbols) - keyword = find_token(Kw, "undef") + keyword = consume_keyword(:undef) Undef.new( symbols: symbols, @@ -3464,8 +3564,8 @@ def on_undef(symbols) # ((nil | Elsif | Else) consequent) # ) -> Unless def on_unless(predicate, statements, consequent) - beginning = find_token(Kw, "unless") - ending = consequent || find_token(Kw, "end") + beginning = consume_keyword(:unless) + ending = consequent || consume_keyword(:end) start_char = find_next_statement_start(predicate.location.end_char) statements.bind( @@ -3486,7 +3586,7 @@ def on_unless(predicate, statements, consequent) # :call-seq: # on_unless_mod: (untyped predicate, untyped statement) -> UnlessMod def on_unless_mod(predicate, statement) - find_token(Kw, "unless") + consume_keyword(:unless) UnlessMod.new( statement: statement, @@ -3498,12 +3598,12 @@ def on_unless_mod(predicate, statement) # :call-seq: # on_until: (untyped predicate, Statements statements) -> Until def on_until(predicate, statements) - beginning = find_token(Kw, "until") - ending = find_token(Kw, "end") + beginning = consume_keyword(:until) + ending = consume_keyword(:end) # Consume the do keyword if it exists so that it doesn't get confused for # some other block - keyword = find_token(Kw, "do", consume: false) + keyword = find_keyword(:do) if keyword && keyword.location.start_char > predicate.location.end_char && keyword.location.end_char < ending.location.start_char tokens.delete(keyword) @@ -3528,7 +3628,7 @@ def on_until(predicate, statements) # :call-seq: # on_until_mod: (untyped predicate, untyped statement) -> UntilMod def on_until_mod(predicate, statement) - find_token(Kw, "until") + consume_keyword(:until) UntilMod.new( statement: statement, @@ -3540,7 +3640,7 @@ def on_until_mod(predicate, statement) # :call-seq: # on_var_alias: (GVar left, (Backref | GVar) right) -> VarAlias def on_var_alias(left, right) - keyword = find_token(Kw, "alias") + keyword = consume_keyword(:alias) VarAlias.new( left: left, @@ -3569,17 +3669,7 @@ def on_var_field(value) # :call-seq: # on_var_ref: ((Const | CVar | GVar | Ident | IVar | Kw) value) -> VarRef def on_var_ref(value) - pin = find_token(Op, "^", consume: false) - - if pin && pin.location.start_char == value.location.start_char - 1 - tokens.delete(pin) - PinnedVarRef.new( - value: value, - location: pin.location.to(value.location) - ) - else - VarRef.new(value: value, location: value.location) - end + VarRef.new(value: value, location: value.location) end # :call-seq: @@ -3604,11 +3694,11 @@ def on_void_stmt # (nil | Else | When) consequent # ) -> When def on_when(arguments, statements, consequent) - beginning = find_token(Kw, "when") - ending = consequent || find_token(Kw, "end") + beginning = consume_keyword(:when) + ending = consequent || consume_keyword(:end) statements_start = arguments - if (token = find_token(Kw, "then", consume: false)) + if (token = find_keyword(:then)) tokens.delete(token) statements_start = token end @@ -3634,12 +3724,12 @@ def on_when(arguments, statements, consequent) # :call-seq: # on_while: (untyped predicate, Statements statements) -> While def on_while(predicate, statements) - beginning = find_token(Kw, "while") - ending = find_token(Kw, "end") + beginning = consume_keyword(:while) + ending = consume_keyword(:end) # Consume the do keyword if it exists so that it doesn't get confused for # some other block - keyword = find_token(Kw, "do", consume: false) + keyword = find_keyword(:do) if keyword && keyword.location.start_char > predicate.location.end_char && keyword.location.end_char < ending.location.start_char tokens.delete(keyword) @@ -3664,7 +3754,7 @@ def on_while(predicate, statements) # :call-seq: # on_while_mod: (untyped predicate, untyped statement) -> WhileMod def on_while_mod(predicate, statement) - find_token(Kw, "while") + consume_keyword(:while) WhileMod.new( statement: statement, @@ -3727,7 +3817,7 @@ def on_words_beg(value) # :call-seq: # on_words_new: () -> Words def on_words_new - beginning = find_token(WordsBeg) + beginning = consume_token(WordsBeg) Words.new( beginning: beginning, @@ -3761,7 +3851,7 @@ def on_xstring_new if heredoc && heredoc.beginning.value.include?("`") heredoc.location else - find_token(Backtick).location + consume_token(Backtick).location end XString.new(parts: [], location: location) @@ -3781,7 +3871,7 @@ def on_xstring_literal(xstring) location: heredoc.location ) else - ending = find_token(TStringEnd, location: xstring.location) + ending = consume_tstring_end(xstring.location) XStringLiteral.new( parts: xstring.parts, @@ -3793,7 +3883,7 @@ def on_xstring_literal(xstring) # :call-seq: # on_yield: ((Args | Paren) arguments) -> Yield def on_yield(arguments) - keyword = find_token(Kw, "yield") + keyword = consume_keyword(:yield) Yield.new( arguments: arguments, @@ -3804,7 +3894,7 @@ def on_yield(arguments) # :call-seq: # on_yield0: () -> Yield0 def on_yield0 - keyword = find_token(Kw, "yield") + keyword = consume_keyword(:yield) Yield0.new(value: keyword.value, location: keyword.location) end @@ -3812,7 +3902,7 @@ def on_yield0 # :call-seq: # on_zsuper: () -> ZSuper def on_zsuper - keyword = find_token(Kw, "super") + keyword = consume_keyword(:super) ZSuper.new(value: keyword.value, location: keyword.location) end diff --git a/lib/syntax_tree/version.rb b/lib/syntax_tree/version.rb index ec6dcd3e..8456abd4 100644 --- a/lib/syntax_tree/version.rb +++ b/lib/syntax_tree/version.rb @@ -1,5 +1,5 @@ # frozen_string_literal: true module SyntaxTree - VERSION = "3.6.3" + VERSION = "4.0.0" end diff --git a/lib/syntax_tree/visitor/environment.rb b/lib/syntax_tree/visitor/environment.rb new file mode 100644 index 00000000..dfcf0a80 --- /dev/null +++ b/lib/syntax_tree/visitor/environment.rb @@ -0,0 +1,81 @@ +# frozen_string_literal: true + +module SyntaxTree + # The environment class is used to keep track of local variables and arguments + # inside a particular scope + class Environment + # [Array[Local]] The local variables and arguments defined in this + # environment + attr_reader :locals + + # This class tracks the occurrences of a local variable or argument + class Local + # [Symbol] The type of the local (e.g. :argument, :variable) + attr_reader :type + + # [Array[Location]] The locations of all definitions and assignments of + # this local + attr_reader :definitions + + # [Array[Location]] The locations of all usages of this local + attr_reader :usages + + # initialize: (Symbol type) -> void + def initialize(type) + @type = type + @definitions = [] + @usages = [] + end + + # add_definition: (Location location) -> void + def add_definition(location) + @definitions << location + end + + # add_usage: (Location location) -> void + def add_usage(location) + @usages << location + end + end + + # initialize: (Environment | nil parent) -> void + def initialize(parent = nil) + @locals = {} + @parent = parent + end + + # Adding a local definition will either insert a new entry in the locals + # hash or append a new definition location to an existing local. Notice that + # it's not possible to change the type of a local after it has been + # registered + # add_local_definition: (Ident | Label identifier, Symbol type) -> void + def add_local_definition(identifier, type) + name = identifier.value.delete_suffix(":") + + @locals[name] ||= Local.new(type) + @locals[name].add_definition(identifier.location) + end + + # Adding a local usage will either insert a new entry in the locals + # hash or append a new usage location to an existing local. Notice that + # it's not possible to change the type of a local after it has been + # registered + # add_local_usage: (Ident | Label identifier, Symbol type) -> void + def add_local_usage(identifier, type) + name = identifier.value.delete_suffix(":") + + @locals[name] ||= Local.new(type) + @locals[name].add_usage(identifier.location) + end + + # Try to find the local given its name in this environment or any of its + # parents + # find_local: (String name) -> Local | nil + def find_local(name) + local = @locals[name] + return local unless local.nil? + + @parent&.find_local(name) + end + end +end diff --git a/lib/syntax_tree/visitor/with_environment.rb b/lib/syntax_tree/visitor/with_environment.rb new file mode 100644 index 00000000..62e59c98 --- /dev/null +++ b/lib/syntax_tree/visitor/with_environment.rb @@ -0,0 +1,141 @@ +# frozen_string_literal: true + +module SyntaxTree + # WithEnvironment is a module intended to be included in classes inheriting + # from Visitor. The module overrides a few visit methods to automatically keep + # track of local variables and arguments defined in the current environment. + # Example usage: + # class MyVisitor < Visitor + # include WithEnvironment + # + # def visit_ident(node) + # # Check if we're visiting an identifier for an argument, a local + # variable or something else + # local = current_environment.find_local(node) + # + # if local.type == :argument + # # handle identifiers for arguments + # elsif local.type == :variable + # # handle identifiers for variables + # else + # # handle other identifiers, such as method names + # end + # end + module WithEnvironment + def current_environment + @current_environment ||= Environment.new + end + + def with_new_environment + previous_environment = @current_environment + @current_environment = Environment.new(previous_environment) + yield + ensure + @current_environment = previous_environment + end + + # Visits for nodes that create new environments, such as classes, modules + # and method definitions + def visit_class(node) + with_new_environment { super } + end + + def visit_module(node) + with_new_environment { super } + end + + def visit_method_add_block(node) + with_new_environment { super } + end + + def visit_def(node) + with_new_environment { super } + end + + def visit_defs(node) + with_new_environment { super } + end + + def visit_def_endless(node) + with_new_environment { super } + end + + # Visit for keeping track of local arguments, such as method and block + # arguments + def visit_params(node) + node.requireds.each do |param| + @current_environment.add_local_definition(param, :argument) + end + + node.posts.each do |param| + @current_environment.add_local_definition(param, :argument) + end + + node.keywords.each do |param| + @current_environment.add_local_definition(param.first, :argument) + end + + node.optionals.each do |param| + @current_environment.add_local_definition(param.first, :argument) + end + + super + end + + def visit_rest_param(node) + name = node.name + @current_environment.add_local_definition(name, :argument) if name + + super + end + + def visit_kwrest_param(node) + name = node.name + @current_environment.add_local_definition(name, :argument) if name + + super + end + + def visit_blockarg(node) + name = node.name + @current_environment.add_local_definition(name, :argument) if name + + super + end + + # Visit for keeping track of local variable definitions + def visit_var_field(node) + value = node.value + + if value.is_a?(SyntaxTree::Ident) + @current_environment.add_local_definition(value, :variable) + end + + super + end + + alias visit_pinned_var_ref visit_var_field + + # Visits for keeping track of variable and argument usages + def visit_aref_field(node) + name = node.collection.value + @current_environment.add_local_usage(name, :variable) if name + + super + end + + def visit_var_ref(node) + value = node.value + + if value.is_a?(SyntaxTree::Ident) + definition = @current_environment.find_local(value.value) + + if definition + @current_environment.add_local_usage(value, definition.type) + end + end + + super + end + end +end diff --git a/syntax_tree.gemspec b/syntax_tree.gemspec index 2b461dfd..ec7d57ef 100644 --- a/syntax_tree.gemspec +++ b/syntax_tree.gemspec @@ -25,7 +25,7 @@ Gem::Specification.new do |spec| spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) } spec.require_paths = %w[lib] - spec.add_dependency "prettier_print" + spec.add_dependency "prettier_print", ">= 1.0.0" spec.add_development_dependency "bundler" spec.add_development_dependency "minitest" diff --git a/test/cli_test.rb b/test/cli_test.rb index 3734e734..03293333 100644 --- a/test/cli_test.rb +++ b/test/cli_test.rb @@ -139,7 +139,7 @@ def test_inline_script def test_multiple_inline_scripts stdio, = capture_io { SyntaxTree::CLI.run(%w[format -e 1+1 -e 2+2]) } - assert_equal("1 + 1\n2 + 2\n", stdio) + assert_equal(["1 + 1", "2 + 2"], stdio.split("\n").sort) end def test_generic_error diff --git a/test/interface_test.rb b/test/interface_test.rb index 49a74e92..5086680e 100644 --- a/test/interface_test.rb +++ b/test/interface_test.rb @@ -54,8 +54,12 @@ def instantiate(klass) case klass.name when "SyntaxTree::Binary" klass.new(**params, operator: :+) + when "SyntaxTree::Kw" + klass.new(**params, value: "kw") when "SyntaxTree::Label" klass.new(**params, value: "label:") + when "SyntaxTree::Op" + klass.new(**params, value: "+") when "SyntaxTree::RegexpLiteral" klass.new(**params, ending: "/") when "SyntaxTree::Statements" diff --git a/test/node_test.rb b/test/node_test.rb index 07c2fe26..1a5af125 100644 --- a/test/node_test.rb +++ b/test/node_test.rb @@ -951,7 +951,7 @@ def test_var_field guard_version("3.1.0") do def test_pinned_var_ref source = "foo in ^bar" - at = location(chars: 7..11) + at = location(chars: 8..11) assert_node(PinnedVarRef, source, at: at, &:pattern) end diff --git a/test/quotes_test.rb b/test/quotes_test.rb new file mode 100644 index 00000000..2e2e0243 --- /dev/null +++ b/test/quotes_test.rb @@ -0,0 +1,15 @@ +# frozen_string_literal: true + +require_relative "test_helper" + +module SyntaxTree + class QuotesTest < Minitest::Test + def test_normalize + content = "'aaa' \"bbb\" \\'ccc\\' \\\"ddd\\\"" + enclosing = "\"" + + result = Quotes.normalize(content, enclosing) + assert_equal "'aaa' \\\"bbb\\\" \\'ccc\\' \\\"ddd\\\"", result + end + end +end diff --git a/test/visitor_with_environment_test.rb b/test/visitor_with_environment_test.rb new file mode 100644 index 00000000..915b2143 --- /dev/null +++ b/test/visitor_with_environment_test.rb @@ -0,0 +1,410 @@ +# frozen_string_literal: true + +require_relative "test_helper" + +module SyntaxTree + class VisitorWithEnvironmentTest < Minitest::Test + class Collector < Visitor + include WithEnvironment + + attr_reader :variables, :arguments + + def initialize + @variables = {} + @arguments = {} + end + + def visit_ident(node) + local = current_environment.find_local(node.value) + return unless local + + value = node.value.delete_suffix(":") + + case local.type + when :argument + @arguments[value] = local + when :variable + @variables[value] = local + end + end + + def visit_label(node) + value = node.value.delete_suffix(":") + local = current_environment.find_local(value) + return unless local + + @arguments[value] = node if local.type == :argument + end + end + + def test_collecting_simple_variables + tree = SyntaxTree.parse(<<~RUBY) + def foo + a = 1 + a + end + RUBY + + visitor = Collector.new + visitor.visit(tree) + + assert_equal(1, visitor.variables.length) + + variable = visitor.variables["a"] + assert_equal(1, variable.definitions.length) + assert_equal(1, variable.usages.length) + + assert_equal(2, variable.definitions[0].start_line) + assert_equal(3, variable.usages[0].start_line) + end + + def test_collecting_aref_variables + tree = SyntaxTree.parse(<<~RUBY) + def foo + a = [] + a[1] + end + RUBY + + visitor = Collector.new + visitor.visit(tree) + + assert_equal(1, visitor.variables.length) + + variable = visitor.variables["a"] + assert_equal(1, variable.definitions.length) + assert_equal(1, variable.usages.length) + + assert_equal(2, variable.definitions[0].start_line) + assert_equal(3, variable.usages[0].start_line) + end + + def test_collecting_multi_assign_variables + tree = SyntaxTree.parse(<<~RUBY) + def foo + a, b = [1, 2] + puts a + puts b + end + RUBY + + visitor = Collector.new + visitor.visit(tree) + + assert_equal(2, visitor.variables.length) + + variable_a = visitor.variables["a"] + assert_equal(1, variable_a.definitions.length) + assert_equal(1, variable_a.usages.length) + + assert_equal(2, variable_a.definitions[0].start_line) + assert_equal(3, variable_a.usages[0].start_line) + + variable_b = visitor.variables["b"] + assert_equal(1, variable_b.definitions.length) + assert_equal(1, variable_b.usages.length) + + assert_equal(2, variable_b.definitions[0].start_line) + assert_equal(4, variable_b.usages[0].start_line) + end + + def test_collecting_pattern_matching_variables + tree = SyntaxTree.parse(<<~RUBY) + def foo + case [1, 2] + in Integer => a, Integer + puts a + end + end + RUBY + + visitor = Collector.new + visitor.visit(tree) + + # There are two occurrences, one on line 3 for pinning and one on line 4 + # for reference + assert_equal(1, visitor.variables.length) + + variable = visitor.variables["a"] + + # Assignment a + assert_equal(3, variable.definitions[0].start_line) + assert_equal(4, variable.usages[0].start_line) + end + + def test_collecting_pinned_variables + tree = SyntaxTree.parse(<<~RUBY) + def foo + a = 18 + case [1, 2] + in ^a, *rest + puts a + puts rest + end + end + RUBY + + visitor = Collector.new + visitor.visit(tree) + + assert_equal(2, visitor.variables.length) + + variable_a = visitor.variables["a"] + assert_equal(2, variable_a.definitions.length) + assert_equal(1, variable_a.usages.length) + + assert_equal(2, variable_a.definitions[0].start_line) + assert_equal(4, variable_a.definitions[1].start_line) + assert_equal(5, variable_a.usages[0].start_line) + + variable_rest = visitor.variables["rest"] + assert_equal(1, variable_rest.definitions.length) + assert_equal(4, variable_rest.definitions[0].start_line) + + # Rest is considered a vcall by the parser instead of a var_ref + # assert_equal(1, variable_rest.usages.length) + # assert_equal(6, variable_rest.usages[0].start_line) + end + + if RUBY_VERSION >= "3.1" + def test_collecting_one_line_pattern_matching_variables + tree = SyntaxTree.parse(<<~RUBY) + def foo + [1] => a + puts a + end + RUBY + + visitor = Collector.new + visitor.visit(tree) + + assert_equal(1, visitor.variables.length) + + variable = visitor.variables["a"] + assert_equal(1, variable.definitions.length) + assert_equal(1, variable.usages.length) + + assert_equal(2, variable.definitions[0].start_line) + assert_equal(3, variable.usages[0].start_line) + end + + def test_collecting_endless_method_arguments + tree = SyntaxTree.parse(<<~RUBY) + def foo(a) = puts a + RUBY + + visitor = Collector.new + visitor.visit(tree) + + assert_equal(1, visitor.arguments.length) + + argument = visitor.arguments["a"] + assert_equal(1, argument.definitions.length) + assert_equal(1, argument.usages.length) + + assert_equal(1, argument.definitions[0].start_line) + assert_equal(1, argument.usages[0].start_line) + end + end + + def test_collecting_method_arguments + tree = SyntaxTree.parse(<<~RUBY) + def foo(a) + puts a + end + RUBY + + visitor = Collector.new + visitor.visit(tree) + + assert_equal(1, visitor.arguments.length) + + argument = visitor.arguments["a"] + assert_equal(1, argument.definitions.length) + assert_equal(1, argument.usages.length) + + assert_equal(1, argument.definitions[0].start_line) + assert_equal(2, argument.usages[0].start_line) + end + + def test_collecting_singleton_method_arguments + tree = SyntaxTree.parse(<<~RUBY) + def self.foo(a) + puts a + end + RUBY + + visitor = Collector.new + visitor.visit(tree) + + assert_equal(1, visitor.arguments.length) + + argument = visitor.arguments["a"] + assert_equal(1, argument.definitions.length) + assert_equal(1, argument.usages.length) + + assert_equal(1, argument.definitions[0].start_line) + assert_equal(2, argument.usages[0].start_line) + end + + def test_collecting_method_arguments_all_types + tree = SyntaxTree.parse(<<~RUBY) + def foo(a, b = 1, *c, d, e: 1, **f, &block) + puts a + puts b + puts c + puts d + puts e + puts f + block.call + end + RUBY + + visitor = Collector.new + visitor.visit(tree) + + assert_equal(7, visitor.arguments.length) + + argument_a = visitor.arguments["a"] + assert_equal(1, argument_a.definitions.length) + assert_equal(1, argument_a.usages.length) + assert_equal(1, argument_a.definitions[0].start_line) + assert_equal(2, argument_a.usages[0].start_line) + + argument_b = visitor.arguments["b"] + assert_equal(1, argument_b.definitions.length) + assert_equal(1, argument_b.usages.length) + assert_equal(1, argument_b.definitions[0].start_line) + assert_equal(3, argument_b.usages[0].start_line) + + argument_c = visitor.arguments["c"] + assert_equal(1, argument_c.definitions.length) + assert_equal(1, argument_c.usages.length) + assert_equal(1, argument_c.definitions[0].start_line) + assert_equal(4, argument_c.usages[0].start_line) + + argument_d = visitor.arguments["d"] + assert_equal(1, argument_d.definitions.length) + assert_equal(1, argument_d.usages.length) + assert_equal(1, argument_d.definitions[0].start_line) + assert_equal(5, argument_d.usages[0].start_line) + + argument_e = visitor.arguments["e"] + assert_equal(1, argument_e.definitions.length) + assert_equal(1, argument_e.usages.length) + assert_equal(1, argument_e.definitions[0].start_line) + assert_equal(6, argument_e.usages[0].start_line) + + argument_f = visitor.arguments["f"] + assert_equal(1, argument_f.definitions.length) + assert_equal(1, argument_f.usages.length) + assert_equal(1, argument_f.definitions[0].start_line) + assert_equal(7, argument_f.usages[0].start_line) + + argument_block = visitor.arguments["block"] + assert_equal(1, argument_block.definitions.length) + assert_equal(1, argument_block.usages.length) + assert_equal(1, argument_block.definitions[0].start_line) + assert_equal(8, argument_block.usages[0].start_line) + end + + def test_collecting_block_arguments + tree = SyntaxTree.parse(<<~RUBY) + def foo + [].each do |i| + puts i + end + end + RUBY + + visitor = Collector.new + visitor.visit(tree) + + assert_equal(1, visitor.arguments.length) + + argument = visitor.arguments["i"] + assert_equal(1, argument.definitions.length) + assert_equal(1, argument.usages.length) + assert_equal(2, argument.definitions[0].start_line) + assert_equal(3, argument.usages[0].start_line) + end + + def test_collecting_one_line_block_arguments + tree = SyntaxTree.parse(<<~RUBY) + def foo + [].each { |i| puts i } + end + RUBY + + visitor = Collector.new + visitor.visit(tree) + + assert_equal(1, visitor.arguments.length) + + argument = visitor.arguments["i"] + assert_equal(1, argument.definitions.length) + assert_equal(1, argument.usages.length) + assert_equal(2, argument.definitions[0].start_line) + assert_equal(2, argument.usages[0].start_line) + end + + def test_collecting_shadowed_block_arguments + tree = SyntaxTree.parse(<<~RUBY) + def foo + i = "something" + + [].each do |i| + puts i + end + + i + end + RUBY + + visitor = Collector.new + visitor.visit(tree) + + assert_equal(1, visitor.arguments.length) + assert_equal(1, visitor.variables.length) + + argument = visitor.arguments["i"] + assert_equal(1, argument.definitions.length) + assert_equal(1, argument.usages.length) + assert_equal(4, argument.definitions[0].start_line) + assert_equal(5, argument.usages[0].start_line) + + variable = visitor.variables["i"] + assert_equal(1, variable.definitions.length) + assert_equal(1, variable.usages.length) + assert_equal(2, variable.definitions[0].start_line) + assert_equal(8, variable.usages[0].start_line) + end + + def test_collecting_shadowed_local_variables + tree = SyntaxTree.parse(<<~RUBY) + def foo(a) + puts a + a = 123 + a + end + RUBY + + visitor = Collector.new + visitor.visit(tree) + + # All occurrences are considered arguments, despite overriding the + # argument value + assert_equal(1, visitor.arguments.length) + assert_equal(0, visitor.variables.length) + + argument = visitor.arguments["a"] + assert_equal(2, argument.definitions.length) + assert_equal(2, argument.usages.length) + + assert_equal(1, argument.definitions[0].start_line) + assert_equal(3, argument.definitions[1].start_line) + assert_equal(2, argument.usages[0].start_line) + assert_equal(4, argument.usages[1].start_line) + end + end +end