diff --git a/array.c b/array.c index c2a4e85..485f617 100644 --- a/array.c +++ b/array.c @@ -3775,6 +3775,35 @@ rb_ary_shuffle(VALUE ary) } +static VALUE +ary_sample_with_replace(VALUE const ary, long const n) +{ + VALUE result; + VALUE* ptr_result; + long i; + + VALUE const* const ptr = RARRAY_PTR(ary); + long const len = RARRAY_LEN(ary); + + switch (n) { + case 0: + return rb_ary_new2(0); + case 1: + return rb_ary_new4(1, &ptr[(long)(rb_genrand_real()*len)]); + default: + break; + } + result = rb_ary_new2(n); + ptr_result = RARRAY_PTR(result); + RB_GC_GUARD(ary); + for (i = 0; i < n; ++i) { + long const j = (long)(rb_genrand_real()*len); + ptr_result[i] = ptr[j]; + } + ARY_SET_LEN(result, n); + return result; +} + /* * call-seq: * ary.sample -> obj @@ -3792,7 +3821,7 @@ rb_ary_shuffle(VALUE ary) static VALUE rb_ary_sample(int argc, VALUE *argv, VALUE ary) { - VALUE nv, result, *ptr; + VALUE nv, replace, result, *ptr; long n, len, i, j, k, idx[10]; len = RARRAY_LEN(ary); @@ -3801,9 +3830,12 @@ rb_ary_sample(int argc, VALUE *argv, VALUE ary) i = len == 1 ? 0 : (long)(rb_genrand_real()*len); return RARRAY_PTR(ary)[i]; } - rb_scan_args(argc, argv, "1", &nv); + rb_scan_args(argc, argv, "12", &nv, &replace); n = NUM2LONG(nv); if (n < 0) rb_raise(rb_eArgError, "negative sample number"); + if (RTEST(replace)) { + return ary_sample_with_replace(ary, n); + } ptr = RARRAY_PTR(ary); len = RARRAY_LEN(ary); if (n > len) n = len; diff --git a/test/ruby/test_array.rb b/test/ruby/test_array.rb index e8edcc2..837ce7b 100644 --- a/test/ruby/test_array.rb +++ b/test/ruby/test_array.rb @@ -1894,7 +1894,7 @@ class TestArray < Test::Unit::TestCase (0..20).each do |n| 100.times do b = a.sample(n) - assert_equal([n, 18].min, b.uniq.size) + assert_equal([n, 18].min, b.size) assert_equal(a, (a | b).sort) assert_equal(b.sort, (a & b).sort) end @@ -1909,6 +1909,62 @@ class TestArray < Test::Unit::TestCase assert_raise(ArgumentError, '[ruby-core:23374]') {[1, 2].sample(-1)} end + def test_sample_without_replace + 100.times do + samples = [2, 1, 0].sample(2, false) + samples.each{|sample| + assert([0, 1, 2].include?(sample)) + } + end + + srand(0) + a = (1..18).to_a + (0..20).each do |n| + 100.times do + b = a.sample(n, false) + assert_equal([n, 18].min, b.size) + assert_equal(a, (a | b).sort) + assert_equal(b.sort, (a & b).sort) + end + + h = Hash.new(0) + 1000.times do + a.sample(n, false).each {|x| h[x] += 1 } + end + assert_operator(h.values.min * 2, :>=, h.values.max) if n != 0 + end + + assert_raise(ArgumentError, '[ruby-core:23374]') {[1, 2].sample(-1, false)} + end + + def test_sample_with_replace + 100.times do + samples = [2, 1, 0].sample(2, true) + samples.each{|sample| + assert([0, 1, 2].include?(sample)) + } + end + + srand(0) + a = (1..18).to_a + (0..20).each do |n| + 100.times do + b = a.sample(n, true) + assert_equal(n, b.size) + assert_equal(a, (a | b).sort) + assert_equal(b.sort.uniq, (a & b).sort) + end + + h = Hash.new(0) + 1000.times do + a.sample(n, true).each {|x| h[x] += 1 } + end + assert_operator(h.values.min * 2, :>=, h.values.max) if n != 0 + end + + assert_raise(ArgumentError, '[ruby-core:23374]') {[1, 2].sample(-1, true)} + end + def test_cycle a = [] [0, 1, 2].cycle do |i|