diff --git a/io.c b/io.c index 491442b..34465ae 100644 --- a/io.c +++ b/io.c @@ -1482,16 +1482,220 @@ io_write(VALUE io, VALUE str, int nosync) return LONG2FIX(n); } +#ifdef HAVE_WRITEV +struct binwritev_arg { + rb_io_t *fptr; + const struct iovec *iov; + int iovcnt; +}; + +static VALUE +call_writev_internal(VALUE arg) +{ + struct binwritev_arg *p = (struct binwritev_arg *)arg; + return rb_writev_internal(p->fptr->fd, p->iov, p->iovcnt); +} + +static long +io_binwritev(struct iovec *iov, int iovcnt, rb_io_t *fptr) +{ + int i; + long r, total = 0, written_len = 0; + + /* don't write anything if current thread has a pending interrupt. */ + rb_thread_check_ints(); + + if (iovcnt == 0) return 0; + for (i = 1; i < iovcnt; i++) total += iov[i].iov_len; + + if (fptr->wbuf.ptr == NULL && !(fptr->mode & FMODE_SYNC)) { + fptr->wbuf.off = 0; + fptr->wbuf.len = 0; + fptr->wbuf.capa = IO_WBUF_CAPA_MIN; + fptr->wbuf.ptr = ALLOC_N(char, fptr->wbuf.capa); + fptr->write_lock = rb_mutex_new(); + rb_mutex_allow_trap(fptr->write_lock, 1); + } + + if (fptr->wbuf.ptr && fptr->wbuf.len) { + if (fptr->wbuf.off + fptr->wbuf.len + total <= fptr->wbuf.capa) { + long offset = fptr->wbuf.off; + for (i = 1; i < iovcnt; i++) { + memcpy(fptr->wbuf.ptr+offset, iov[i].iov_base, iov[i].iov_len); + offset += iov[i].iov_len; + } + fptr->wbuf.len += total; + return total; + } + else { + iov[0].iov_base = fptr->wbuf.ptr + fptr->wbuf.off; + iov[0].iov_len = fptr->wbuf.len; + } + } + else { + iov++; + iovcnt--; + } + + retry: + if (fptr->write_lock) { + struct binwritev_arg arg; + arg.fptr = fptr; + arg.iov = iov; + arg.iovcnt = iovcnt; + r = rb_mutex_synchronize(fptr->write_lock, call_writev_internal, (VALUE)&arg); + } + else { + r = rb_writev_internal(fptr->fd, iov, iovcnt); + } + + if (r >= 0) { + written_len += r; + if (fptr->wbuf.ptr && fptr->wbuf.len) { + if (written_len < fptr->wbuf.len) { + fptr->wbuf.off += r; + fptr->wbuf.len -= r; + } + else { + fptr->wbuf.off = 0; + fptr->wbuf.len = 0; + } + } + if (written_len == total) return written_len; + + for (i = 0; i < iovcnt; i++) { + if (r > (ssize_t)iov[i].iov_len) { + r -= iov[i].iov_len; + iov[i].iov_len = 0; + } + else { + iov[i].iov_base = (char *)iov[i].iov_base + r; + iov[i].iov_len -= r; + break; + } + } + + errno = EAGAIN; + } + if (rb_io_wait_writable(fptr->fd)) { + rb_io_check_closed(fptr); + goto retry; + } + + return -1L; +} + +static long +io_fwritev(int argc, VALUE *argv, rb_io_t *fptr) +{ + int i, converted, iovcnt = argc + 1; + long n; + VALUE v1, v2, str, tmp, *tmp_array; + struct iovec *iov; + + if (iovcnt > IOV_MAX) { + rb_raise(rb_eArgError, "too many items (IOV_MAX: %d)", IOV_MAX); + } + + iov = ALLOCV_N(struct iovec, v1, iovcnt); + tmp_array = ALLOCV_N(VALUE, v2, argc); + + for (i = 0; i < argc; i++) { + str = argv[i]; + converted = 0; + str = do_writeconv(str, fptr, &converted); + if (converted) + OBJ_FREEZE(str); + + tmp = rb_str_tmp_frozen_acquire(str); + tmp_array[i] = tmp; + /* iov[0] is reserved for buffer of fptr */ + iov[i+1].iov_base = RSTRING_PTR(tmp); + iov[i+1].iov_len = RSTRING_LEN(tmp); + } + + n = io_binwritev(iov, iovcnt, fptr); + if (v1) ALLOCV_END(v1); + + for (i = 0; i < argc; i++) { + rb_str_tmp_frozen_release(argv[i], tmp_array[i]); + } + + if (v2) ALLOCV_END(v2); + + return n; +} + +static VALUE +io_writev(int argc, VALUE *argv, VALUE io) +{ + rb_io_t *fptr; + long n; + VALUE tmp; + + io = GetWriteIO(io); + tmp = rb_io_check_io(io); + if (NIL_P(tmp)) { + /* port is not IO, call writev method for it. */ + return rb_funcallv(io, id_write, argc, argv); + } + io = tmp; + + GetOpenFile(io, fptr); + rb_io_check_writable(fptr); + + n = io_fwritev(argc, argv, fptr); + if (n == -1L) rb_sys_fail_path(fptr->pathv); + + return LONG2FIX(n); +} +#else +static VALUE +io_writev(int argc, VALUE *argv, VALUE io) +{ + rb_io_t *fptr; + long n, total; + VALUE str, tmp, total = INT2FIX(0); + int nosync; + + io = GetWriteIO(io); + tmp = rb_io_check_io(io); + if (NIL_P(tmp)) { + /* port is not IO, call writev method for it. */ + return rb_funcallv(io, id_write, argc, argv); + } + io = tmp; + + GetOpenFile(io, fptr); + rb_io_check_writable(fptr); + + for (i = 0; i < argc; i++) { + /* sync at last item */ + if (i == argc-1) + nosync = 0; + else + nosync = 1; + + str = argv[i]; + n = io_fwrite(str, fptr, nosync); + if (n == -1L) rb_sys_fail_path(fptr->pathv); + total = rb_fix_plus_fix(LONG2FIX(n), total); + } + + return total; +} +#endif /* HAVE_WRITEV */ + /* * call-seq: - * ios.write(string) -> integer + * ios.write(string, ...) -> integer * - * Writes the given string to ios. The stream must be opened - * for writing. If the argument is not a string, it will be converted - * to a string using to_s. Returns the number of bytes - * written. + * Writes the given strings to ios. The stream must be opened + * for writing. If the arguments are not a string, they will be converted + * to String using to_s. Returns the number of bytes + * written in total. * - * count = $stdout.write("This is a test\n") + * count = $stdout.write("This is", "a test\n") * puts "That was #{count} bytes of data" * * produces: @@ -1501,9 +1705,21 @@ io_write(VALUE io, VALUE str, int nosync) */ static VALUE -io_write_m(VALUE io, VALUE str) +io_write_m(int argc, VALUE *argv, VALUE io) { - return io_write(io, str, 0); +#ifdef HAVE_WRITEV + rb_check_arity(argc, 1, IOV_MAX-1); +#else + rb_check_arity(argc, 1, UNLIMITED_ARGUMENTS); +#endif + + if (argc > 1) { + return io_writev(argc, argv, io); + } + else { + VALUE str = argv[0]; + return io_write(io, str, 0); + } } VALUE @@ -12674,7 +12890,7 @@ Init_IO(void) rb_define_method(rb_cIO, "readpartial", io_readpartial, -1); rb_define_method(rb_cIO, "read", io_read, -1); - rb_define_method(rb_cIO, "write", io_write_m, 1); + rb_define_method(rb_cIO, "write", io_write_m, -1); rb_define_method(rb_cIO, "gets", rb_io_gets_m, -1); rb_define_method(rb_cIO, "readline", rb_io_readline, -1); rb_define_method(rb_cIO, "getc", rb_io_getc, 0); diff --git a/test/ruby/test_io.rb b/test/ruby/test_io.rb index 402d497..6268563 100644 --- a/test/ruby/test_io.rb +++ b/test/ruby/test_io.rb @@ -1216,6 +1216,15 @@ def test_ungetc2 end) end + def test_write_with_multiple_arguments + pipe(proc do |w| + w.write("foo", "bar") + w.close + end, proc do |r| + assert_equal("foobar", r.read) + end) + end + def test_write_non_writable with_pipe do |r, w| assert_raise(IOError) do