diff --git a/io.c b/io.c index 43f3688..80a10a8 100644 --- a/io.c +++ b/io.c @@ -1445,6 +1445,33 @@ rb_io_seek(VALUE io, VALUE offset, int whence) return INT2FIX(0); } +static VALUE sym_set, sym_cur, sym_current, sym_end; + +static int +sym_to_whence(VALUE ptrname) +{ + retry: + if (ptrname == sym_set){ + return SEEK_SET; + } + else if (ptrname == sym_cur || ptrname == sym_current) { + return SEEK_CUR; + } + else if (ptrname == sym_end) { + return SEEK_END; + } + else { + VALUE lower_ptrname; + + lower_ptrname = rb_funcall2(rb_sym_to_s(ptrname), rb_intern("downcase!"), 0, 0); + if (!NIL_P(lower_ptrname)) { + ptrname = rb_str_intern(lower_ptrname); + goto retry; + } + } + rb_raise(rb_eArgError, "unknown whence: %s", rb_id2name(SYM2ID(ptrname))); +} + /* * call-seq: * ios.seek(amount, whence=IO::SEEK_SET) -> 0 @@ -1473,7 +1500,16 @@ rb_io_seek_m(int argc, VALUE *argv, VALUE io) int whence = SEEK_SET; if (rb_scan_args(argc, argv, "11", &offset, &ptrname) == 2) { - whence = NUM2INT(ptrname); + switch (TYPE(ptrname)) { + case T_FIXNUM: + whence = NUM2INT(ptrname); + break; + case T_SYMBOL: + whence = sym_to_whence(ptrname); + break; + default: + rb_raise(rb_eTypeError, "type mismatch: %s given", rb_obj_classname(ptrname)); + } } return rb_io_seek(io, offset, whence); @@ -11694,4 +11730,8 @@ Init_IO(void) sym_willneed = ID2SYM(rb_intern("willneed")); sym_dontneed = ID2SYM(rb_intern("dontneed")); sym_noreuse = ID2SYM(rb_intern("noreuse")); + sym_set = ID2SYM(rb_intern("set")); + sym_cur = ID2SYM(rb_intern("cur")); + sym_current = ID2SYM(rb_intern("current")); + sym_end = ID2SYM(rb_intern("end")); } diff --git a/test/ruby/test_io.rb b/test/ruby/test_io.rb index ddc5f8c..73a16bf 100644 --- a/test/ruby/test_io.rb +++ b/test/ruby/test_io.rb @@ -1482,6 +1482,72 @@ class TestIO < Test::Unit::TestCase } end + def test_seek_sym + make_tempfile {|t| + open(t.path) { |f| + f.seek(9, :set) + assert_equal("az\n", f.read) + } + + open(t.path) { |f| + f.seek(-4, :end) + assert_equal("baz\n", f.read) + } + + open(t.path) { |f| + assert_equal("foo\n", f.gets) + f.seek(2, :cur) + assert_equal("r\nbaz\n", f.read) + } + + open(t.path) { |f| + assert_equal("foo\n", f.gets) + f.seek(2, :current) + assert_equal("r\nbaz\n", f.read) + } + + open(t.path) { |f| + f.seek(9, :SET) + assert_equal("az\n", f.read) + } + + open(t.path) { |f| + f.seek(-4, :END) + assert_equal("baz\n", f.read) + } + + open(t.path) { |f| + assert_equal("foo\n", f.gets) + f.seek(2, :CUR) + assert_equal("r\nbaz\n", f.read) + } + + open(t.path) { |f| + assert_equal("foo\n", f.gets) + f.seek(2, :CURRENT) + assert_equal("r\nbaz\n", f.read) + } + + open(t.path) { |f| + assert_raise(ArgumentError) do + f.seek(42, :hoge) + end + } + + open(t.path) { |f| + assert_raise(ArgumentError) do + f.seek(42, :HOGE) + end + } + + open(t.path) { |f| + assert_raise(TypeError) do + f.seek(42, "hoge") + end + } + } + end + def test_sysseek make_tempfile {|t| open(t.path) do |f|