Feature #4195
Updated by mame (Yusuke Endoh) about 12 years ago
=begin なかだです。 Socket#recvmsgは scm_rights: true を指定するだけでメインのデータだけで なく簡単にIOを受け取ることができますが、一方でSocket#sendmsg側には対応 する指定ができません。以下のようなオプションを追加するのはどうでしょう か。 s.sendmsg("foo", scm_rights: STDIN) s.sendmsg("foo", scm_rights: [STDIN, STDOUT]) diff --git i/ext/socket/ancdata.c w/ext/socket/ancdata.c index abaf19d..c329e0a 100644 --- i/ext/socket/ancdata.c +++ w/ext/socket/ancdata.c @@ -2,6 +2,8 @@ #include <time.h> +static ID sym_scm_rights; + #if defined(HAVE_ST_MSG_CONTROL) static VALUE rb_cAncillaryData; @@ -1126,17 +1128,63 @@ rb_sendmsg(int fd, const struct msghdr *msg, int flags) return rb_thread_blocking_region(nogvl_sendmsg_func, &args, RUBY_UBF_IO, 0); } +#if defined(HAVE_ST_MSG_CONTROL) +static size_t +io_to_fd(VALUE io) +{ + VALUE fnum = rb_check_to_integer(io, "to_int"); + if (NIL_P(fnum)) + fnum = rb_convert_type(io, T_FIXNUM, "Fixnum", "fileno"); + return NUM2UINT(fnum); +} + +static char * +prepare_msghdr(VALUE controls_str, int level, int type, long clen) +{ + struct cmsghdr cmh; + char *cmsg; + size_t cspace; + long oldlen = RSTRING_LEN(controls_str); + cspace = CMSG_SPACE(clen); + rb_str_resize(controls_str, oldlen + cspace); + cmsg = RSTRING_PTR(controls_str)+oldlen; + memset((char *)cmsg, 0, cspace); + memset((char *)&cmh, 0, sizeof(cmh)); + cmh.cmsg_level = level; + cmh.cmsg_type = type; + cmh.cmsg_len = (socklen_t)CMSG_LEN(clen); + MEMCPY(cmsg, &cmh, char, sizeof(cmh)); + return cmsg+((char*)CMSG_DATA(&cmh)-(char*)&cmh); +} + +# if defined(__NetBSD__) +# define TRIM_PADDING 1 +# endif +# if TRIM_PADDING +# define prepare_msghdr(controls_str, level, type, clen) \ + (last_pad = CMSG_SPACE(clen) - CMSG_LEN(clen), \ + prepare_msghdr((controls_str), \ + last_level = (level), last_type = (type), \ + (clen))) +# endif +#endif + static VALUE bsock_sendmsg_internal(int argc, VALUE *argv, VALUE sock, int nonblock) { rb_io_t *fptr; - VALUE data, vflags, dest_sockaddr; + VALUE data, vflags, dest_sockaddr, vopts = Qnil; VALUE *controls_ptr; int controls_num; struct msghdr mh; struct iovec iov; #if defined(HAVE_ST_MSG_CONTROL) volatile VALUE controls_str = 0; +# if TRIM_PADDING + size_t last_pad = 0; + int last_level = 0; + int last_type = 0; +# endif #endif int flags; ssize_t ss; @@ -1152,6 +1200,8 @@ bsock_sendmsg_internal(int argc, VALUE *argv, VALUE sock, int nonblock) if (argc == 0) rb_raise(rb_eArgError, "mesg argument required"); + if (1 < argc && RB_TYPE_P(argv[argc-1], T_HASH)) + vopts = argv[--argc]; data = argv[0]; if (1 < argc) vflags = argv[1]; if (2 < argc) dest_sockaddr = argv[2]; @@ -1162,19 +1212,13 @@ bsock_sendmsg_internal(int argc, VALUE *argv, VALUE sock, int nonblock) if (controls_num) { #if defined(HAVE_ST_MSG_CONTROL) int i; - size_t last_pad = 0; - int last_level = 0; - int last_type = 0; controls_str = rb_str_tmp_new(0); for (i = 0; i < controls_num; i++) { VALUE elt = controls_ptr[i], v; VALUE vlevel, vtype; int level, type; VALUE cdata; - long oldlen; - struct cmsghdr cmh; char *cmsg; - size_t cspace; v = rb_check_convert_type(elt, T_ARRAY, "Array", "to_ary"); if (!NIL_P(v)) { elt = v; @@ -1192,21 +1236,46 @@ bsock_sendmsg_internal(int argc, VALUE *argv, VALUE sock, int nonblock) level = rsock_level_arg(family, vlevel); type = rsock_cmsg_type_arg(family, level, vtype); StringValue(cdata); - oldlen = RSTRING_LEN(controls_str); - cspace = CMSG_SPACE(RSTRING_LEN(cdata)); - rb_str_resize(controls_str, oldlen + cspace); - cmsg = RSTRING_PTR(controls_str)+oldlen; - memset((char *)cmsg, 0, cspace); - memset((char *)&cmh, 0, sizeof(cmh)); - cmh.cmsg_level = level; - cmh.cmsg_type = type; - cmh.cmsg_len = (socklen_t)CMSG_LEN(RSTRING_LEN(cdata)); - MEMCPY(cmsg, &cmh, char, sizeof(cmh)); - MEMCPY(cmsg+((char*)CMSG_DATA(&cmh)-(char*)&cmh), RSTRING_PTR(cdata), char, RSTRING_LEN(cdata)); - last_level = cmh.cmsg_level; - last_type = cmh.cmsg_type; - last_pad = cspace - cmh.cmsg_len; + cmsg = prepare_msghdr(controls_str, level, type, RSTRING_LEN(cdata)); + MEMCPY(cmsg, RSTRING_PTR(cdata), char, RSTRING_LEN(cdata)); } +#else + no_msg_control: + rb_raise(rb_eNotImpError, "control message for sendmsg is unimplemented"); +#endif + } + if (!NIL_P(vopts)) { + VALUE rights = rb_hash_aref(vopts, sym_scm_rights); + if (!NIL_P(rights)) { +#if defined(HAVE_ST_MSG_CONTROL) + VALUE tmp = rb_check_array_type(rights); + long count = NIL_P(tmp) ? 1 : RARRAY_LEN(tmp); + char *cmsg; + int fd; + if (!controls_str) controls_str = rb_str_tmp_new(0); + cmsg = prepare_msghdr(controls_str, SOL_SOCKET, SCM_RIGHTS, + count * sizeof(int)); + if (NIL_P(tmp)) { + fd = io_to_fd(rights); + MEMCPY(cmsg, &fd, int, 1); + } + else { + long i; + rights = tmp; + for (i = 0; i < count && i < RARRAY_LEN(rights); ++i) { + fd = io_to_fd(RARRAY_PTR(rights)[i]); + MEMCPY(cmsg, &fd, int, 1); + cmsg += sizeof(int); + } + } +#else + goto no_msg_control; +#endif + } + } +#if defined(HAVE_ST_MSG_CONTROL) + { +# if TRIM_PADDING if (last_pad) { /* * This code removes the last padding from msg_controllen. @@ -1228,15 +1297,12 @@ bsock_sendmsg_internal(int argc, VALUE *argv, VALUE sock, int nonblock) * Basically, msg_controllen should contains the padding. * So the padding is removed only if a problem really exists. */ -#if defined(__NetBSD__) if (last_level == SOL_SOCKET && last_type == SCM_RIGHTS) rb_str_set_len(controls_str, RSTRING_LEN(controls_str)-last_pad); -#endif } -#else - rb_raise(rb_eNotImpError, "control message for sendmsg is unimplemented"); -#endif +# endif } +#endif flags = NIL_P(vflags) ? 0 : NUM2INT(vflags); #ifdef MSG_DONTWAIT @@ -1492,7 +1558,7 @@ bsock_recvmsg_internal(int argc, VALUE *argv, VALUE sock, int nonblock) grow_buffer = NIL_P(vmaxdatlen) || NIL_P(vmaxctllen); request_scm_rights = 0; - if (!NIL_P(vopts) && RTEST(rb_hash_aref(vopts, ID2SYM(rb_intern("scm_rights"))))) + if (!NIL_P(vopts) && RTEST(rb_hash_aref(vopts, sym_scm_rights))) request_scm_rights = 1; GetOpenFile(sock, fptr); @@ -1795,5 +1861,7 @@ rsock_init_ancdata(void) rb_define_method(rb_cAncillaryData, "ipv6_pktinfo", ancillary_ipv6_pktinfo, 0); rb_define_method(rb_cAncillaryData, "ipv6_pktinfo_addr", ancillary_ipv6_pktinfo_addr, 0); rb_define_method(rb_cAncillaryData, "ipv6_pktinfo_ifindex", ancillary_ipv6_pktinfo_ifindex, 0); + + sym_scm_rights = ID2SYM(rb_intern("scm_rights")); #endif } diff --git i/test/socket/test_unix.rb w/test/socket/test_unix.rb index bde17cf..e9db22e 100644 --- i/test/socket/test_unix.rb +++ w/test/socket/test_unix.rb @@ -31,7 +31,7 @@ class TestSocket_UNIXSocket < Test::Unit::TestCase end end - def test_fd_passing_n + def fd_passing_test io_ary = [] return if !defined?(Socket::SCM_RIGHTS) io_ary.concat IO.pipe @@ -42,8 +42,7 @@ class TestSocket_UNIXSocket < Test::Unit::TestCase send_io_ary << io UNIXSocket.pair {|s1, s2| begin - ret = s1.sendmsg("\0", 0, nil, [Socket::SOL_SOCKET, Socket::SCM_RIGHTS, - send_io_ary.map {|io2| io2.fileno }.pack("i!*")]) + ret = yield(s1, send_io_ary) rescue NotImplementedError return end @@ -66,48 +65,38 @@ class TestSocket_UNIXSocket < Test::Unit::TestCase io_ary.each {|io| io.close if !io.closed? } end + def test_fd_passing_n + fd_passing_test do |s, ios| + s.sendmsg("\0", 0, nil, + [Socket::SOL_SOCKET, Socket::SCM_RIGHTS, ios.map(&:fileno).pack("i!*")]) + end + end + def test_fd_passing_n2 - io_ary = [] - return if !defined?(Socket::SCM_RIGHTS) - return if !defined?(Socket::AncillaryData) - io_ary.concat IO.pipe - io_ary.concat IO.pipe - io_ary.concat IO.pipe - send_io_ary = [] - io_ary.each {|io| - send_io_ary << io - UNIXSocket.pair {|s1, s2| - begin - ancdata = Socket::AncillaryData.unix_rights(*send_io_ary) - ret = s1.sendmsg("\0", 0, nil, ancdata) - rescue NotImplementedError - return - end - assert_equal(1, ret) - ret = s2.recvmsg(:scm_rights=>true) - data, srcaddr, flags, *ctls = ret - recv_io_ary = [] - ctls.each {|ctl| - next if ctl.level != Socket::SOL_SOCKET || ctl.type != Socket::SCM_RIGHTS - recv_io_ary.concat ctl.unix_rights - } - assert_equal(send_io_ary.length, recv_io_ary.length) - send_io_ary.length.times {|i| - assert_not_equal(send_io_ary[i].fileno, recv_io_ary[i].fileno) - assert(File.identical?(send_io_ary[i], recv_io_ary[i])) - } - } - } - ensure - io_ary.each {|io| io.close if !io.closed? } + fd_passing_test do |s, ios| + ancdata = Socket::AncillaryData.unix_rights(*ios) + s.sendmsg("\0", 0, nil, ancdata) + end + end + + def test_fd_passing_n3 + fd_passing_test do |s, ios| + s.sendmsg("\0", 0, nil, scm_rights: ios.map(&:fileno)) + end + end + + def test_fd_passing_n4 + fd_passing_test do |s, ios| + s.sendmsg("\0", 0, nil, scm_rights: ios) + end end - def test_sendmsg + def sendmsg_test return if !defined?(Socket::SCM_RIGHTS) IO.pipe {|r1, w| UNIXSocket.pair {|s1, s2| begin - ret = s1.sendmsg("\0", 0, nil, [Socket::SOL_SOCKET, Socket::SCM_RIGHTS, [r1.fileno].pack("i!")]) + ret = yield(s1, r1) rescue NotImplementedError return end @@ -122,6 +111,24 @@ class TestSocket_UNIXSocket < Test::Unit::TestCase } end + def test_sendmsg_1 + sendmsg_test do |s, r| + s.sendmsg("\0", 0, nil, [Socket::SOL_SOCKET, Socket::SCM_RIGHTS, [r.fileno].pack("i!")]) + end + end + + def test_sendmsg_2 + sendmsg_test do |s, r| + s.sendmsg("\0", 0, nil, scm_rights: r.fileno) + end + end + + def test_sendmsg_3 + sendmsg_test do |s, r| + s.sendmsg("\0", 0, nil, scm_rights: r) + end + end + def test_sendmsg_ancillarydata_int return if !defined?(Socket::SCM_RIGHTS) return if !defined?(Socket::AncillaryData) -- --- 僕の前にBugはない。 --- 僕の後ろにBugはできる。 中田 伸悦 =end