diff --git a/lib/syntax_tree/visitor/compiler.rb b/lib/syntax_tree/visitor/compiler.rb index 10c59a77..bac8b914 100644 --- a/lib/syntax_tree/visitor/compiler.rb +++ b/lib/syntax_tree/visitor/compiler.rb @@ -94,6 +94,10 @@ def visit_label(node) node.value.chomp(":").to_sym end + def visit_mrhs(node) + visit_all(node.parts) + end + def visit_qsymbols(node) node.elements.map { |element| visit(element).to_sym } end @@ -209,7 +213,7 @@ def change_by(value) class LocalTable # A local representing a block passed into the current instruction # sequence. - class BlockProxyLocal + class BlockLocal attr_reader :name def initialize(name) @@ -260,9 +264,9 @@ def size locals.length end - # Add a BlockProxyLocal to the local table. - def block_proxy(name) - locals << BlockProxyLocal.new(name) unless has?(name) + # Add a BlockLocal to the local table. + def block(name) + locals << BlockLocal.new(name) unless has?(name) end # Add a PlainLocal to the local table. @@ -336,8 +340,6 @@ def local_variable(name, level = 0) lookup elsif parent_iseq parent_iseq.local_variable(name, level + 1) - else - raise "Unknown local variable: #{name}" end end @@ -388,7 +390,7 @@ def to_a name, "", "", - 1, + location.start_line, type, local_table.names, argument_options, @@ -401,15 +403,15 @@ def to_a def serialize(insn) case insn[0] - when :checkkeyword, :getblockparamproxy, :getlocal_WC_0, - :getlocal_WC_1, :getlocal, :setlocal_WC_0, :setlocal_WC_1, - :setlocal + when :checkkeyword, :getblockparam, :getblockparamproxy, + :getlocal_WC_0, :getlocal_WC_1, :getlocal, :setlocal_WC_0, + :setlocal_WC_1, :setlocal iseq = self case insn[0] when :getlocal_WC_1, :setlocal_WC_1 iseq = iseq.parent_iseq - when :getblockparamproxy, :getlocal, :setlocal + when :getblockparam, :getblockparamproxy, :getlocal, :setlocal insn[2].times { iseq = iseq.parent_iseq } end @@ -418,12 +420,14 @@ def serialize(insn) [insn[0], iseq.local_table.offset(insn[1]), *insn[2..]] when :defineclass [insn[0], insn[1], insn[2].to_a, insn[3]] - when :definemethod + when :definemethod, :definesmethod [insn[0], insn[1], insn[2].to_a] when :send # For any instructions that push instruction sequences onto the # stack, we need to call #to_a on them as well. [insn[0], insn[1], (insn[2].to_a if insn[2])] + when :once + [insn[0], insn[1].to_a, insn[2]] else insn end @@ -476,6 +480,11 @@ def branchif(index) iseq.push([:branchif, index]) end + def branchnil(index) + stack.change_by(-1) + iseq.push([:branchnil, index]) + end + def branchunless(index) stack.change_by(-1) iseq.push([:branchunless, index]) @@ -511,6 +520,11 @@ def definemethod(name, method_iseq) iseq.push([:definemethod, name, method_iseq]) end + def definesmethod(name, method_iseq) + stack.change_by(-1) + iseq.push([:definesmethod, name, method_iseq]) + end + def dup stack.change_by(-1 + 2) iseq.push([:dup]) @@ -531,6 +545,16 @@ def dupn(number) iseq.push([:dupn, number]) end + def expandarray(length, flag) + stack.change_by(-1 + length) + iseq.push([:expandarray, length, flag]) + end + + def getblockparam(index, level) + stack.change_by(+1) + iseq.push([:getblockparam, index, level]) + end + def getblockparamproxy(index, level) stack.change_by(+1) iseq.push([:getblockparamproxy, index, level]) @@ -645,6 +669,11 @@ def objtostring(method_id, argc, flag) iseq.push([:objtostring, call_data(method_id, argc, flag)]) end + def once(postexe_iseq, inline_storage) + stack.change_by(+1) + iseq.push([:once, postexe_iseq, inline_storage]) + end + def opt_getconstant_path(names) if RUBY_VERSION >= "3.2" stack.change_by(+1) @@ -992,6 +1021,10 @@ def initialize( @last_statement = false end + def visit_BEGIN(node) + visit(node.statements) + end + def visit_CHAR(node) if frozen_string_literal builder.putobject(node.value[1..]) @@ -1000,6 +1033,27 @@ def visit_CHAR(node) end end + def visit_END(node) + name = "block in #{current_iseq.name}" + once_iseq = + with_instruction_sequence(:block, name, current_iseq, node) do + postexe_iseq = + with_instruction_sequence(:block, name, current_iseq, node) do + *statements, last_statement = node.statements.body + visit_all(statements) + with_last_statement { visit(last_statement) } + builder.leave + end + + builder.putspecialobject(VM_SPECIAL_OBJECT_VMCORE) + builder.send(:"core#set_postexe", 0, VM_CALL_FCALL, postexe_iseq) + builder.leave + end + + builder.once(once_iseq, current_iseq.inline_storage) + builder.pop + end + def visit_alias(node) builder.putspecialobject(VM_SPECIAL_OBJECT_VMCORE) builder.putspecialobject(VM_SPECIAL_OBJECT_CBASE) @@ -1209,7 +1263,7 @@ def visit_block_var(node) def visit_blockarg(node) current_iseq.argument_options[:block_start] = current_iseq.argument_size - current_iseq.local_table.block_proxy(node.name.value.to_sym) + current_iseq.local_table.block(node.name.value.to_sym) current_iseq.argument_size += 1 end @@ -1218,13 +1272,28 @@ def visit_bodystmt(node) end def visit_call(node) + if node.is_a?(CallNode) + return( + visit_call( + CommandCall.new( + receiver: node.receiver, + operator: node.operator, + message: node.message, + arguments: node.arguments, + block: nil, + location: node.location + ) + ) + ) + end + arg_parts = argument_parts(node.arguments) + argc = arg_parts.length # First we're going to check if we're calling a method on an array # literal without any arguments. In that case there are some # specializations we might be able to perform. - if arg_parts.empty? && - (node.message.is_a?(Ident) || node.message.is_a?(Op)) + if argc == 0 && (node.message.is_a?(Ident) || node.message.is_a?(Op)) case node.receiver when ArrayLiteral parts = node.receiver.contents&.parts || [] @@ -1256,48 +1325,98 @@ def visit_call(node) end end - node.receiver ? visit(node.receiver) : builder.putself + if node.receiver + if node.receiver.is_a?(VarRef) && + ( + lookup = + current_iseq.local_variable(node.receiver.value.value.to_sym) + ) && lookup.local.is_a?(LocalTable::BlockLocal) + builder.getblockparamproxy(lookup.index, lookup.level) + else + visit(node.receiver) + end + else + builder.putself + end - visit(node.arguments) - block_iseq = visit(node.block) if node.respond_to?(:block) && node.block + branchnil = + if node.operator&.value == "&." + builder.dup + builder.branchnil(-1) + end - if arg_parts.last.is_a?(ArgBlock) - flag = node.receiver.nil? ? VM_CALL_FCALL : 0 - flag |= VM_CALL_ARGS_BLOCKARG + flag = 0 - if arg_parts.any? { |part| part.is_a?(ArgStar) } + arg_parts.each do |arg_part| + case arg_part + when ArgBlock + argc -= 1 + flag |= VM_CALL_ARGS_BLOCKARG + visit(arg_part) + when ArgStar flag |= VM_CALL_ARGS_SPLAT - end + visit(arg_part) + when ArgsForward + flag |= VM_CALL_ARGS_SPLAT | VM_CALL_ARGS_BLOCKARG + + lookup = current_iseq.local_table.find(:*, 0) + builder.getlocal(lookup.index, lookup.level) + builder.splatarray(arg_parts.length != 1) - if arg_parts.any? { |part| part.is_a?(BareAssocHash) } + lookup = current_iseq.local_table.find(:&, 0) + builder.getblockparamproxy(lookup.index, lookup.level) + when BareAssocHash flag |= VM_CALL_KW_SPLAT + visit(arg_part) + else + visit(arg_part) end + end - builder.send( - node.message.value.to_sym, - arg_parts.length - 1, - flag, - block_iseq - ) - else - flag = 0 - arg_parts.each do |arg_part| - case arg_part - when ArgStar - flag |= VM_CALL_ARGS_SPLAT - when BareAssocHash - flag |= VM_CALL_KW_SPLAT - end + block_iseq = visit(node.block) if node.block + flag |= VM_CALL_ARGS_SIMPLE if block_iseq.nil? && flag == 0 + flag |= VM_CALL_FCALL if node.receiver.nil? + + builder.send(node.message.value.to_sym, argc, flag, block_iseq) + branchnil[1] = builder.label if branchnil + end + + def visit_case(node) + visit(node.value) if node.value + + clauses = [] + else_clause = nil + + current = node.consequent + + while current + clauses << current + + if (current = current.consequent).is_a?(Else) + else_clause = current + break end + end - flag |= VM_CALL_ARGS_SIMPLE if block_iseq.nil? && flag == 0 - flag |= VM_CALL_FCALL if node.receiver.nil? - builder.send( - node.message.value.to_sym, - arg_parts.length, - flag, - block_iseq - ) + branches = + clauses.map do |clause| + visit(clause.arguments) + builder.topn(1) + builder.send(:===, 1, VM_CALL_FCALL | VM_CALL_ARGS_SIMPLE) + [clause, builder.branchif(:label_00)] + end + + builder.pop + + else_clause ? visit(else_clause) : builder.putnil + + builder.leave + + branches.each_with_index do |(clause, branchif), index| + builder.leave if index != 0 + branchif[1] = builder.label + builder.pop + visit(clause) end end @@ -1386,7 +1505,14 @@ def visit_def(node) end name = node.name.value.to_sym - builder.definemethod(name, method_iseq) + + if node.target + visit(node.target) + builder.definesmethod(name, method_iseq) + else + builder.definemethod(name, method_iseq) + end + builder.putobject(name) end @@ -1524,8 +1650,8 @@ def visit_heredoc(node) elsif node.parts.length == 1 && node.parts.first.is_a?(TStringContent) visit(node.parts.first) else - visit_string_parts(node) - builder.concatstrings(node.parts.length) + length = visit_string_parts(node) + builder.concatstrings(length) end end @@ -1587,6 +1713,33 @@ def visit_label(node) builder.putobject(node.accept(RubyVisitor.new)) end + def visit_lambda(node) + lambda_iseq = + with_instruction_sequence( + :block, + "block in #{current_iseq.name}", + current_iseq, + node + ) do + visit(node.params) + visit(node.statements) + builder.leave + end + + builder.putspecialobject(VM_SPECIAL_OBJECT_VMCORE) + builder.send(:lambda, 0, VM_CALL_FCALL, lambda_iseq) + end + + def visit_lambda_var(node) + visit_block_var(node) + end + + def visit_massign(node) + visit(node.value) + builder.dup + visit(node.target) + end + def visit_method_add_block(node) visit_call( CommandCall.new( @@ -1600,6 +1753,21 @@ def visit_method_add_block(node) ) end + def visit_mlhs(node) + lookups = [] + + node.parts.each do |part| + case part + when VarField + lookups << visit(part) + end + end + + builder.expandarray(lookups.length, 0) + + lookups.each { |lookup| builder.setlocal(lookup.index, lookup.level) } + end + def visit_module(node) name = node.constant.constant.value.to_sym module_iseq = @@ -1630,6 +1798,15 @@ def visit_module(node) builder.defineclass(name, module_iseq, flags) end + def visit_mrhs(node) + if (compiled = RubyVisitor.compile(node)) + builder.duparray(compiled) + else + visit_all(node.parts) + builder.newarray(node.parts.length) + end + end + def visit_not(node) visit(node.statement) builder.send(:!, 0, VM_CALL_ARGS_SIMPLE) @@ -1781,7 +1958,22 @@ def visit_params(node) checkkeywords.each { |checkkeyword| checkkeyword[1] = lookup.index } end - visit(node.keyword_rest) if node.keyword_rest + if node.keyword_rest.is_a?(ArgsForward) + current_iseq.local_table.plain(:*) + current_iseq.local_table.plain(:&) + + current_iseq.argument_options[ + :rest_start + ] = current_iseq.argument_size + current_iseq.argument_options[ + :block_start + ] = current_iseq.argument_size + 1 + + current_iseq.argument_size += 2 + elsif node.keyword_rest + visit(node.keyword_rest) + end + visit(node.block) if node.block end @@ -1798,17 +1990,23 @@ def visit_program(node) end end - statements = - node.statements.body.select do |statement| - case statement - when Comment, EmbDoc, EndContent, VoidStmt - false - else - true - end + preexes = [] + statements = [] + + node.statements.body.each do |statement| + case statement + when Comment, EmbDoc, EndContent, VoidStmt + # ignore + when BEGINBlock + preexes << statement + else + statements << statement end + end with_instruction_sequence(:top, "", nil, node) do + visit_all(preexes) + if statements.empty? builder.putnil else @@ -1849,10 +2047,9 @@ def visit_rational(node) def visit_regexp_literal(node) builder.putobject(node.accept(RubyVisitor.new)) rescue RubyVisitor::CompilationError - visit_string_parts(node) - flags = RubyVisitor.new.visit_regexp_literal_flags(node) - builder.toregexp(flags, node.parts.length) + length = visit_string_parts(node) + builder.toregexp(flags, length) end def visit_rest_param(node) @@ -1861,6 +2058,28 @@ def visit_rest_param(node) current_iseq.argument_size += 1 end + def visit_sclass(node) + visit(node.target) + builder.putnil + + singleton_iseq = + with_instruction_sequence( + :class, + "singleton class", + current_iseq, + node + ) do + visit(node.bodystmt) + builder.leave + end + + builder.defineclass( + :singletonclass, + singleton_iseq, + VM_DEFINECLASS_TYPE_SINGLETON_CLASS + ) + end + def visit_statements(node) statements = node.body.select do |statement| @@ -1896,8 +2115,8 @@ def visit_string_literal(node) if node.parts.length == 1 && node.parts.first.is_a?(TStringContent) visit(node.parts.first) else - visit_string_parts(node) - builder.concatstrings(node.parts.length) + length = visit_string_parts(node) + builder.concatstrings(length) end end @@ -1924,13 +2143,7 @@ def visit_symbols(node) element.parts.first.is_a?(TStringContent) builder.putobject(element.parts.first.value.to_sym) else - length = element.parts.length - unless element.parts.first.is_a?(TStringContent) - builder.putobject("") - length += 1 - end - - visit_string_parts(element) + length = visit_string_parts(element) builder.concatstrings(length) builder.intern end @@ -1982,6 +2195,48 @@ def visit_undef(node) end end + def visit_unless(node) + visit(node.predicate) + branchunless = builder.branchunless(-1) + node.consequent ? visit(node.consequent) : builder.putnil + + if last_statement? + builder.leave + branchunless[1] = builder.label + + visit(node.statements) + else + builder.pop + + if node.consequent + jump = builder.jump(-1) + branchunless[1] = builder.label + visit(node.consequent) + jump[1] = builder.label + else + branchunless[1] = builder.label + end + end + end + + def visit_until(node) + jumps = [] + + jumps << builder.jump(-1) + builder.putnil + builder.pop + jumps << builder.jump(-1) + + label = builder.label + visit(node.statements) + builder.pop + jumps.each { |jump| jump[1] = builder.label } + + visit(node.predicate) + builder.branchunless(label) + builder.putnil if last_statement? + end + def visit_var_field(node) case node.value when CVar, IVar @@ -1989,8 +2244,13 @@ def visit_var_field(node) current_iseq.inline_storage_for(name) when Ident name = node.value.value.to_sym - current_iseq.local_table.plain(name) - current_iseq.local_variable(name) + + if (local_variable = current_iseq.local_variable(name)) + local_variable + else + current_iseq.local_table.plain(name) + current_iseq.local_variable(name) + end end end @@ -2007,8 +2267,8 @@ def visit_var_ref(node) lookup = current_iseq.local_variable(node.value.value.to_sym) case lookup.local - when LocalTable::BlockProxyLocal - builder.getblockparamproxy(lookup.index, lookup.level) + when LocalTable::BlockLocal + builder.getblockparam(lookup.index, lookup.level) when LocalTable::PlainLocal builder.getlocal(lookup.index, lookup.level) end @@ -2036,6 +2296,10 @@ def visit_vcall(node) builder.send(node.value.value.to_sym, 0, flag) end + def visit_when(node) + visit(node.statements) + end + def visit_while(node) jumps = [] @@ -2058,13 +2322,7 @@ def visit_word(node) if node.parts.length == 1 && node.parts.first.is_a?(TStringContent) visit(node.parts.first) else - length = node.parts.length - unless node.parts.first.is_a?(TStringContent) - builder.putobject("") - length += 1 - end - - visit_string_parts(node) + length = visit_string_parts(node) builder.concatstrings(length) end end @@ -2089,8 +2347,8 @@ def visit_words(node) def visit_xstring_literal(node) builder.putself - visit_string_parts(node) - builder.concatstrings(node.parts.length) if node.parts.length > 1 + length = visit_string_parts(node) + builder.concatstrings(node.parts.length) if length > 1 builder.send(:`, 1, VM_CALL_FCALL | VM_CALL_ARGS_SIMPLE) end @@ -2122,7 +2380,11 @@ def argument_parts(node) when Args node.parts when ArgParen - node.arguments.parts + if node.arguments.is_a?(ArgsForward) + [node.arguments] + else + node.arguments.parts + end when Paren node.contents.parts end @@ -2248,6 +2510,13 @@ def push_interpolate # heredocs, etc. This method will visit all the parts of a string within # those containers. def visit_string_parts(node) + length = 0 + + unless node.parts.first.is_a?(TStringContent) + builder.putobject("") + length += 1 + end + node.parts.each do |part| case part when StringDVar @@ -2259,7 +2528,11 @@ def visit_string_parts(node) when TStringContent builder.putobject(part.accept(RubyVisitor.new)) end + + length += 1 end + + length end # The current instruction sequence that we're compiling is always stored @@ -2297,12 +2570,13 @@ def with_instruction_sequence(type, name, parent_iseq, node) # last statement of a scope and allow visit methods to query that # information. def with_last_statement + previous = @last_statement @last_statement = true begin yield ensure - @last_statement = false + @last_statement = previous end end diff --git a/test/compiler_test.rb b/test/compiler_test.rb index fe0bd1f6..632b3e55 100644 --- a/test/compiler_test.rb +++ b/test/compiler_test.rb @@ -2,9 +2,17 @@ return if !defined?(RubyVM::InstructionSequence) || RUBY_VERSION < "3.1" require_relative "test_helper" +require "fiddle" module SyntaxTree class CompilerTest < Minitest::Test + ISEQ_LOAD = + Fiddle::Function.new( + Fiddle::Handle::DEFAULT["rb_iseq_load"], + [Fiddle::TYPE_VOIDP] * 3, + Fiddle::TYPE_VOIDP + ) + CASES = [ # Various literals placed on the stack "true", @@ -130,6 +138,12 @@ class CompilerTest < Minitest::Test "foo ||= 1", "foo <<= 1", "foo ^= 1", + "foo, bar = 1, 2", + "foo, bar, = 1, 2", + "foo, bar, baz = 1, 2", + "foo, bar = 1, 2, 3", + "foo = 1, 2, 3", + "foo, * = 1, 2, 3", # Instance variables "@foo", "@foo = 1", @@ -253,15 +267,27 @@ class CompilerTest < Minitest::Test "Foo::Bar.baz = 1", "::Foo::Bar.baz = 1", # Control flow + "foo&.bar", + "foo&.bar(1)", + "foo&.bar 1, 2, 3", + "foo&.bar {}", "foo && bar", "foo || bar", "if foo then bar end", "if foo then bar else baz end", "if foo then bar elsif baz then qux end", "foo if bar", + "unless foo then bar end", + "unless foo then bar else baz end", + "foo unless bar", "foo while bar", + "while foo do bar end", + "foo until bar", + "until foo do bar end", "for i in [1, 2, 3] do i end", "foo ? bar : baz", + "case foo when bar then 1 end", + "case foo when bar then 1 else 2 end", # Constructed values "foo..bar", "foo...bar", @@ -336,6 +362,7 @@ class CompilerTest < Minitest::Test "def foo(bar, baz, *qux, quaz); end", "def foo(bar, baz, &qux); end", "def foo(bar, *baz, &qux); end", + "def foo(&qux); qux; end", "def foo(&qux); qux.call; end", "def foo(bar:); end", "def foo(bar:, baz:); end", @@ -354,6 +381,12 @@ class CompilerTest < Minitest::Test "def foo(bar: 1, baz: qux, **rest); end", "def foo(bar: qux, baz: 1, **rest); end", "def foo(bar: baz, qux: qaz, **rest); end", + "def foo(...); end", + "def foo(bar, ...); end", + "def foo(...); bar(...); end", + "def foo(bar, ...); baz(1, 2, 3, ...); end", + "def self.foo; end", + "def foo.bar(baz); end", # Class/module definitions "module Foo; end", "module ::Foo; end", @@ -371,12 +404,19 @@ class CompilerTest < Minitest::Test "class ::Foo::Bar < Baz; end", "class Foo; class Bar < Baz; end; end", "class Foo < baz; end", + "class << Object; end", + "class << ::String; end", # Block "foo do end", "foo {}", "foo do |bar| end", "foo { |bar| }", - "foo { |bar; baz| }" + "foo { |bar; baz| }", + "-> do end", + "-> {}", + "-> (bar) do end", + "-> (bar) {}", + "-> (bar; baz) { }" ] # These are the combinations of instructions that we're going to test. @@ -398,6 +438,11 @@ class CompilerTest < Minitest::Test end end + def test_evaluation + assert_evaluates 5, "2 + 3" + assert_evaluates 5, "a = 2; b = 3; a + b" + end + private def serialize_iseq(iseq) @@ -431,5 +476,17 @@ def assert_compiles(source, **options) serialize_iseq(program.accept(Visitor::Compiler.new(**options))) ) end + + def assert_evaluates(expected, source, **options) + program = SyntaxTree.parse(source) + compiled = program.accept(Visitor::Compiler.new(**options)).to_a + + # Temporary hack until we get these working. + compiled[4][:node_id] = 11 + compiled[4][:node_ids] = [1, 0, 3, 2, 6, 7, 9, -1] + + iseq = Fiddle.dlunwrap(ISEQ_LOAD.call(Fiddle.dlwrap(compiled), 0, nil)) + assert_equal expected, iseq.eval + end end end