Feature #3647 ยป array_sample_with_replace_hash.patch
array.c | ||
---|---|---|
24 | 24 | |
25 | 25 |
static ID id_cmp; |
26 | 26 | |
27 |
static VALUE sym_replace; |
|
28 | ||
27 | 29 |
#define ARY_DEFAULT_SIZE 16 |
28 | 30 |
#define ARY_MAX_SIZE (LONG_MAX / (int)sizeof(VALUE)) |
29 | 31 | |
... | ... | |
3787 | 3789 | |
3788 | 3790 |
switch (n) { |
3789 | 3791 |
case 0: |
3790 |
return rb_ary_new2(0);
|
|
3792 |
return rb_ary_new2(0);
|
|
3791 | 3793 |
case 1: |
3792 |
return rb_ary_new4(1, &ptr[(long)(rb_genrand_real()*len)]);
|
|
3794 |
return rb_ary_new4(1, &ptr[(long)(rb_genrand_real()*len)]);
|
|
3793 | 3795 |
default: |
3794 |
break;
|
|
3796 |
break;
|
|
3795 | 3797 |
} |
3796 | 3798 |
result = rb_ary_new2(n); |
3797 | 3799 |
ptr_result = RARRAY_PTR(result); |
3798 | 3800 |
RB_GC_GUARD(ary); |
3799 | 3801 |
for (i = 0; i < n; ++i) { |
3800 |
long const j = (long)(rb_genrand_real()*len);
|
|
3801 |
ptr_result[i] = ptr[j];
|
|
3802 |
long const j = (long)(rb_genrand_real()*len);
|
|
3803 |
ptr_result[i] = ptr[j];
|
|
3802 | 3804 |
} |
3803 | 3805 |
ARY_SET_LEN(result, n); |
3804 | 3806 |
return result; |
... | ... | |
3821 | 3823 |
static VALUE |
3822 | 3824 |
rb_ary_sample(int argc, VALUE *argv, VALUE ary) |
3823 | 3825 |
{ |
3824 |
VALUE nv, replace, result, *ptr;
|
|
3826 |
VALUE nv, opts, replace=Qfalse, result, *ptr;
|
|
3825 | 3827 |
long n, len, i, j, k, idx[10]; |
3826 | 3828 | |
3827 | 3829 |
len = RARRAY_LEN(ary); |
... | ... | |
3830 | 3832 |
i = len == 1 ? 0 : (long)(rb_genrand_real()*len); |
3831 | 3833 |
return RARRAY_PTR(ary)[i]; |
3832 | 3834 |
} |
3833 |
rb_scan_args(argc, argv, "12", &nv, &replace);
|
|
3835 |
rb_scan_args(argc, argv, "12", &nv, &opts);
|
|
3834 | 3836 |
n = NUM2LONG(nv); |
3835 | 3837 |
if (n < 0) rb_raise(rb_eArgError, "negative sample number"); |
3836 |
if (RTEST(replace)) {
|
|
3837 |
return ary_sample_with_replace(ary, n);
|
|
3838 |
if (!NIL_P(opts) && TYPE(opts) == T_HASH) {
|
|
3839 |
replace = rb_hash_aref(opts, sym_replace);
|
|
3838 | 3840 |
} |
3841 |
if (RTEST(replace)) |
|
3842 |
return ary_sample_with_replace(ary, n); |
|
3839 | 3843 |
ptr = RARRAY_PTR(ary); |
3840 | 3844 |
len = RARRAY_LEN(ary); |
3841 | 3845 |
if (n > len) n = len; |
... | ... | |
4641 | 4645 |
rb_define_method(rb_cArray, "drop_while", rb_ary_drop_while, 0); |
4642 | 4646 | |
4643 | 4647 |
id_cmp = rb_intern("<=>"); |
4648 |
sym_replace = ID2SYM(rb_intern("replace")); |
|
4644 | 4649 |
} |
test/ruby/test_array.rb | ||
---|---|---|
1911 | 1911 | |
1912 | 1912 |
def test_sample_without_replace |
1913 | 1913 |
100.times do |
1914 |
samples = [2, 1, 0].sample(2, false) |
|
1914 |
samples = [2, 1, 0].sample(2, replace: false)
|
|
1915 | 1915 |
samples.each{|sample| |
1916 | 1916 |
assert([0, 1, 2].include?(sample)) |
1917 | 1917 |
} |
... | ... | |
1921 | 1921 |
a = (1..18).to_a |
1922 | 1922 |
(0..20).each do |n| |
1923 | 1923 |
100.times do |
1924 |
b = a.sample(n, false) |
|
1924 |
b = a.sample(n, replace: false)
|
|
1925 | 1925 |
assert_equal([n, 18].min, b.size) |
1926 | 1926 |
assert_equal(a, (a | b).sort) |
1927 | 1927 |
assert_equal(b.sort, (a & b).sort) |
... | ... | |
1929 | 1929 | |
1930 | 1930 |
h = Hash.new(0) |
1931 | 1931 |
1000.times do |
1932 |
a.sample(n, false).each {|x| h[x] += 1 } |
|
1932 |
a.sample(n, replace: false).each {|x| h[x] += 1 }
|
|
1933 | 1933 |
end |
1934 | 1934 |
assert_operator(h.values.min * 2, :>=, h.values.max) if n != 0 |
1935 | 1935 |
end |
1936 | 1936 | |
1937 |
assert_raise(ArgumentError, '[ruby-core:23374]') {[1, 2].sample(-1, false)} |
|
1937 |
assert_raise(ArgumentError, '[ruby-core:23374]') {[1, 2].sample(-1, replace: false)}
|
|
1938 | 1938 |
end |
1939 | 1939 | |
1940 | 1940 |
def test_sample_with_replace |
1941 | 1941 |
100.times do |
1942 |
samples = [2, 1, 0].sample(2, true) |
|
1942 |
samples = [2, 1, 0].sample(2, replace: true)
|
|
1943 | 1943 |
samples.each{|sample| |
1944 | 1944 |
assert([0, 1, 2].include?(sample)) |
1945 | 1945 |
} |
... | ... | |
1949 | 1949 |
a = (1..18).to_a |
1950 | 1950 |
(0..20).each do |n| |
1951 | 1951 |
100.times do |
1952 |
b = a.sample(n, true) |
|
1952 |
b = a.sample(n, replace: true)
|
|
1953 | 1953 |
assert_equal(n, b.size) |
1954 | 1954 |
assert_equal(a, (a | b).sort) |
1955 | 1955 |
assert_equal(b.sort.uniq, (a & b).sort) |
... | ... | |
1957 | 1957 | |
1958 | 1958 |
h = Hash.new(0) |
1959 | 1959 |
1000.times do |
1960 |
a.sample(n, true).each {|x| h[x] += 1 } |
|
1960 |
a.sample(n, replace: true).each {|x| h[x] += 1 }
|
|
1961 | 1961 |
end |
1962 | 1962 |
assert_operator(h.values.min * 2, :>=, h.values.max) if n != 0 |
1963 | 1963 |
end |
1964 | 1964 | |
1965 |
assert_raise(ArgumentError, '[ruby-core:23374]') {[1, 2].sample(-1, true)} |
|
1965 |
assert_raise(ArgumentError, '[ruby-core:23374]') {[1, 2].sample(-1, replace: true)}
|
|
1966 | 1966 |
end |
1967 | 1967 | |
1968 | 1968 |
def test_cycle |