From a1df4a7b3aff940164ec88e5cbc27c031bafdcb8 Mon Sep 17 00:00:00 2001 From: Dylan Thacker-Smith Date: Thu, 22 Aug 2019 14:08:35 -0400 Subject: [PATCH] Optimize Array#flatten and flatten! for already flattened arrays --- array.c | 43 +++++++++++++++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/array.c b/array.c index bc5c719267..d1947e97fa 100644 --- a/array.c +++ b/array.c @@ -5112,18 +5112,40 @@ rb_ary_count(int argc, VALUE *argv, VALUE ary) } static VALUE -flatten(VALUE ary, int level, int *modified) +flatten(VALUE ary, int level) { - long i = 0; + long i; VALUE stack, result, tmp, elt; st_table *memo; st_data_t id; - stack = ary_new(0, ARY_DEFAULT_SIZE); + for (i = 0; i < RARRAY_LEN(ary); i++) { + elt = RARRAY_AREF(ary, i); + tmp = rb_check_array_type(elt); + if (!NIL_P(tmp)) { + break; + } + } + if (i == RARRAY_LEN(ary)) { + return ary; + } else if (tmp == ary) { + rb_raise(rb_eArgError, "tried to flatten recursive array"); + } + result = ary_new(0, RARRAY_LEN(ary)); + ary_memcpy(result, 0, i, RARRAY_CONST_PTR_TRANSIENT(ary)); + ARY_SET_LEN(result, i); + + stack = ary_new(0, ARY_DEFAULT_SIZE); + rb_ary_push(stack, ary); + rb_ary_push(stack, LONG2NUM(i + 1)); + memo = st_init_numtable(); st_insert(memo, (st_data_t)ary, (st_data_t)Qtrue); - *modified = 0; + st_insert(memo, (st_data_t)tmp, (st_data_t)Qtrue); + + ary = tmp; + i = 0; while (1) { while (i < RARRAY_LEN(ary)) { @@ -5140,7 +5162,6 @@ flatten(VALUE ary, int level, int *modified) rb_ary_push(result, elt); } else { - *modified = 1; id = (st_data_t)tmp; if (st_lookup(memo, id, 0)) { st_free_table(memo); @@ -5200,9 +5221,8 @@ rb_ary_flatten_bang(int argc, VALUE *argv, VALUE ary) if (!NIL_P(lv)) level = NUM2INT(lv); if (level == 0) return Qnil; - result = flatten(ary, level, &mod); - if (mod == 0) { - ary_discard(result); + result = flatten(ary, level); + if (result == ary) { return Qnil; } if (!(mod = ARY_EMBED_P(result))) rb_obj_freeze(result); @@ -5237,7 +5257,7 @@ rb_ary_flatten_bang(int argc, VALUE *argv, VALUE ary) static VALUE rb_ary_flatten(int argc, VALUE *argv, VALUE ary) { - int mod = 0, level = -1; + int level = -1; VALUE result; if (rb_check_arity(argc, 0, 1) && !NIL_P(argv[0])) { @@ -5245,7 +5265,10 @@ rb_ary_flatten(int argc, VALUE *argv, VALUE ary) if (level == 0) return ary_make_shared_copy(ary); } - result = flatten(ary, level, &mod); + result = flatten(ary, level); + if (result == ary) { + result = ary_make_shared_copy(ary); + } OBJ_INFECT(result, ary); return result; -- 2.21.0