Feature #15144 ยป 0001-Implement-Enumerator-Chain-and-Enumerator-chain-Feat.patch
enumerator.c | ||
---|---|---|
************************************************/
|
||
#include "ruby/ruby.h"
|
||
#include "internal.h"
|
||
#include "id.h"
|
||
... | ... | |
static VALUE generator_allocate(VALUE klass);
|
||
static VALUE generator_init(VALUE obj, VALUE proc);
|
||
static VALUE rb_cEnumChain;
|
||
struct enum_chain {
|
||
VALUE enums;
|
||
long pos;
|
||
};
|
||
static VALUE rb_cArithSeq;
|
||
/*
|
||
... | ... | |
return rb_attr_get(self, id_result);
|
||
}
|
||
/*
|
||
* Document-class: Enumerator::Chain
|
||
*
|
||
* Enumerator::Chain is a subclass of Enumerator, which represents a
|
||
* chain of enumerables that works as a single enumerator.
|
||
*/
|
||
static void
|
||
enum_chain_mark(void *p)
|
||
{
|
||
struct enum_chain *ptr = p;
|
||
rb_gc_mark(ptr->enums);
|
||
}
|
||
#define enum_chain_free RUBY_TYPED_DEFAULT_FREE
|
||
static size_t
|
||
enum_chain_memsize(const void *p)
|
||
{
|
||
return sizeof(struct enum_chain);
|
||
}
|
||
static const rb_data_type_t enum_chain_data_type = {
|
||
"chain",
|
||
{
|
||
enum_chain_mark,
|
||
enum_chain_free,
|
||
enum_chain_memsize,
|
||
},
|
||
0, 0, RUBY_TYPED_FREE_IMMEDIATELY
|
||
};
|
||
static struct enum_chain *
|
||
enum_chain_ptr(VALUE obj)
|
||
{
|
||
struct enum_chain *ptr;
|
||
TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr);
|
||
if (!ptr || ptr->enums == Qundef) {
|
||
rb_raise(rb_eArgError, "uninitialized chain");
|
||
}
|
||
return ptr;
|
||
}
|
||
/* :nodoc: */
|
||
static VALUE
|
||
enum_chain_allocate(VALUE klass)
|
||
{
|
||
struct enum_chain *ptr;
|
||
VALUE obj;
|
||
obj = TypedData_Make_Struct(klass, struct enum_chain, &enum_chain_data_type, ptr);
|
||
ptr->enums = Qundef;
|
||
ptr->pos = -1;
|
||
return obj;
|
||
}
|
||
/*
|
||
* call-seq:
|
||
* Enumerator::Chain.new(*enums) -> enum
|
||
* Enumerator.chain(*enums) -> enum
|
||
*
|
||
* Generates an Enumerator::Chain object from the given
|
||
* enumerable objects.
|
||
*
|
||
* e = Enumerator.chain(1..3, [4, 5])
|
||
* e.to_a #=> [1, 2, 3, 4, 5]
|
||
*/
|
||
static VALUE
|
||
enum_chain_initialize(VALUE obj, VALUE enums)
|
||
{
|
||
struct enum_chain *ptr;
|
||
rb_check_frozen(obj);
|
||
TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr);
|
||
if (!ptr) rb_raise(rb_eArgError, "unallocated chain");
|
||
ptr->enums = rb_obj_freeze(enums);
|
||
ptr->pos = -1;
|
||
return obj;
|
||
}
|
||
/* :nodoc: */
|
||
static VALUE
|
||
enum_chain_init_copy(VALUE obj, VALUE orig)
|
||
{
|
||
struct enum_chain *ptr0, *ptr1;
|
||
if (!OBJ_INIT_COPY(obj, orig)) return obj;
|
||
ptr0 = enum_chain_ptr(orig);
|
||
TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr1);
|
||
if (!ptr1) rb_raise(rb_eArgError, "unallocated chain");
|
||
ptr1->enums = ptr0->enums;
|
||
ptr1->pos = ptr0->pos;
|
||
return obj;
|
||
}
|
||
static VALUE
|
||
enum_chain_total_size(VALUE enums)
|
||
{
|
||
VALUE total = INT2FIX(0);
|
||
RARRAY_PTR_USE(enums, ptr, {
|
||
long i;
|
||
for (i = 0; i < RARRAY_LEN(enums); i++) {
|
||
VALUE size = enum_size(ptr[i]);
|
||
if (NIL_P(size) || (RB_TYPE_P(size, T_FLOAT) && isinf(NUM2DBL(size)))) {
|
||
return size;
|
||
}
|
||
if (!RB_INTEGER_TYPE_P(size)) {
|
||
return Qnil;
|
||
}
|
||
total = rb_funcall(total, '+', 1, size);
|
||
}
|
||
});
|
||
return total;
|
||
}
|
||
/*
|
||
* call-seq:
|
||
* obj.size -> integer
|
||
*
|
||
* Returns the total size of the enumerator chain calculated by
|
||
* summing up the size of each enumerable in the chain. If any of the
|
||
* enumerables reports its size as nil or Float::INFINITY, that value
|
||
* is returned as the total size.
|
||
*/
|
||
static VALUE
|
||
enum_chain_size(VALUE obj)
|
||
{
|
||
return enum_chain_total_size(enum_chain_ptr(obj)->enums);
|
||
}
|
||
static VALUE
|
||
enum_chain_enum_size(VALUE obj, VALUE args, VALUE eobj)
|
||
{
|
||
return enum_chain_size(obj);
|
||
}
|
||
static VALUE
|
||
enum_chain_yield_block(VALUE arg, VALUE block, int argc, VALUE *argv)
|
||
{
|
||
return rb_funcallv(block, rb_intern("call"), argc, argv);
|
||
}
|
||
static VALUE
|
||
enum_chain_enum_no_size(VALUE obj, VALUE args, VALUE eobj)
|
||
{
|
||
return Qnil;
|
||
}
|
||
/*
|
||
* call-seq:
|
||
* obj.each(*args) { |...| ... } -> obj
|
||
* obj.each(*args) -> enumerator
|
||
*
|
||
* Iterates over the first enumerable by calling the "each" method on
|
||
* it with the given arguments until it is exhausted, then proceeds to
|
||
* the next enumerable, until all of the enumerables are exhausted.
|
||
*
|
||
* If no block is given, returns an enumerator.
|
||
*/
|
||
static VALUE
|
||
enum_chain_each(int argc, VALUE *argv, VALUE obj)
|
||
{
|
||
VALUE enums, block;
|
||
struct enum_chain *objptr;
|
||
RETURN_SIZED_ENUMERATOR(obj, argc, argv, argc > 0 ? enum_chain_enum_no_size : enum_chain_enum_size);
|
||
objptr = enum_chain_ptr(obj);
|
||
enums = objptr->enums;
|
||
block = rb_block_proc();
|
||
RARRAY_PTR_USE(enums, ptr, {
|
||
long i;
|
||
for (i = 0; i < RARRAY_LEN(enums); i++) {
|
||
objptr->pos = i;
|
||
rb_block_call(ptr[i], id_each, argc, argv, enum_chain_yield_block, block);
|
||
}
|
||
});
|
||
return obj;
|
||
}
|
||
/*
|
||
* call-seq:
|
||
* obj.rewind -> obj
|
||
*
|
||
* Rewinds the enumerator chain by calling the "rewind" method on each
|
||
* enumerable in reverse order. Each call is performed only if the
|
||
* enumerable responds to the method.
|
||
*/
|
||
static VALUE
|
||
enum_chain_rewind(VALUE obj)
|
||
{
|
||
struct enum_chain *objptr = enum_chain_ptr(obj);
|
||
VALUE enums = objptr->enums;
|
||
RARRAY_PTR_USE(enums, ptr, {
|
||
long i;
|
||
for (i = objptr->pos; 0 <= i && i < RARRAY_LEN(enums); objptr->pos = --i) {
|
||
rb_check_funcall(ptr[i], id_rewind, 0, 0);
|
||
}
|
||
});
|
||
return obj;
|
||
}
|
||
static VALUE
|
||
inspect_enum_chain(VALUE obj, VALUE dummy, int recur)
|
||
{
|
||
VALUE klass = rb_obj_class(obj);
|
||
struct enum_chain *ptr;
|
||
TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr);
|
||
if (!ptr || ptr->enums == Qundef) {
|
||
return rb_sprintf("#<%"PRIsVALUE": uninitialized>", rb_class_path(klass));
|
||
}
|
||
if (recur) {
|
||
return rb_sprintf("#<%"PRIsVALUE": ...>", rb_class_path(klass));
|
||
}
|
||
return rb_sprintf("#<%"PRIsVALUE": %+"PRIsVALUE">", rb_class_path(klass), ptr->enums);
|
||
}
|
||
/*
|
||
* call-seq:
|
||
* obj.inspect -> string
|
||
*
|
||
* Returns a printable version of the enumerator chain.
|
||
*/
|
||
static VALUE
|
||
enum_chain_inspect(VALUE obj)
|
||
{
|
||
return rb_exec_recursive(inspect_enum_chain, obj, 0);
|
||
}
|
||
/*
|
||
* call-seq:
|
||
* e.chain(*enums) -> enumerator
|
||
*
|
||
* Returns an Enumerator::Chain object generated from this enumerator
|
||
* and given enumerables.
|
||
*
|
||
* e = (1..3).each.chain([4, 5])
|
||
* e.to_a #=> [1, 2, 3, 4, 5]
|
||
*/
|
||
static VALUE
|
||
enumerator_s_chain(int argc, VALUE enums)
|
||
{
|
||
return enum_chain_initialize(enum_chain_allocate(rb_cEnumChain), enums);
|
||
}
|
||
/*
|
||
* call-seq:
|
||
* e.chain(*enums) -> enumerator
|
||
*
|
||
* Returns an Enumerator::Chain object generated from this enumerator
|
||
* and given enumerables.
|
||
*
|
||
* e = (1..3).each.chain([4, 5])
|
||
* e.to_a #=> [1, 2, 3, 4, 5]
|
||
*/
|
||
static VALUE
|
||
enumerator_chain(int argc, VALUE *argv, VALUE obj)
|
||
{
|
||
VALUE enums = rb_ary_new_from_values(1, &obj);
|
||
rb_ary_cat(enums, argv, argc);
|
||
return enum_chain_initialize(enum_chain_allocate(rb_cEnumChain), enums);
|
||
}
|
||
/*
|
||
* call-seq:
|
||
* e + enum -> enumerator
|
||
*
|
||
* Returns an Enumerator::Chain object generated from this enumerator
|
||
* and a given enumerable.
|
||
*
|
||
* e = (1..3).each + [4, 5]
|
||
* e.to_a #=> [1, 2, 3, 4, 5]
|
||
*/
|
||
static VALUE
|
||
enumerator_plus(VALUE obj, VALUE eobj)
|
||
{
|
||
VALUE enums = rb_ary_new_from_args(2, obj, eobj);
|
||
return enum_chain_initialize(enum_chain_allocate(rb_cEnumChain), enums);
|
||
}
|
||
/*
|
||
* Document-class: Enumerator::ArithmeticSequence
|
||
*
|
||
... | ... | |
rb_define_method(rb_cEnumerator, "rewind", enumerator_rewind, 0);
|
||
rb_define_method(rb_cEnumerator, "inspect", enumerator_inspect, 0);
|
||
rb_define_method(rb_cEnumerator, "size", enumerator_size, 0);
|
||
rb_define_method(rb_cEnumerator, "chain", enumerator_chain, -1);
|
||
rb_define_method(rb_cEnumerator, "+", enumerator_plus, 1);
|
||
/* Lazy */
|
||
rb_cLazy = rb_define_class_under(rb_cEnumerator, "Lazy", rb_cEnumerator);
|
||
... | ... | |
rb_define_method(rb_cYielder, "yield", yielder_yield, -2);
|
||
rb_define_method(rb_cYielder, "<<", yielder_yield_push, 1);
|
||
/* Chain */
|
||
rb_cEnumChain = rb_define_class_under(rb_cEnumerator, "Chain", rb_cEnumerator);
|
||
rb_define_alloc_func(rb_cEnumChain, enum_chain_allocate);
|
||
rb_define_method(rb_cEnumChain, "initialize", enum_chain_initialize, -2);
|
||
rb_define_method(rb_cEnumChain, "initialize_copy", enum_chain_init_copy, 1);
|
||
rb_define_method(rb_cEnumChain, "each", enum_chain_each, -1);
|
||
rb_define_method(rb_cEnumChain, "size", enum_chain_size, 0);
|
||
rb_define_method(rb_cEnumChain, "rewind", enum_chain_rewind, 0);
|
||
rb_define_method(rb_cEnumChain, "inspect", enum_chain_inspect, 0);
|
||
rb_define_singleton_method(rb_cEnumerator, "chain", enumerator_s_chain, -2);
|
||
/* ArithmeticSequence */
|
||
rb_cArithSeq = rb_define_class_under(rb_cEnumerator, "ArithmeticSequence", rb_cEnumerator);
|
||
rb_undef_alloc_func(rb_cArithSeq);
|
test/ruby/test_enumerator.rb | ||
---|---|---|
assert_equal([0, 1], u.force)
|
||
assert_equal([0, 1], u.force)
|
||
end
|
||
def test_chain_and_plus
|
||
a = (1..5).each
|
||
e1 = a.chain()
|
||
assert_kind_of(Enumerator::Chain, e1)
|
||
assert_equal(5, e1.size)
|
||
ary = []
|
||
e1.each { |x| ary << x }
|
||
assert_equal([1, 2, 3, 4, 5], ary)
|
||
e2 = a + [6, 7, 8]
|
||
assert_kind_of(Enumerator::Chain, e2)
|
||
assert_equal(8, e2.size)
|
||
ary = []
|
||
e2.each { |x| ary << x }
|
||
assert_equal([1, 2, 3, 4, 5, 6, 7, 8], ary)
|
||
e3 = a.chain([6, 7], 8.step)
|
||
assert_kind_of(Enumerator::Chain, e3)
|
||
assert_equal(Float::INFINITY, e3.size)
|
||
ary = []
|
||
e3.take(10).each { |x| ary << x }
|
||
assert_equal([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], ary)
|
||
# `a + b + c` should not return `Enumerator.chain(a, b, c)`
|
||
# because it is expected that `(a + b).each` be called.
|
||
e4 = e2.dup
|
||
class << e4
|
||
attr_reader :each_is_called
|
||
def each
|
||
super
|
||
@each_is_called = true
|
||
end
|
||
end
|
||
e5 = e4 + 9.step
|
||
assert_kind_of(Enumerator::Chain, e5)
|
||
assert_equal(Float::INFINITY, e5.size)
|
||
ary = []
|
||
e5.take(10).each { |x| ary << x }
|
||
assert_equal([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], ary)
|
||
assert_equal(true, e4.each_is_called)
|
||
end
|
||
def test_chained_enums
|
||
a = (1..5).each
|
||
e0 = Enumerator::Chain.new()
|
||
assert_kind_of(Enumerator::Chain, e0)
|
||
assert_equal(0, e0.size)
|
||
ary = []
|
||
e0.each { |x| ary << x }
|
||
assert_equal([], ary)
|
||
e1 = Enumerator::Chain.new(a)
|
||
assert_kind_of(Enumerator::Chain, e1)
|
||
assert_equal(5, e1.size)
|
||
ary = []
|
||
e1.each { |x| ary << x }
|
||
assert_equal([1, 2, 3, 4, 5], ary)
|
||
e2 = Enumerator.chain(a, [6, 7, 8])
|
||
assert_kind_of(Enumerator::Chain, e2)
|
||
assert_equal(8, e2.size)
|
||
ary = []
|
||
e2.each { |x| ary << x }
|
||
assert_equal([1, 2, 3, 4, 5, 6, 7, 8], ary)
|
||
e3 = Enumerator.chain(a, [6, 7], 8.step)
|
||
assert_kind_of(Enumerator::Chain, e3)
|
||
assert_equal(Float::INFINITY, e3.size)
|
||
ary = []
|
||
e3.take(10).each { |x| ary << x }
|
||
assert_equal([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], ary)
|
||
e4 = Enumerator.chain(a, Enumerator.new { |y| y << 6 << 7 << 8 })
|
||
assert_kind_of(Enumerator::Chain, e4)
|
||
assert_equal(nil, e4.size)
|
||
ary = []
|
||
e4.each { |x| ary << x }
|
||
assert_equal([1, 2, 3, 4, 5, 6, 7, 8], ary)
|
||
e5 = Enumerator.chain(e1, e2)
|
||
assert_kind_of(Enumerator::Chain, e5)
|
||
assert_equal(13, e5.size)
|
||
ary = []
|
||
e5.each { |x| ary << x }
|
||
assert_equal([1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 6, 7, 8], ary)
|
||
rewound = []
|
||
e1.define_singleton_method(:rewind) { rewound << object_id }
|
||
e2.define_singleton_method(:rewind) { rewound << object_id }
|
||
e5.rewind
|
||
assert_equal(rewound, [e2.object_id, e1.object_id])
|
||
rewound = []
|
||
a = [1]
|
||
e6 = Enumerator.chain(a)
|
||
a.define_singleton_method(:rewind) { rewound << object_id }
|
||
e6.rewind
|
||
assert_equal(rewound, [])
|
||
assert_equal(
|
||
'#<Enumerator::Chain: [' +
|
||
'#<Enumerator::Chain: [' +
|
||
'#<Enumerator: 1..5:each>' +
|
||
']>, ' +
|
||
'#<Enumerator::Chain: [' +
|
||
'#<Enumerator: 1..5:each>, ' +
|
||
'[6, 7, 8]' +
|
||
']>' +
|
||
']>',
|
||
e5.inspect
|
||
)
|
||
end
|
||
end
|
||