From a1df4a7b3aff940164ec88e5cbc27c031bafdcb8 Mon Sep 17 00:00:00 2001
From: Dylan Thacker-Smith <Dylan.Smith@shopify.com>
Date: Thu, 22 Aug 2019 14:08:35 -0400
Subject: [PATCH] Optimize Array#flatten and flatten! for already flattened
 arrays

---
 array.c | 43 +++++++++++++++++++++++++++++++++----------
 1 file changed, 33 insertions(+), 10 deletions(-)

diff --git a/array.c b/array.c
index bc5c719267..d1947e97fa 100644
--- a/array.c
+++ b/array.c
@@ -5112,18 +5112,40 @@ rb_ary_count(int argc, VALUE *argv, VALUE ary)
 }
 
 static VALUE
-flatten(VALUE ary, int level, int *modified)
+flatten(VALUE ary, int level)
 {
-    long i = 0;
+    long i;
     VALUE stack, result, tmp, elt;
     st_table *memo;
     st_data_t id;
 
-    stack = ary_new(0, ARY_DEFAULT_SIZE);
+    for (i = 0; i < RARRAY_LEN(ary); i++) {
+        elt = RARRAY_AREF(ary, i);
+        tmp = rb_check_array_type(elt);
+        if (!NIL_P(tmp)) {
+            break;
+        }
+    }
+    if (i == RARRAY_LEN(ary)) {
+        return ary;
+    } else if (tmp == ary) {
+        rb_raise(rb_eArgError, "tried to flatten recursive array");
+    }
+
     result = ary_new(0, RARRAY_LEN(ary));
+    ary_memcpy(result, 0, i, RARRAY_CONST_PTR_TRANSIENT(ary));
+    ARY_SET_LEN(result, i);
+
+    stack = ary_new(0, ARY_DEFAULT_SIZE);
+    rb_ary_push(stack, ary);
+    rb_ary_push(stack, LONG2NUM(i + 1));
+
     memo = st_init_numtable();
     st_insert(memo, (st_data_t)ary, (st_data_t)Qtrue);
-    *modified = 0;
+    st_insert(memo, (st_data_t)tmp, (st_data_t)Qtrue);
+
+    ary = tmp;
+    i = 0;
 
     while (1) {
 	while (i < RARRAY_LEN(ary)) {
@@ -5140,7 +5162,6 @@ flatten(VALUE ary, int level, int *modified)
 		rb_ary_push(result, elt);
 	    }
 	    else {
-		*modified = 1;
 		id = (st_data_t)tmp;
 		if (st_lookup(memo, id, 0)) {
 		    st_free_table(memo);
@@ -5200,9 +5221,8 @@ rb_ary_flatten_bang(int argc, VALUE *argv, VALUE ary)
     if (!NIL_P(lv)) level = NUM2INT(lv);
     if (level == 0) return Qnil;
 
-    result = flatten(ary, level, &mod);
-    if (mod == 0) {
-	ary_discard(result);
+    result = flatten(ary, level);
+    if (result == ary) {
 	return Qnil;
     }
     if (!(mod = ARY_EMBED_P(result))) rb_obj_freeze(result);
@@ -5237,7 +5257,7 @@ rb_ary_flatten_bang(int argc, VALUE *argv, VALUE ary)
 static VALUE
 rb_ary_flatten(int argc, VALUE *argv, VALUE ary)
 {
-    int mod = 0, level = -1;
+    int level = -1;
     VALUE result;
 
     if (rb_check_arity(argc, 0, 1) && !NIL_P(argv[0])) {
@@ -5245,7 +5265,10 @@ rb_ary_flatten(int argc, VALUE *argv, VALUE ary)
         if (level == 0) return ary_make_shared_copy(ary);
     }
 
-    result = flatten(ary, level, &mod);
+    result = flatten(ary, level);
+    if (result == ary) {
+        result = ary_make_shared_copy(ary);
+    }
     OBJ_INFECT(result, ary);
 
     return result;
-- 
2.21.0

