diff --git a/NEWS b/NEWS index 0e4508e..b20cebb 100644 --- a/NEWS +++ b/NEWS @@ -81,6 +81,16 @@ with all sufficient information, see the ChangeLog file. * added Struct#to_h returning values with keys corresponding to the instance variable names. + * Thread + * added method: + * added Thread#thread_variable_get for getting thread local variables + (these are different than Fiber local variables). + * added Thread#thread_variable_set for setting thread local variables. + * added Thread#thread_variables for getting a list of the thread local + variable keys. + * added Thread#thread_variable? for testing to see if a particular thread + variable has been set. + * Time * change return value: * Time#to_s returned encoding defaults to US-ASCII but automatically diff --git a/test/ruby/test_thread.rb b/test/ruby/test_thread.rb index c8081b4..145e021 100644 --- a/test/ruby/test_thread.rb +++ b/test/ruby/test_thread.rb @@ -27,6 +27,79 @@ class TestThread < Test::Unit::TestCase end end + def test_main_thread_variable_in_enumerator + assert_equal Thread.main, Thread.current + + Thread.current.thread_variable_set :foo, "bar" + + thread, value = Fiber.new { + Fiber.yield [Thread.current, Thread.current.thread_variable_get(:foo)] + }.resume + + assert_equal Thread.current, thread + assert_equal Thread.current.thread_variable_get(:foo), value + end + + def test_thread_variable_in_enumerator + Thread.new { + Thread.current.thread_variable_set :foo, "bar" + + thread, value = Fiber.new { + Fiber.yield [Thread.current, Thread.current.thread_variable_get(:foo)] + }.resume + + assert_equal Thread.current, thread + assert_equal Thread.current.thread_variable_get(:foo), value + }.join + end + + def test_thread_variables + assert_equal [], Thread.new { Thread.current.thread_variables }.join.value + + t = Thread.new { + Thread.current.thread_variable_set(:foo, "bar") + Thread.current.thread_variables + } + assert_equal [:foo], t.join.value + end + + def test_thread_variable? + refute Thread.new { Thread.current.thread_variable?("foo") }.join.value + t = Thread.new { + Thread.current.thread_variable_set("foo", "bar") + }.join + + assert t.thread_variable?("foo") + assert t.thread_variable?(:foo) + refute t.thread_variable?(:bar) + end + + def test_thread_variable_strings_and_symbols_are_the_same_key + t = Thread.new {}.join + t.thread_variable_set("foo", "bar") + assert_equal "bar", t.thread_variable_get(:foo) + end + + def test_thread_variable_frozen + t = Thread.new { }.join + t.freeze + assert_raises(RuntimeError) do + t.thread_variable_set(:foo, "bar") + end + end + + def test_thread_variable_security + t = Thread.new { sleep } + + assert_raises(SecurityError) do + Thread.new { $SAFE = 4; t.thread_variable_get(:foo) }.join + end + + assert_raises(SecurityError) do + Thread.new { $SAFE = 4; t.thread_variable_set(:foo, :baz) }.join + end + end + def test_mutex_synchronize m = Mutex.new r = 0 diff --git a/thread.c b/thread.c index b512566..b21a088 100644 --- a/thread.c +++ b/thread.c @@ -2525,6 +2525,9 @@ rb_thread_local_aref(VALUE thread, ID id) * #=> nil if fiber-local * #=> 2 if thread-local (The value 2 is leaked to outside of meth method.) * + * For thread-local variables, please see Thread#thread_local_get + * and Thread#thread_local_set. + * */ static VALUE @@ -2561,7 +2564,9 @@ rb_thread_local_aset(VALUE thread, ID id, VALUE val) * thr[sym] = obj -> obj * * Attribute Assignment---Sets or creates the value of a fiber-local variable, - * using either a symbol or a string. See also Thread#[]. + * using either a symbol or a string. See also Thread#[]. For + * thread-local variables, please see Thread#thread_variable_set + * and Thread#thread_variable_get. */ static VALUE @@ -2572,6 +2577,80 @@ rb_thread_aset(VALUE self, VALUE id, VALUE val) /* * call-seq: + * thr.thread_variable_get(key) -> obj or nil + * + * Returns the value of a thread local variable that has been set. Note that + * these are different than fiber local values. For fiber local values, + * please see Thread#[] and Thread#[]=. + * + * Thread local values are carried along with threads, and do not respect + * fibers. For example: + * + * Thread.new { + * Thread.current.thread_variable_set("foo", "bar") # set a thread local + * Thread.current["foo"] = "bar" # set a fiber local + * + * Fiber.new { + * Fiber.yield [ + * Thread.current.thread_variable_get("foo"), # get the thread local + * Thread.current["foo"], # get the fiber local + * ] + * }.resume + * }.join.value # => ['bar', nil] + * + * The value "bar" is returned for the thread local, where nil is returned + * for the fiber local. The fiber is executed in the same thread, so the + * thread local values are available. + * + * See also Thread#[] + */ + +static VALUE +rb_thread_variable_get(VALUE thread, VALUE id) +{ + VALUE locals; + rb_thread_t *th; + + GetThreadPtr(thread, th); + + if (rb_safe_level() >= 4 && th != GET_THREAD()) { + rb_raise(rb_eSecurityError, "Insecure: can't modify thread locals"); + } + + locals = rb_iv_get(thread, "locals"); + return rb_hash_aref(locals, ID2SYM(rb_to_id(id))); +} + +/* + * call-seq: + * thr.thread_variable_set(key, value) + * + * Sets a thread local with +key+ to +value+. Note that these are local to + * threads, and not to fibers. Please see Thread#thread_variable_get and + * Thread#[] for more information. + */ + +static VALUE +rb_thread_variable_set(VALUE thread, VALUE id, VALUE val) +{ + VALUE locals; + rb_thread_t *th; + + GetThreadPtr(thread, th); + + if (rb_safe_level() >= 4 && th != GET_THREAD()) { + rb_raise(rb_eSecurityError, "Insecure: can't modify thread locals"); + } + if (OBJ_FROZEN(thread)) { + rb_error_frozen("thread locals"); + } + + locals = rb_iv_get(thread, "locals"); + return rb_hash_aset(locals, ID2SYM(rb_to_id(id)), val); +} + +/* + * call-seq: * thr.key?(sym) -> true or false * * Returns true if the given string (or symbol) exists as a @@ -2651,6 +2730,76 @@ rb_thread_keys(VALUE self) return ary; } +static int +keys_i(VALUE key, VALUE value, VALUE ary) +{ + rb_ary_push(ary, key); + return ST_CONTINUE; +} + +/* + * call-seq: + * thr.thread_variables -> array + * + * Returns an an array of the names of the thread-local variables (as Symbols). + * + * thr = Thread.new do + * Thread.current.thread_variable_set(:cat, 'meow') + * Thread.current.thread_variable_set("dog", 'woof') + * end + * thr.join #=> # + * thr.thread_variables #=> [:dog, :cat] + * + * Note that these are not fiber local variables. Please see Thread#[] and + * Thread#thread_variable_get for more details. + */ + +static VALUE +rb_thread_variables(VALUE thread) +{ + VALUE locals; + VALUE ary; + + locals = rb_iv_get(thread, "locals"); + ary = rb_ary_new(); + rb_hash_foreach(locals, keys_i, ary); + + return ary; +} + +/* + * call-seq: + * thr.thread_variable?(key) -> true or false + * + * Returns true if the given string (or symbol) exists as a + * thread-local variable. + * + * me = Thread.current + * me.thread_variable_set(:oliver, "a") + * me.thread_variable?(:oliver) #=> true + * me.thread_variable?(:stanley) #=> false + * + * Note that these are not fiber local variables. Please see Thread#[] and + * Thread#thread_variable_get for more details. + */ + +static VALUE +rb_thread_variable_p(VALUE thread, VALUE key) +{ + VALUE locals; + + locals = rb_iv_get(thread, "locals"); + + if (!RHASH(locals)->ntbl) + return Qfalse; + + if (st_lookup(RHASH(locals)->ntbl, ID2SYM(rb_to_id(key)), 0)) { + return Qtrue; + } + + return Qfalse; +} + /* * call-seq: * thr.priority -> integer @@ -4541,6 +4690,10 @@ Init_Thread(void) rb_define_method(rb_cThread, "priority", rb_thread_priority, 0); rb_define_method(rb_cThread, "priority=", rb_thread_priority_set, 1); rb_define_method(rb_cThread, "status", rb_thread_status, 0); + rb_define_method(rb_cThread, "thread_variable_get", rb_thread_variable_get, 1); + rb_define_method(rb_cThread, "thread_variable_set", rb_thread_variable_set, 2); + rb_define_method(rb_cThread, "thread_variables", rb_thread_variables, 0); + rb_define_method(rb_cThread, "thread_variable?", rb_thread_variable_p, 1); rb_define_method(rb_cThread, "alive?", rb_thread_alive_p, 0); rb_define_method(rb_cThread, "stop?", rb_thread_stop_p, 0); rb_define_method(rb_cThread, "abort_on_exception", rb_thread_abort_exc, 0); diff --git a/vm.c b/vm.c index c47b592..5269a50 100644 --- a/vm.c +++ b/vm.c @@ -1838,6 +1838,7 @@ ruby_thread_init(VALUE self) GetThreadPtr(self, th); th_init(th, self); + rb_iv_set(self, "locals", rb_hash_new()); th->vm = vm; th->top_wrapper = 0; @@ -2185,6 +2186,7 @@ Init_VM(void) /* create main thread */ th_self = th->self = TypedData_Wrap_Struct(rb_cThread, &thread_data_type, th); + rb_iv_set(th_self, "locals", rb_hash_new()); vm->main_thread = th; vm->running_thread = th; th->vm = vm;