diff --git a/thread.c b/thread.c index 55ba49d..a8a70ee 100644 --- a/thread.c +++ b/thread.c @@ -420,6 +420,8 @@ ruby_thread_init_stack(rb_thread_t *th) native_thread_init_stack(th); } +static void thread_run_at_exit(rb_thread_t *th); + static int thread_start_func_2(rb_thread_t *th, VALUE *stack_start, VALUE *register_stack_start) { @@ -488,6 +490,8 @@ thread_start_func_2(rb_thread_t *th, VALUE *stack_start, VALUE *register_stack_s th->value = Qnil; } + thread_run_at_exit(th); + th->status = THREAD_KILLED; thread_debug("thread end: %p\n", (void *)th); @@ -2035,6 +2039,79 @@ rb_thread_inspect(VALUE thread) return str; } +static void +thread_run_at_exit(rb_thread_t *th) +{ + int i; + VALUE args, at_exit; + + if(th->at_exit){ + at_exit = th->at_exit; + args = rb_ary_new3(1, GET_THREAD()->self); + for (i=RARRAY_LEN(at_exit)-1; i>=0; i--) { + rb_proc_call(RARRAY_PTR(at_exit)[i], args); + } + } + return; +} + +void +rb_call_at_exit_proc(VALUE data) +{ + VALUE args; + args = rb_ary_new3(1, GET_THREAD()->self); + rb_proc_call(data, args); +} + +struct thread_define_at_exit_arg { + rb_thread_t *th; + VALUE proc; +}; + +static VALUE +thread_define_at_exit(VALUE arg) +{ + rb_thread_t *th = ((struct thread_define_at_exit_arg *)arg)->th; + VALUE proc = ((struct thread_define_at_exit_arg *)arg)->proc; + + if (th == th->vm->main_thread) { + rb_set_end_proc(rb_call_at_exit_proc, proc); + } else { + if (th->at_exit) { + rb_ary_push(th->at_exit, proc); + } else { + th->at_exit = rb_ary_new3(1, proc); + } + } + return th->self; +} + +VALUE +rb_thread_define_at_exit(VALUE thread) +{ + rb_thread_t *th; + struct thread_define_at_exit_arg arg; + + if (!rb_block_given_p()) { + rb_raise(rb_eArgError, "called without a block"); + } + + GetThreadPtr(thread, th); + + if (rb_threadptr_dead(th)) { + return Qnil; + } + + arg.th = th; + arg.proc = rb_block_proc(); + + if (!th->at_exit_lock) { + th->at_exit_lock = rb_mutex_new(); + } + + return rb_mutex_synchronize(th->at_exit_lock, thread_define_at_exit, (VALUE)&arg); +} + VALUE rb_thread_local_aref(VALUE thread, ID id) { @@ -4609,6 +4686,7 @@ Init_Thread(void) rb_define_method(rb_cThread, "safe_level", rb_thread_safe_level, 0); rb_define_method(rb_cThread, "group", rb_thread_group, 0); rb_define_method(rb_cThread, "backtrace", rb_thread_backtrace_m, 0); + rb_define_method(rb_cThread, "at_exit", rb_thread_define_at_exit, 0); rb_define_method(rb_cThread, "inspect", rb_thread_inspect, 0); diff --git a/vm.c b/vm.c index 8621709..dd9bd24 100644 --- a/vm.c +++ b/vm.c @@ -1748,6 +1748,8 @@ rb_thread_mark(void *ptr) RUBY_MARK_UNLESS_NULL(th->root_fiber); RUBY_MARK_UNLESS_NULL(th->stat_insn_usage); RUBY_MARK_UNLESS_NULL(th->last_status); + RUBY_MARK_UNLESS_NULL(th->at_exit); + RUBY_MARK_UNLESS_NULL(th->at_exit_lock); RUBY_MARK_UNLESS_NULL(th->locking_mutex); diff --git a/vm_core.h b/vm_core.h index 9b72839..93452b6 100644 --- a/vm_core.h +++ b/vm_core.h @@ -455,6 +455,9 @@ typedef struct rb_thread_struct { /* storage */ st_table *local_storage; + VALUE at_exit; + VALUE at_exit_lock; + struct rb_thread_struct *join_list_next; struct rb_thread_struct *join_list_head;