diff --git a/test/ruby/test_thread.rb b/test/ruby/test_thread.rb index c8081b4..d56c8a0 100644 --- a/test/ruby/test_thread.rb +++ b/test/ruby/test_thread.rb @@ -27,6 +27,32 @@ class TestThread < Test::Unit::TestCase end end + def test_main_thread_local_in_enumerator + assert_equal Thread.main, Thread.current + + Thread.current.set_local :foo, "bar" + + thread, value = Fiber.new { + Fiber.yield [Thread.current, Thread.current.get_local(:foo)] + }.resume + + assert_equal Thread.current, thread + assert_equal Thread.current.get_local(:foo), value + end + + def test_thread_local_in_enumerator + Thread.new { + Thread.current.set_local :foo, "bar" + + thread, value = Fiber.new { + Fiber.yield [Thread.current, Thread.current.get_local(:foo)] + }.resume + + assert_equal Thread.current, thread + assert_equal Thread.current.get_local(:foo), value + }.join + end + def test_mutex_synchronize m = Mutex.new r = 0 diff --git a/thread.c b/thread.c index 8320ed2..85d74a9 100644 --- a/thread.c +++ b/thread.c @@ -2562,6 +2562,24 @@ rb_thread_aset(VALUE self, VALUE id, VALUE val) return rb_thread_local_aset(self, rb_to_id(id), val); } +static VALUE +rb_thread_get_local(VALUE thread, VALUE id) +{ + VALUE locals; + + locals = rb_iv_get(thread, "locals"); + return rb_hash_aref(locals, id); +} + +static VALUE +rb_thread_set_local(VALUE thread, VALUE id, VALUE val) +{ + VALUE locals; + + locals = rb_iv_get(thread, "locals"); + return rb_hash_aset(locals, id, val); +} + /* * call-seq: * thr.key?(sym) -> true or false @@ -4528,6 +4546,8 @@ Init_Thread(void) rb_define_method(rb_cThread, "wakeup", rb_thread_wakeup, 0); rb_define_method(rb_cThread, "[]", rb_thread_aref, 1); rb_define_method(rb_cThread, "[]=", rb_thread_aset, 2); + rb_define_method(rb_cThread, "get_local", rb_thread_get_local, 1); + rb_define_method(rb_cThread, "set_local", rb_thread_set_local, 2); rb_define_method(rb_cThread, "key?", rb_thread_key_p, 1); rb_define_method(rb_cThread, "keys", rb_thread_keys, 0); rb_define_method(rb_cThread, "priority", rb_thread_priority, 0); diff --git a/vm.c b/vm.c index 348d134..d799213 100644 --- a/vm.c +++ b/vm.c @@ -1832,6 +1832,7 @@ ruby_thread_init(VALUE self) GetThreadPtr(self, th); th_init(th, self); + rb_iv_set(self, "locals", rb_hash_new()); th->vm = vm; th->top_wrapper = 0; @@ -2165,6 +2166,7 @@ Init_VM(void) /* create main thread */ th_self = th->self = TypedData_Wrap_Struct(rb_cThread, &thread_data_type, th); + rb_iv_set(th_self, "locals", rb_hash_new()); vm->main_thread = th; vm->running_thread = th; th->vm = vm;