From 8625d01adeffe015d7e29e0d77861747de94de63 Mon Sep 17 00:00:00 2001 From: Akinori MUSHA Date: Fri, 16 Nov 2018 04:57:44 +0900 Subject: [PATCH] Implement Enumerator::Chain and Enumerator#{+,chain} [Feature #15144] Enumerator::Chain is a subclass of Enumerator, which represents a chain of enumerables that works as a single enumerator. ```ruby e = Enumerator.chain(1..3, [4, 5]) e.to_a #=> [1, 2, 3, 4, 5] e = (1..3).each + [4, 5] e.to_a #=> [1, 2, 3, 4, 5] ``` --- enumerator.c | 327 +++++++++++++++++++++++++++++++++++ test/ruby/test_enumerator.rb | 116 ++++++++++++- 2 files changed, 442 insertions(+), 1 deletion(-) diff --git a/enumerator.c b/enumerator.c index 274583a3de..c6be9ee837 100644 --- a/enumerator.c +++ b/enumerator.c @@ -12,6 +12,7 @@ ************************************************/ +#include "ruby/ruby.h" #include "internal.h" #include "id.h" @@ -161,6 +162,13 @@ struct proc_entry { static VALUE generator_allocate(VALUE klass); static VALUE generator_init(VALUE obj, VALUE proc); +static VALUE rb_cEnumChain; + +struct enum_chain { + VALUE enums; + long pos; +}; + static VALUE rb_cArithSeq; /* @@ -2411,6 +2419,312 @@ stop_result(VALUE self) return rb_attr_get(self, id_result); } +/* + * Document-class: Enumerator::Chain + * + * Enumerator::Chain is a subclass of Enumerator, which represents a + * chain of enumerables that works as a single enumerator. + */ + +static void +enum_chain_mark(void *p) +{ + struct enum_chain *ptr = p; + rb_gc_mark(ptr->enums); +} + +#define enum_chain_free RUBY_TYPED_DEFAULT_FREE + +static size_t +enum_chain_memsize(const void *p) +{ + return sizeof(struct enum_chain); +} + +static const rb_data_type_t enum_chain_data_type = { + "chain", + { + enum_chain_mark, + enum_chain_free, + enum_chain_memsize, + }, + 0, 0, RUBY_TYPED_FREE_IMMEDIATELY +}; + +static struct enum_chain * +enum_chain_ptr(VALUE obj) +{ + struct enum_chain *ptr; + + TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr); + if (!ptr || ptr->enums == Qundef) { + rb_raise(rb_eArgError, "uninitialized chain"); + } + return ptr; +} + +/* :nodoc: */ +static VALUE +enum_chain_allocate(VALUE klass) +{ + struct enum_chain *ptr; + VALUE obj; + + obj = TypedData_Make_Struct(klass, struct enum_chain, &enum_chain_data_type, ptr); + ptr->enums = Qundef; + ptr->pos = -1; + + return obj; +} + +/* + * call-seq: + * Enumerator::Chain.new(*enums) -> enum + * Enumerator.chain(*enums) -> enum + * + * Generates an Enumerator::Chain object from the given + * enumerable objects. + * + * e = Enumerator.chain(1..3, [4, 5]) + * e.to_a #=> [1, 2, 3, 4, 5] + */ +static VALUE +enum_chain_initialize(VALUE obj, VALUE enums) +{ + struct enum_chain *ptr; + + rb_check_frozen(obj); + TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr); + + if (!ptr) rb_raise(rb_eArgError, "unallocated chain"); + + ptr->enums = rb_obj_freeze(enums); + ptr->pos = -1; + + return obj; +} + +/* :nodoc: */ +static VALUE +enum_chain_init_copy(VALUE obj, VALUE orig) +{ + struct enum_chain *ptr0, *ptr1; + + if (!OBJ_INIT_COPY(obj, orig)) return obj; + ptr0 = enum_chain_ptr(orig); + + TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr1); + + if (!ptr1) rb_raise(rb_eArgError, "unallocated chain"); + + ptr1->enums = ptr0->enums; + ptr1->pos = ptr0->pos; + + return obj; +} + +static VALUE +enum_chain_total_size(VALUE enums) +{ + VALUE total = INT2FIX(0); + + RARRAY_PTR_USE(enums, ptr, { + long i; + + for (i = 0; i < RARRAY_LEN(enums); i++) { + VALUE size = enum_size(ptr[i]); + + if (NIL_P(size) || (RB_TYPE_P(size, T_FLOAT) && isinf(NUM2DBL(size)))) { + return size; + } + if (!RB_INTEGER_TYPE_P(size)) { + return Qnil; + } + + total = rb_funcall(total, '+', 1, size); + } + }); + + return total; +} + +/* + * call-seq: + * obj.size -> integer + * + * Returns the total size of the enumerator chain calculated by + * summing up the size of each enumerable in the chain. If any of the + * enumerables reports its size as nil or Float::INFINITY, that value + * is returned as the total size. + */ +static VALUE +enum_chain_size(VALUE obj) +{ + return enum_chain_total_size(enum_chain_ptr(obj)->enums); +} + +static VALUE +enum_chain_enum_size(VALUE obj, VALUE args, VALUE eobj) +{ + return enum_chain_size(obj); +} + +static VALUE +enum_chain_yield_block(VALUE arg, VALUE block, int argc, VALUE *argv) +{ + return rb_funcallv(block, rb_intern("call"), argc, argv); +} + +static VALUE +enum_chain_enum_no_size(VALUE obj, VALUE args, VALUE eobj) +{ + return Qnil; +} + +/* + * call-seq: + * obj.each(*args) { |...| ... } -> obj + * obj.each(*args) -> enumerator + * + * Iterates over the first enumerable by calling the "each" method on + * it with the given arguments until it is exhausted, then proceeds to + * the next enumerable, until all of the enumerables are exhausted. + * + * If no block is given, returns an enumerator. + */ +static VALUE +enum_chain_each(int argc, VALUE *argv, VALUE obj) +{ + VALUE enums, block; + struct enum_chain *objptr; + + RETURN_SIZED_ENUMERATOR(obj, argc, argv, argc > 0 ? enum_chain_enum_no_size : enum_chain_enum_size); + + objptr = enum_chain_ptr(obj); + enums = objptr->enums; + block = rb_block_proc(); + + RARRAY_PTR_USE(enums, ptr, { + long i; + + for (i = 0; i < RARRAY_LEN(enums); i++) { + objptr->pos = i; + rb_block_call(ptr[i], id_each, argc, argv, enum_chain_yield_block, block); + } + }); + + return obj; +} + +/* + * call-seq: + * obj.rewind -> obj + * + * Rewinds the enumerator chain by calling the "rewind" method on each + * enumerable in reverse order. Each call is performed only if the + * enumerable responds to the method. + */ +static VALUE +enum_chain_rewind(VALUE obj) +{ + struct enum_chain *objptr = enum_chain_ptr(obj); + VALUE enums = objptr->enums; + + RARRAY_PTR_USE(enums, ptr, { + long i; + + for (i = objptr->pos; 0 <= i && i < RARRAY_LEN(enums); objptr->pos = --i) { + rb_check_funcall(ptr[i], id_rewind, 0, 0); + } + }); + + return obj; +} + +static VALUE +inspect_enum_chain(VALUE obj, VALUE dummy, int recur) +{ + VALUE klass = rb_obj_class(obj); + struct enum_chain *ptr; + + TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr); + + if (!ptr || ptr->enums == Qundef) { + return rb_sprintf("#<%"PRIsVALUE": uninitialized>", rb_class_path(klass)); + } + + if (recur) { + return rb_sprintf("#<%"PRIsVALUE": ...>", rb_class_path(klass)); + } + + return rb_sprintf("#<%"PRIsVALUE": %+"PRIsVALUE">", rb_class_path(klass), ptr->enums); +} + +/* + * call-seq: + * obj.inspect -> string + * + * Returns a printable version of the enumerator chain. + */ +static VALUE +enum_chain_inspect(VALUE obj) +{ + return rb_exec_recursive(inspect_enum_chain, obj, 0); +} + +/* + * call-seq: + * e.chain(*enums) -> enumerator + * + * Returns an Enumerator::Chain object generated from this enumerator + * and given enumerables. + * + * e = (1..3).each.chain([4, 5]) + * e.to_a #=> [1, 2, 3, 4, 5] + */ +static VALUE +enumerator_s_chain(int argc, VALUE enums) +{ + return enum_chain_initialize(enum_chain_allocate(rb_cEnumChain), enums); +} + +/* + * call-seq: + * e.chain(*enums) -> enumerator + * + * Returns an Enumerator::Chain object generated from this enumerator + * and given enumerables. + * + * e = (1..3).each.chain([4, 5]) + * e.to_a #=> [1, 2, 3, 4, 5] + */ +static VALUE +enumerator_chain(int argc, VALUE *argv, VALUE obj) +{ + VALUE enums = rb_ary_new_from_values(1, &obj); + rb_ary_cat(enums, argv, argc); + + return enum_chain_initialize(enum_chain_allocate(rb_cEnumChain), enums); +} + +/* + * call-seq: + * e + enum -> enumerator + * + * Returns an Enumerator::Chain object generated from this enumerator + * and a given enumerable. + * + * e = (1..3).each + [4, 5] + * e.to_a #=> [1, 2, 3, 4, 5] + */ +static VALUE +enumerator_plus(VALUE obj, VALUE eobj) +{ + VALUE enums = rb_ary_new_from_args(2, obj, eobj); + + return enum_chain_initialize(enum_chain_allocate(rb_cEnumChain), enums); +} + /* * Document-class: Enumerator::ArithmeticSequence * @@ -2907,6 +3221,8 @@ InitVM_Enumerator(void) rb_define_method(rb_cEnumerator, "rewind", enumerator_rewind, 0); rb_define_method(rb_cEnumerator, "inspect", enumerator_inspect, 0); rb_define_method(rb_cEnumerator, "size", enumerator_size, 0); + rb_define_method(rb_cEnumerator, "chain", enumerator_chain, -1); + rb_define_method(rb_cEnumerator, "+", enumerator_plus, 1); /* Lazy */ rb_cLazy = rb_define_class_under(rb_cEnumerator, "Lazy", rb_cEnumerator); @@ -2960,6 +3276,17 @@ InitVM_Enumerator(void) rb_define_method(rb_cYielder, "yield", yielder_yield, -2); rb_define_method(rb_cYielder, "<<", yielder_yield_push, 1); + /* Chain */ + rb_cEnumChain = rb_define_class_under(rb_cEnumerator, "Chain", rb_cEnumerator); + rb_define_alloc_func(rb_cEnumChain, enum_chain_allocate); + rb_define_method(rb_cEnumChain, "initialize", enum_chain_initialize, -2); + rb_define_method(rb_cEnumChain, "initialize_copy", enum_chain_init_copy, 1); + rb_define_method(rb_cEnumChain, "each", enum_chain_each, -1); + rb_define_method(rb_cEnumChain, "size", enum_chain_size, 0); + rb_define_method(rb_cEnumChain, "rewind", enum_chain_rewind, 0); + rb_define_method(rb_cEnumChain, "inspect", enum_chain_inspect, 0); + rb_define_singleton_method(rb_cEnumerator, "chain", enumerator_s_chain, -2); + /* ArithmeticSequence */ rb_cArithSeq = rb_define_class_under(rb_cEnumerator, "ArithmeticSequence", rb_cEnumerator); rb_undef_alloc_func(rb_cArithSeq); diff --git a/test/ruby/test_enumerator.rb b/test/ruby/test_enumerator.rb index 0839c2c3dd..3223646af7 100644 --- a/test/ruby/test_enumerator.rb +++ b/test/ruby/test_enumerator.rb @@ -670,5 +670,119 @@ def test_uniq assert_equal([0, 1], u.force) assert_equal([0, 1], u.force) end + + def test_chain_and_plus + a = (1..5).each + + e1 = a.chain() + assert_kind_of(Enumerator::Chain, e1) + assert_equal(5, e1.size) + ary = [] + e1.each { |x| ary << x } + assert_equal([1, 2, 3, 4, 5], ary) + + e2 = a + [6, 7, 8] + assert_kind_of(Enumerator::Chain, e2) + assert_equal(8, e2.size) + ary = [] + e2.each { |x| ary << x } + assert_equal([1, 2, 3, 4, 5, 6, 7, 8], ary) + + e3 = a.chain([6, 7], 8.step) + assert_kind_of(Enumerator::Chain, e3) + assert_equal(Float::INFINITY, e3.size) + ary = [] + e3.take(10).each { |x| ary << x } + assert_equal([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], ary) + + # `a + b + c` should not return `Enumerator.chain(a, b, c)` + # because it is expected that `(a + b).each` be called. + e4 = e2.dup + class << e4 + attr_reader :each_is_called + def each + super + @each_is_called = true + end + end + e5 = e4 + 9.step + assert_kind_of(Enumerator::Chain, e5) + assert_equal(Float::INFINITY, e5.size) + ary = [] + e5.take(10).each { |x| ary << x } + assert_equal([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], ary) + assert_equal(true, e4.each_is_called) + end + + def test_chained_enums + a = (1..5).each + + e0 = Enumerator::Chain.new() + assert_kind_of(Enumerator::Chain, e0) + assert_equal(0, e0.size) + ary = [] + e0.each { |x| ary << x } + assert_equal([], ary) + + e1 = Enumerator::Chain.new(a) + assert_kind_of(Enumerator::Chain, e1) + assert_equal(5, e1.size) + ary = [] + e1.each { |x| ary << x } + assert_equal([1, 2, 3, 4, 5], ary) + + e2 = Enumerator.chain(a, [6, 7, 8]) + assert_kind_of(Enumerator::Chain, e2) + assert_equal(8, e2.size) + ary = [] + e2.each { |x| ary << x } + assert_equal([1, 2, 3, 4, 5, 6, 7, 8], ary) + + e3 = Enumerator.chain(a, [6, 7], 8.step) + assert_kind_of(Enumerator::Chain, e3) + assert_equal(Float::INFINITY, e3.size) + ary = [] + e3.take(10).each { |x| ary << x } + assert_equal([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], ary) + + e4 = Enumerator.chain(a, Enumerator.new { |y| y << 6 << 7 << 8 }) + assert_kind_of(Enumerator::Chain, e4) + assert_equal(nil, e4.size) + ary = [] + e4.each { |x| ary << x } + assert_equal([1, 2, 3, 4, 5, 6, 7, 8], ary) + + e5 = Enumerator.chain(e1, e2) + assert_kind_of(Enumerator::Chain, e5) + assert_equal(13, e5.size) + ary = [] + e5.each { |x| ary << x } + assert_equal([1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 6, 7, 8], ary) + + rewound = [] + e1.define_singleton_method(:rewind) { rewound << object_id } + e2.define_singleton_method(:rewind) { rewound << object_id } + e5.rewind + assert_equal(rewound, [e2.object_id, e1.object_id]) + + rewound = [] + a = [1] + e6 = Enumerator.chain(a) + a.define_singleton_method(:rewind) { rewound << object_id } + e6.rewind + assert_equal(rewound, []) + + assert_equal( + '#' + + ']>, ' + + '#, ' + + '[6, 7, 8]' + + ']>' + + ']>', + e5.inspect + ) + end end - -- 2.19.1