Index: array.c =================================================================== --- array.c (revision 26907) +++ array.c (working copy) @@ -1408,7 +1408,7 @@ rb_ary_each(VALUE ary) { long i; - RETURN_ENUMERATOR(ary, 0, 0); + RETURN_ENUMERATOR_WITH_LEN(ary, 0, 0, ary); for (i=0; i array @@ -3921,9 +3934,9 @@ rb_ary_permutation(int argc, VALUE *argv, VALUE ary) long r, n, i; n = RARRAY_LEN(ary); /* Array length */ - RETURN_ENUMERATOR(ary, argc, argv); /* Return enumerator if no block */ rb_scan_args(argc, argv, "01", &num); r = NIL_P(num) ? n : NUM2LONG(num); /* Permutation size from argument */ + RETURN_ENUMERATOR_WITH_LEN(ary, argc, argv, INT2NUM(permu_len(n, r))); if (r < 0 || n < r) { /* no permutations: yield nothing */ @@ -4004,8 +4017,8 @@ rb_ary_combination(VALUE ary, VALUE num) long n, i, len; n = NUM2LONG(num); - RETURN_ENUMERATOR(ary, 1, &num); len = RARRAY_LEN(ary); + RETURN_ENUMERATOR_WITH_LEN(ary, 1, &num, INT2NUM(combi_len(len, n))); if (n < 0 || len < n) { /* yield nothing */ } Index: enumerator.c =================================================================== --- enumerator.c (revision 26907) +++ enumerator.c (working copy) @@ -302,6 +302,26 @@ rb_enumeratorize(VALUE obj, VALUE meth, int argc, VALUE *argv) return enumerator_init(enumerator_allocate(rb_cEnumerator), obj, meth, argc, argv); } +VALUE +rb_enumerator_set_length(VALUE self, VALUE length) +{ + rb_ivar_set(self, rb_intern("length"), length); + return self; +} + +static VALUE +enumerator_length(VALUE self) +{ + VALUE len = rb_attr_get(self, rb_intern("length")); + switch (TYPE(len)) { + case T_ARRAY: + return INT2NUM(RARRAY_LEN(len)); + case T_STRING: + return INT2NUM(RSTRING_LEN(len)); + } + return len; +} + static VALUE enumerator_block_call(VALUE obj, rb_block_call_func *func, VALUE arg) { @@ -1070,6 +1090,7 @@ Init_Enumerator(void) rb_define_method(rb_cEnumerator, "feed", enumerator_feed, 1); rb_define_method(rb_cEnumerator, "rewind", enumerator_rewind, 0); rb_define_method(rb_cEnumerator, "inspect", enumerator_inspect, 0); + rb_define_method(rb_cEnumerator, "length", enumerator_length, 0); rb_eStopIteration = rb_define_class("StopIteration", rb_eIndexError); rb_define_method(rb_eStopIteration, "result", stop_result, 0); Index: include/ruby/intern.h =================================================================== --- include/ruby/intern.h (revision 26907) +++ include/ruby/intern.h (working copy) @@ -185,11 +185,19 @@ VALUE rb_fiber_alive_p(VALUE); /* enum.c */ /* enumerator.c */ VALUE rb_enumeratorize(VALUE, VALUE, int, VALUE *); +VALUE rb_enumerator_set_length(VALUE, VALUE); #define RETURN_ENUMERATOR(obj, argc, argv) do { \ if (!rb_block_given_p()) \ return rb_enumeratorize(obj, ID2SYM(rb_frame_this_func()), \ argc, argv); \ } while (0) +#define RETURN_ENUMERATOR_WITH_LEN(obj, argc, argv, len) do { \ + if (!rb_block_given_p()) { \ + VALUE _e = rb_enumeratorize(obj, ID2SYM(rb_frame_this_func()), \ + argc, argv); \ + return rb_enumerator_set_length(_e, len); \ + } \ +} while (0) /* error.c */ VALUE rb_exc_new(VALUE, const char*, long); VALUE rb_exc_new2(VALUE, const char*); Index: test/ruby/test_array.rb =================================================================== --- test/ruby/test_array.rb (revision 26907) +++ test/ruby/test_array.rb (working copy) @@ -1370,6 +1370,11 @@ class TestArray < Test::Unit::TestCase assert_equal(@cls[], @cls[1,2,3,4].combination(5).to_a) end + def test_combination_length + assert_equal(6, @cls[1,2,3,4].combination(2).length) + assert_equal(4, @cls[1,2,3,4].combination(3).length) + end + def test_product assert_equal(@cls[[1,4],[1,5],[2,4],[2,5],[3,4],[3,5]], @cls[1,2,3].product([4,5])) @@ -1403,6 +1408,15 @@ class TestArray < Test::Unit::TestCase assert_equal(@cls[1, 2, 3, 4].permutation.to_a, b) end + def test_permutation_length + assert_equal(6, @cls[1, 2, 3].permutation.length) + assert_equal(1, @cls[1, 2, 3].permutation(0).length) + assert_equal(3, @cls[1, 2, 3].permutation(1).length) + assert_equal(6, @cls[1, 2, 3].permutation(3).length) + assert_equal(0, @cls[1, 2, 3].permutation(4).length) + assert_equal(24, @cls[1, 2, 3, 4].permutation(3).length) + end + def test_take assert_equal([1,2,3], [1,2,3,4,5,0].take(3)) assert_raise(ArgumentError, '[ruby-dev:34123]') { [1,2].take(-1) }