diff --git a/lib/syntax_tree/visitor/compiler.rb b/lib/syntax_tree/visitor/compiler.rb index c56e553d..10c59a77 100644 --- a/lib/syntax_tree/visitor/compiler.rb +++ b/lib/syntax_tree/visitor/compiler.rb @@ -56,6 +56,13 @@ class RubyVisitor < BasicVisitor class CompilationError < StandardError end + # This will attempt to compile the given node. If it's possible, then + # it will return the compiled object. Otherwise it will return nil. + def self.compile(node) + node.accept(new) + rescue CompilationError + end + def visit_array(node) visit_all(node.contents.parts) end @@ -194,22 +201,86 @@ def change_by(value) end end - # This class is meant to mirror RubyVM::InstructionSequence. It contains a - # list of instructions along with the metadata pertaining to them. It also - # functions as a builder for the instruction sequence. - class InstructionSequence - # This is a small data class that captures the level of a local variable - # table (the number of scopes to traverse) and the index of the local - # variable in that table. - class LocalVariable - attr_reader :level, :index + # This represents every local variable associated with an instruction + # sequence. There are two kinds of locals: plain locals that are what you + # expect, and block proxy locals, which represent local variables + # associated with blocks that were passed into the current instruction + # sequence. + class LocalTable + # A local representing a block passed into the current instruction + # sequence. + class BlockProxyLocal + attr_reader :name - def initialize(level, index) - @level = level + def initialize(name) + @name = name + end + end + + # A regular local variable. + class PlainLocal + attr_reader :name + + def initialize(name) + @name = name + end + end + + # The result of looking up a local variable in the current local table. + class Lookup + attr_reader :local, :index, :level + + def initialize(local, index, level) + @local = local @index = index + @level = level end end + attr_reader :locals + + def initialize + @locals = [] + end + + def find(name, level) + index = locals.index { |local| local.name == name } + Lookup.new(locals[index], index, level) if index + end + + def has?(name) + locals.any? { |local| local.name == name } + end + + def names + locals.map(&:name) + end + + def size + locals.length + end + + # Add a BlockProxyLocal to the local table. + def block_proxy(name) + locals << BlockProxyLocal.new(name) unless has?(name) + end + + # Add a PlainLocal to the local table. + def plain(name) + locals << PlainLocal.new(name) unless has?(name) + end + + # This is the offset from the top of the stack where this local variable + # lives. + def offset(index) + size - (index - 3) - 1 + end + end + + # This class is meant to mirror RubyVM::InstructionSequence. It contains a + # list of instructions along with the metadata pertaining to them. It also + # functions as a builder for the instruction sequence. + class InstructionSequence # The type of the instruction sequence. attr_reader :type @@ -230,9 +301,8 @@ def initialize(level, index) # The list of instructions for this instruction sequence. attr_reader :insns - # The array of symbols corresponding to the local variables of this - # instruction sequence. - attr_reader :local_variables + # The table of local variables. + attr_reader :local_table # The hash of names of instance and class variables pointing to the # index of their associated inline storage. @@ -254,7 +324,7 @@ def initialize(type, name, parent_iseq, location) @argument_size = 0 @argument_options = {} - @local_variables = [] + @local_table = LocalTable.new @inline_storages = {} @insns = [] @storage_index = 0 @@ -262,8 +332,8 @@ def initialize(type, name, parent_iseq, location) end def local_variable(name, level = 0) - if (index = local_variables.index(name)) - LocalVariable.new(level, index) + if (lookup = local_table.find(name, level)) + lookup elsif parent_iseq parent_iseq.local_variable(name, level + 1) else @@ -312,7 +382,7 @@ def to_a 1, { arg_size: argument_size, - local_size: local_variables.length, + local_size: local_table.size, stack_max: stack.maximum_size }, name, @@ -320,7 +390,7 @@ def to_a "", 1, type, - local_variables, + local_table.names, argument_options, [], insns.map { |insn| serialize(insn) } @@ -331,21 +401,21 @@ def to_a def serialize(insn) case insn[0] - when :getlocal_WC_0, :getlocal_WC_1, :getlocal, :setlocal_WC_0, - :setlocal_WC_1, :setlocal + when :checkkeyword, :getblockparamproxy, :getlocal_WC_0, + :getlocal_WC_1, :getlocal, :setlocal_WC_0, :setlocal_WC_1, + :setlocal iseq = self case insn[0] when :getlocal_WC_1, :setlocal_WC_1 iseq = iseq.parent_iseq - when :getlocal, :setlocal + when :getblockparamproxy, :getlocal, :setlocal insn[2].times { iseq = iseq.parent_iseq } end # Here we need to map the local variable index to the offset # from the top of the stack where it will be stored. - index = iseq.local_variables.length - (insn[1] - 3) - 1 - [insn[0], index, *insn[2..]] + [insn[0], iseq.local_table.offset(insn[1]), *insn[2..]] when :defineclass [insn[0], insn[1], insn[2].to_a, insn[3]] when :definemethod @@ -411,6 +481,16 @@ def branchunless(index) iseq.push([:branchunless, index]) end + def checkkeyword(index, keyword_index) + stack.change_by(+1) + iseq.push([:checkkeyword, index, keyword_index]) + end + + def concatarray + stack.change_by(-2 + 1) + iseq.push([:concatarray]) + end + def concatstrings(number) stack.change_by(-number + 1) iseq.push([:concatstrings, number]) @@ -451,6 +531,11 @@ def dupn(number) iseq.push([:dupn, number]) end + def getblockparamproxy(index, level) + stack.change_by(+1) + iseq.push([:getblockparamproxy, index, level]) + end + def getclassvariable(name) stack.change_by(+1) @@ -513,6 +598,11 @@ def intern iseq.push([:intern]) end + def invokeblock(method_id, argc, flag) + stack.change_by(-argc + 1) + iseq.push([:invokeblock, call_data(method_id, argc, flag)]) + end + def invokesuper(method_id, argc, flag, block_iseq) stack.change_by(-(argc + 1) + 1) @@ -584,11 +674,59 @@ def opt_getinlinecache(offset, inline_storage) iseq.push([:opt_getinlinecache, offset, inline_storage]) end + def opt_newarray_max(length) + if specialized_instruction + stack.change_by(-length + 1) + iseq.push([:opt_newarray_max, length]) + else + newarray(length) + send(:max, 0, VM_CALL_ARGS_SIMPLE) + end + end + + def opt_newarray_min(length) + if specialized_instruction + stack.change_by(-length + 1) + iseq.push([:opt_newarray_min, length]) + else + newarray(length) + send(:min, 0, VM_CALL_ARGS_SIMPLE) + end + end + def opt_setinlinecache(inline_storage) stack.change_by(-1 + 1) iseq.push([:opt_setinlinecache, inline_storage]) end + def opt_str_freeze(value) + if specialized_instruction + stack.change_by(+1) + iseq.push( + [ + :opt_str_freeze, + value, + call_data(:freeze, 0, VM_CALL_ARGS_SIMPLE) + ] + ) + else + putstring(value) + send(:freeze, 0, VM_CALL_ARGS_SIMPLE) + end + end + + def opt_str_uminus(value) + if specialized_instruction + stack.change_by(+1) + iseq.push( + [:opt_str_uminus, value, call_data(:-@, 0, VM_CALL_ARGS_SIMPLE)] + ) + else + putstring(value) + send(:-@, 0, VM_CALL_ARGS_SIMPLE) + end + end + def pop stack.change_by(-1) iseq.push([:pop]) @@ -894,10 +1032,31 @@ def visit_args(node) end def visit_array(node) - builder.duparray(node.accept(RubyVisitor.new)) - rescue RubyVisitor::CompilationError - visit_all(node.contents.parts) - builder.newarray(node.contents.parts.length) + if (compiled = RubyVisitor.compile(node)) + builder.duparray(compiled) + else + length = 0 + + node.contents.parts.each do |part| + if part.is_a?(ArgStar) + if length > 0 + builder.newarray(length) + length = 0 + end + + visit(part.value) + builder.concatarray + else + visit(part) + length += 1 + end + end + + builder.newarray(length) if length > 0 + if length > 0 && length != node.contents.parts.length + builder.concatarray + end + end end def visit_assign(node) @@ -985,9 +1144,11 @@ def visit_backref(node) end def visit_bare_assoc_hash(node) - builder.duphash(node.accept(RubyVisitor.new)) - rescue RubyVisitor::CompilationError - visit_all(node.assocs) + if (compiled = RubyVisitor.compile(node)) + builder.duphash(compiled) + else + visit_all(node.assocs) + end end def visit_binary(node) @@ -1017,15 +1178,88 @@ def visit_binary(node) end end + def visit_block(node) + with_instruction_sequence( + :block, + "block in #{current_iseq.name}", + current_iseq, + node + ) do + visit(node.block_var) + visit(node.bodystmt) + builder.leave + end + end + + def visit_block_var(node) + params = node.params + + if params.requireds.length == 1 && params.optionals.empty? && + !params.rest && params.posts.empty? && params.keywords.empty? && + !params.keyword_rest && !params.block + current_iseq.argument_options[:ambiguous_param0] = true + end + + visit(node.params) + + node.locals.each do |local| + current_iseq.local_table.plain(local.value.to_sym) + end + end + + def visit_blockarg(node) + current_iseq.argument_options[:block_start] = current_iseq.argument_size + current_iseq.local_table.block_proxy(node.name.value.to_sym) + current_iseq.argument_size += 1 + end + def visit_bodystmt(node) visit(node.statements) end def visit_call(node) + arg_parts = argument_parts(node.arguments) + + # First we're going to check if we're calling a method on an array + # literal without any arguments. In that case there are some + # specializations we might be able to perform. + if arg_parts.empty? && + (node.message.is_a?(Ident) || node.message.is_a?(Op)) + case node.receiver + when ArrayLiteral + parts = node.receiver.contents&.parts || [] + + if parts.none? { |part| part.is_a?(ArgStar) } && + RubyVisitor.compile(node.receiver).nil? + case node.message.value + when "max" + visit(node.receiver.contents) + builder.opt_newarray_max(parts.length) + return + when "min" + visit(node.receiver.contents) + builder.opt_newarray_min(parts.length) + return + end + end + when StringLiteral + if RubyVisitor.compile(node.receiver).nil? + case node.message.value + when "-@" + builder.opt_str_uminus(node.receiver.parts.first.value) + return + when "freeze" + builder.opt_str_freeze(node.receiver.parts.first.value) + return + end + end + end + end + node.receiver ? visit(node.receiver) : builder.putself visit(node.arguments) - arg_parts = argument_parts(node.arguments) + block_iseq = visit(node.block) if node.respond_to?(:block) && node.block if arg_parts.last.is_a?(ArgBlock) flag = node.receiver.nil? ? VM_CALL_FCALL : 0 @@ -1039,7 +1273,12 @@ def visit_call(node) flag |= VM_CALL_KW_SPLAT end - builder.send(node.message.value.to_sym, arg_parts.length - 1, flag) + builder.send( + node.message.value.to_sym, + arg_parts.length - 1, + flag, + block_iseq + ) else flag = 0 arg_parts.each do |arg_part| @@ -1051,38 +1290,77 @@ def visit_call(node) end end - flag |= VM_CALL_ARGS_SIMPLE if flag == 0 + flag |= VM_CALL_ARGS_SIMPLE if block_iseq.nil? && flag == 0 flag |= VM_CALL_FCALL if node.receiver.nil? - builder.send(node.message.value.to_sym, arg_parts.length, flag) + builder.send( + node.message.value.to_sym, + arg_parts.length, + flag, + block_iseq + ) end end + def visit_class(node) + name = node.constant.constant.value.to_sym + class_iseq = + with_instruction_sequence( + :class, + "", + current_iseq, + node + ) do + visit(node.bodystmt) + builder.leave + end + + flags = VM_DEFINECLASS_TYPE_CLASS + + case node.constant + when ConstPathRef + flags |= VM_DEFINECLASS_FLAG_SCOPED + visit(node.constant.parent) + when ConstRef + builder.putspecialobject(VM_SPECIAL_OBJECT_CONST_BASE) + when TopConstRef + flags |= VM_DEFINECLASS_FLAG_SCOPED + builder.putobject(Object) + end + + if node.superclass + flags |= VM_DEFINECLASS_FLAG_HAS_SUPERCLASS + visit(node.superclass) + else + builder.putnil + end + + builder.defineclass(name, class_iseq, flags) + end + def visit_command(node) - call_node = - CallNode.new( + visit_call( + CommandCall.new( receiver: nil, operator: nil, message: node.message, arguments: node.arguments, + block: node.block, location: node.location ) - - call_node.comments.concat(node.comments) - visit_call(call_node) + ) end def visit_command_call(node) - call_node = - CallNode.new( + visit_call( + CommandCall.new( receiver: node.receiver, operator: node.operator, message: node.message, arguments: node.arguments, + block: node.block, location: node.location ) - - call_node.comments.concat(node.comments) - visit_call(call_node) + ) end def visit_const_path_field(node) @@ -1119,10 +1397,7 @@ def visit_defined(node) # that we put it into the local table. if node.value.target.is_a?(VarField) && node.value.target.value.is_a?(Ident) - name = node.value.target.value.value.to_sym - unless current_iseq.local_variables.include?(name) - current_iseq.local_variables << name - end + current_iseq.local_table.plain(node.value.target.value.value.to_sym) end builder.putobject("assignment") @@ -1184,6 +1459,17 @@ def visit_else(node) builder.pop unless last_statement? end + def visit_elsif(node) + visit_if( + IfNode.new( + predicate: node.predicate, + statements: node.statements, + consequent: node.consequent, + location: node.location + ) + ) + end + def visit_field(node) visit(node.parent) end @@ -1196,9 +1482,7 @@ def visit_for(node) visit(node.collection) name = node.index.value.value.to_sym - unless current_iseq.local_variables.include?(name) - current_iseq.local_variables << name - end + current_iseq.local_table.plain(name) block_iseq = with_instruction_sequence( @@ -1212,7 +1496,7 @@ def visit_for(node) current_iseq.argument_options[:ambiguous_param0] = true current_iseq.argument_size += 1 - current_iseq.local_variables << 2 + current_iseq.local_table.plain(2) builder.getlocal(0, 0) @@ -1269,6 +1553,22 @@ def visit_if(node) end end + def visit_if_op(node) + visit_if( + IfNode.new( + predicate: node.predicate, + statements: node.truthy, + consequent: + Else.new( + keyword: Kw.new(value: "else", location: Location.default), + statements: node.falsy, + location: Location.default + ), + location: Location.default + ) + ) + end + def visit_imaginary(node) builder.putobject(node.accept(RubyVisitor.new)) end @@ -1277,10 +1577,29 @@ def visit_int(node) builder.putobject(node.accept(RubyVisitor.new)) end + def visit_kwrest_param(node) + current_iseq.argument_options[:kwrest] = current_iseq.argument_size + current_iseq.argument_size += 1 + current_iseq.local_table.plain(node.name.value.to_sym) + end + def visit_label(node) builder.putobject(node.accept(RubyVisitor.new)) end + def visit_method_add_block(node) + visit_call( + CommandCall.new( + receiver: node.call.receiver, + operator: node.call.operator, + message: node.call.message, + arguments: node.call.arguments, + block: node.block, + location: node.location + ) + ) + end + def visit_module(node) name = node.constant.constant.value.to_sym module_iseq = @@ -1389,17 +1708,17 @@ def visit_params(node) argument_options[:lead_num] = 0 node.requireds.each do |required| - current_iseq.local_variables << required.value.to_sym + current_iseq.local_table.plain(required.value.to_sym) current_iseq.argument_size += 1 argument_options[:lead_num] += 1 end end node.optionals.each do |(optional, value)| - index = current_iseq.local_variables.length + index = current_iseq.local_table.size name = optional.value.to_sym - current_iseq.local_variables << name + current_iseq.local_table.plain(name) current_iseq.argument_size += 1 unless argument_options.key?(:opt) @@ -1418,11 +1737,52 @@ def visit_params(node) argument_options[:post_num] = 0 node.posts.each do |post| - current_iseq.local_variables << post.value.to_sym + current_iseq.local_table.plain(post.value.to_sym) current_iseq.argument_size += 1 argument_options[:post_num] += 1 end end + + if node.keywords.any? + argument_options[:kwbits] = 0 + argument_options[:keyword] = [] + checkkeywords = [] + + node.keywords.each_with_index do |(keyword, value), keyword_index| + name = keyword.value.chomp(":").to_sym + index = current_iseq.local_table.size + + current_iseq.local_table.plain(name) + current_iseq.argument_size += 1 + argument_options[:kwbits] += 1 + + if value.nil? + argument_options[:keyword] << name + else + begin + compiled = value.accept(RubyVisitor.new) + argument_options[:keyword] << [name, compiled] + rescue RubyVisitor::CompilationError + argument_options[:keyword] << [name] + checkkeywords << builder.checkkeyword(-1, keyword_index) + branchif = builder.branchif(-1) + visit(value) + builder.setlocal(index, 0) + branchif[1] = builder.label + end + end + end + + name = node.keyword_rest ? 3 : 2 + current_iseq.argument_size += 1 + current_iseq.local_table.plain(name) + + lookup = current_iseq.local_table.find(name, 0) + checkkeywords.each { |checkkeyword| checkkeyword[1] = lookup.index } + end + + visit(node.keyword_rest) if node.keyword_rest + visit(node.block) if node.block end def visit_paren(node) @@ -1496,7 +1856,7 @@ def visit_regexp_literal(node) end def visit_rest_param(node) - current_iseq.local_variables << node.name.value.to_sym + current_iseq.local_table.plain(node.name.value.to_sym) current_iseq.argument_options[:rest_start] = current_iseq.argument_size current_iseq.argument_size += 1 end @@ -1592,17 +1952,24 @@ def visit_tstring_content(node) end def visit_unary(node) - visit(node.statement) - method_id = case node.operator when "+", "-" - :"#{node.operator}@" + "#{node.operator}@" else - node.operator.to_sym + node.operator end - builder.send(method_id, 0, VM_CALL_ARGS_SIMPLE) + visit_call( + CommandCall.new( + receiver: node.statement, + operator: nil, + message: Ident.new(value: method_id, location: Location.default), + arguments: nil, + block: nil, + location: Location.default + ) + ) end def visit_undef(node) @@ -1622,9 +1989,7 @@ def visit_var_field(node) current_iseq.inline_storage_for(name) when Ident name = node.value.value.to_sym - unless current_iseq.local_variables.include?(name) - current_iseq.local_variables << name - end + current_iseq.local_table.plain(name) current_iseq.local_variable(name) end end @@ -1639,8 +2004,14 @@ def visit_var_ref(node) when GVar builder.getglobal(node.value.value.to_sym) when Ident - local_variable = current_iseq.local_variable(node.value.value.to_sym) - builder.getlocal(local_variable.index, local_variable.level) + lookup = current_iseq.local_variable(node.value.value.to_sym) + + case lookup.local + when LocalTable::BlockProxyLocal + builder.getblockparamproxy(lookup.index, lookup.level) + when LocalTable::PlainLocal + builder.getlocal(lookup.index, lookup.level) + end when IVar name = node.value.value.to_sym builder.getinstancevariable(name) @@ -1723,6 +2094,12 @@ def visit_xstring_literal(node) builder.send(:`, 1, VM_CALL_FCALL | VM_CALL_ARGS_SIMPLE) end + def visit_yield(node) + parts = argument_parts(node.arguments) + visit_all(parts) + builder.invokeblock(nil, parts.length, VM_CALL_ARGS_SIMPLE) + end + def visit_zsuper(_node) builder.putself builder.invokesuper( @@ -1746,6 +2123,8 @@ def argument_parts(node) node.parts when ArgParen node.arguments.parts + when Paren + node.contents.parts end end diff --git a/test/compiler_test.rb b/test/compiler_test.rb index 1c6cde38..fe0bd1f6 100644 --- a/test/compiler_test.rb +++ b/test/compiler_test.rb @@ -92,6 +92,11 @@ class CompilerTest < Minitest::Test "foo.size", "foo.succ", "/foo/ =~ \"foo\" && $1", + "\"foo\".freeze", + "\"foo\".freeze(1)", + "-\"foo\"", + "\"foo\".-@", + "\"foo\".-@(1)", # Various method calls "foo?", "foo.bar", @@ -252,9 +257,11 @@ class CompilerTest < Minitest::Test "foo || bar", "if foo then bar end", "if foo then bar else baz end", + "if foo then bar elsif baz then qux end", "foo if bar", "foo while bar", "for i in [1, 2, 3] do i end", + "foo ? bar : baz", # Constructed values "foo..bar", "foo...bar", @@ -266,6 +273,7 @@ class CompilerTest < Minitest::Test "%W[foo \#{bar} baz]", "%I[foo \#{bar} baz]", "[foo, bar] + [baz, qux]", + "[foo, bar, *baz, qux]", "{ foo: bar, baz: qux }", "{ :foo => bar, :baz => qux }", "{ foo => bar, baz => qux }", @@ -273,16 +281,25 @@ class CompilerTest < Minitest::Test "[$1, $2, $3, $4, $5, $6, $7, $8, $9]", "/foo \#{bar} baz/", "%r{foo \#{bar} baz}", + "[1, 2, 3].max", + "[foo, bar, baz].max", + "[foo, bar, baz].max(1)", + "[1, 2, 3].min", + "[foo, bar, baz].min", + "[foo, bar, baz].min(1)", # Core method calls "alias foo bar", "alias :foo :bar", + "super", + "super(1)", + "super(1, 2, 3)", "undef foo", "undef :foo", "undef foo, bar, baz", "undef :foo, :bar, :baz", - "super", - "super(1)", - "super(1, 2, 3)", + "def foo; yield; end", + "def foo; yield(1); end", + "def foo; yield(1, 2, 3); end", # defined? usage "defined?(foo)", "defined?(\"foo\")", @@ -317,12 +334,49 @@ class CompilerTest < Minitest::Test "def foo(*bar, baz, qux); end", "def foo(bar, *baz, qux); end", "def foo(bar, baz, *qux, quaz); end", + "def foo(bar, baz, &qux); end", + "def foo(bar, *baz, &qux); end", + "def foo(&qux); qux.call; end", + "def foo(bar:); end", + "def foo(bar:, baz:); end", + "def foo(bar: 1); end", + "def foo(bar: 1, baz: 2); end", + "def foo(bar: baz); end", + "def foo(bar: 1, baz: qux); end", + "def foo(bar: qux, baz: 1); end", + "def foo(bar: baz, qux: qaz); end", + "def foo(**rest); end", + "def foo(bar:, **rest); end", + "def foo(bar:, baz:, **rest); end", + "def foo(bar: 1, **rest); end", + "def foo(bar: 1, baz: 2, **rest); end", + "def foo(bar: baz, **rest); end", + "def foo(bar: 1, baz: qux, **rest); end", + "def foo(bar: qux, baz: 1, **rest); end", + "def foo(bar: baz, qux: qaz, **rest); end", # Class/module definitions "module Foo; end", "module ::Foo; end", "module Foo::Bar; end", "module ::Foo::Bar; end", - "module Foo; module Bar; end; end" + "module Foo; module Bar; end; end", + "class Foo; end", + "class ::Foo; end", + "class Foo::Bar; end", + "class ::Foo::Bar; end", + "class Foo; class Bar; end; end", + "class Foo < Baz; end", + "class ::Foo < Baz; end", + "class Foo::Bar < Baz; end", + "class ::Foo::Bar < Baz; end", + "class Foo; class Bar < Baz; end; end", + "class Foo < baz; end", + # Block + "foo do end", + "foo {}", + "foo do |bar| end", + "foo { |bar| }", + "foo { |bar; baz| }" ] # These are the combinations of instructions that we're going to test.