diff --git a/marshal.c b/marshal.c index e05b9f5..c00ddd2 100644 --- a/marshal.c +++ b/marshal.c @@ -81,7 +81,7 @@ shortlen(long len, BDIGIT *ds) static ID s_dump, s_load, s_mdump, s_mload; static ID s_dump_data, s_load_data, s_alloc, s_call; -static ID s_getbyte, s_read, s_write, s_binmode; +static ID s_getbyte, s_read, s_readpartial, s_write, s_binmode; typedef struct { VALUE newclass; @@ -958,7 +958,10 @@ marshal_dump(int argc, VALUE *argv) struct load_arg { VALUE src; + char *buf; + long buflen; long offset; + int partial; st_table *symbols; st_table *data; VALUE proc; @@ -1030,15 +1033,29 @@ r_byte(struct load_arg *arg) c = (unsigned char)RSTRING_PTR(arg->src)[arg->offset++]; } else { + too_short: rb_raise(rb_eArgError, "marshal data too short"); } } else { - VALUE src = arg->src; - VALUE v = rb_funcall2(src, s_getbyte, 0, 0); - check_load_arg(arg, s_getbyte); - if (NIL_P(v)) rb_eof_error(); - c = (unsigned char)NUM2CHR(v); + if (arg->buflen == 0) { + VALUE str, n = LONG2NUM(BUFSIZ); + + if (arg->partial) + str = rb_funcall2(arg->src, s_readpartial, 1, &n); + else + str = rb_funcall2(arg->src, s_read, 1, &n); + + check_load_arg(arg, s_read); + if (NIL_P(str)) goto too_short; + StringValue(str); + arg->infection |= (int)FL_TEST(str, MARSHAL_INFECTION); + memcpy(arg->buf, RSTRING_PTR(str), RSTRING_LEN(str)); + arg->offset = 0; + arg->buflen = RSTRING_LEN(str); + } + c = (unsigned char)arg->buf[arg->offset++]; + arg->buflen--; } return c; } @@ -1091,6 +1108,74 @@ r_long(struct load_arg *arg) return x; } +static VALUE +r_bytes1(long len, struct load_arg *arg) +{ + VALUE str, n = LONG2NUM(len); + + str = rb_funcall2(arg->src, s_read, 1, &n); + check_load_arg(arg, s_read); + + if (NIL_P(str)) { + too_short: + rb_raise(rb_eArgError, "marshal data too short"); + } + StringValue(str); + if (RSTRING_LEN(str) < len) goto too_short; + + arg->infection |= (int)FL_TEST(str, MARSHAL_INFECTION); + + return str; +} + +static VALUE +r_bytes1_partial(long len, struct load_arg *arg) +{ + long buflen = arg->buflen; + long tmp_len, need = len - buflen; + VALUE n = LONG2NUM(need > BUFSIZ ? need : BUFSIZ); + VALUE str, tmp; + const char *tmp_ptr; + + tmp = rb_funcall2(arg->src, s_readpartial, 1, &n); + + check_load_arg(arg, s_read); + if (NIL_P(tmp)) { + too_short: + rb_raise(rb_eArgError, "marshal data too short"); + } + StringValue(tmp); + + tmp_ptr = RSTRING_PTR(tmp); + tmp_len = RSTRING_LEN(tmp); + + if (tmp_len < need) { + VALUE fill; + + /* retry */ + n = LONG2NUM(need-tmp_len); + fill = rb_funcall2(arg->src, s_read, 1, &n); + + if (NIL_P(fill)) goto too_short; + StringValue(fill); + if (RSTRING_LEN(fill) < need-tmp_len) goto too_short; + + rb_str_concat(tmp, fill); + tmp_len = RSTRING_LEN(tmp); + } + + arg->infection |= (int)FL_TEST(tmp, MARSHAL_INFECTION); + str = rb_str_new(arg->buf+arg->offset, buflen); + rb_str_cat(str, tmp_ptr, need); + if (tmp_len-need > 0) + memcpy(arg->buf, tmp_ptr+need, tmp_len-need); + + arg->offset = 0; + arg->buflen = tmp_len - need; + + return str; +} + #define r_bytes(arg) r_bytes0(r_long(arg), (arg)) static VALUE @@ -1105,19 +1190,21 @@ r_bytes0(long len, struct load_arg *arg) arg->offset += len; } else { - too_short: rb_raise(rb_eArgError, "marshal data too short"); } } else { - VALUE src = arg->src; - VALUE n = LONG2NUM(len); - str = rb_funcall2(src, s_read, 1, &n); - check_load_arg(arg, s_read); - if (NIL_P(str)) goto too_short; - StringValue(str); - if (RSTRING_LEN(str) != len) goto too_short; - arg->infection |= (int)FL_TEST(str, MARSHAL_INFECTION); + if (len <= arg->buflen) { + str = rb_str_new(arg->buf+arg->offset, len); + arg->offset += len; + arg->buflen -= len; + } + else { + if (arg->partial) + str = r_bytes1_partial(len, arg); + else + str = r_bytes1(len, arg); + } } return str; } @@ -1784,6 +1871,12 @@ marshal_load(int argc, VALUE *argv) arg->data = st_init_numtable(); arg->compat_tbl = st_init_numtable(); arg->proc = 0; + arg->partial = 0; + + if(NIL_P(v)) { + if (rb_respond_to(port, s_readpartial)) arg->partial = 1; + arg->buf = xmalloc(BUFSIZ); + } major = r_byte(arg); minor = r_byte(arg); @@ -1921,6 +2014,7 @@ Init_marshal(void) s_call = rb_intern("call"); s_getbyte = rb_intern("getbyte"); s_read = rb_intern("read"); + s_readpartial = rb_intern("readpartial"); s_write = rb_intern("write"); s_binmode = rb_intern("binmode");