From 9058f2e49d03a79dc46d74cadbf26939cad69a1a Mon Sep 17 00:00:00 2001 From: Jeremy Evans Date: Tue, 27 Aug 2019 12:53:30 -0700 Subject: [PATCH] Make Set call to_st on arguments if reponding to it to convert to set This allows duck typing to work. This affects: * flatten * flatten! * proper_superset? * proper_subset? * superset? * subset? * intersect? * == * eql? Implements [Feature #15240] --- lib/set.rb | 26 +++++++++++++++++--------- test/test_set.rb | 28 ++++++++++++++++++++++++---- 2 files changed, 41 insertions(+), 13 deletions(-) diff --git a/lib/set.rb b/lib/set.rb index 921f18f97b..448eb9f820 100644 --- a/lib/set.rb +++ b/lib/set.rb @@ -212,10 +212,18 @@ def to_set(klass = Set, *args, &block) return self if instance_of?(Set) && klass == Set && block.nil? && args.empty? klass.new(self, *args, &block) end + alias to_st to_set + + private def _to_st?(e) + if e.respond_to?(:to_st) && (s = e.to_st) && s.is_a?(Set) + s + end + end def flatten_merge(set, seen = Set.new) # :nodoc: set.each { |e| - if e.is_a?(Set) + if s = _to_st?(e) + e = s if seen.include?(e_id = e.object_id) raise ArgumentError, "tried to flatten recursive Set" end @@ -241,7 +249,7 @@ def flatten # Equivalent to Set#flatten, but replaces the receiver with the # result in place. Returns nil if no modifications were made. def flatten! - replace(flatten()) if any? { |e| e.is_a?(Set) } + replace(flatten()) if any? { |e| _to_st?(e) } end # Returns true if the set contains the given object. @@ -260,7 +268,7 @@ def superset?(set) case when set.instance_of?(self.class) && @hash.respond_to?(:>=) @hash >= set.instance_variable_get(:@hash) - when set.is_a?(Set) + when set = _to_st?(set) size >= set.size && set.all? { |o| include?(o) } else raise ArgumentError, "value must be a set" @@ -273,7 +281,7 @@ def proper_superset?(set) case when set.instance_of?(self.class) && @hash.respond_to?(:>) @hash > set.instance_variable_get(:@hash) - when set.is_a?(Set) + when set = _to_st?(set) size > set.size && set.all? { |o| include?(o) } else raise ArgumentError, "value must be a set" @@ -286,7 +294,7 @@ def subset?(set) case when set.instance_of?(self.class) && @hash.respond_to?(:<=) @hash <= set.instance_variable_get(:@hash) - when set.is_a?(Set) + when set = _to_st?(set) size <= set.size && all? { |o| set.include?(o) } else raise ArgumentError, "value must be a set" @@ -299,7 +307,7 @@ def proper_subset?(set) case when set.instance_of?(self.class) && @hash.respond_to?(:<) @hash < set.instance_variable_get(:@hash) - when set.is_a?(Set) + when set = _to_st?(set) size < set.size && all? { |o| set.include?(o) } else raise ArgumentError, "value must be a set" @@ -313,7 +321,7 @@ def proper_subset?(set) # Set[1, 2, 3].intersect? Set[4, 5] #=> false # Set[1, 2, 3].intersect? Set[3, 4] #=> true def intersect?(set) - set.is_a?(Set) or raise ArgumentError, "value must be a set" + set = _to_st?(set) or raise ArgumentError, "value must be a set" if size < set.size any? { |o| set.include?(o) } else @@ -503,7 +511,7 @@ def ==(other) true elsif other.instance_of?(self.class) @hash == other.instance_variable_get(:@hash) - elsif other.is_a?(Set) && self.size == other.size + elsif (other = _to_st?(other)) && self.size == other.size other.all? { |o| @hash.include?(o) } else false @@ -515,7 +523,7 @@ def hash # :nodoc: end def eql?(o) # :nodoc: - return false unless o.is_a?(Set) + return false unless o = _to_st?(o) @hash.eql?(o.instance_variable_get(:@hash)) end diff --git a/test/test_set.rb b/test/test_set.rb index b20920e63e..c8f1833516 100644 --- a/test/test_set.rb +++ b/test/test_set.rb @@ -5,6 +5,12 @@ class TC_Set < Test::Unit::TestCase class Set2 < Set end + class ArraySet < Array + def self.[](*v) + new.concat(v) + end + alias to_st to_set + end def test_aref assert_nothing_raised { @@ -232,6 +238,8 @@ def test_superset? set.superset?([2]) } + assert_equal(true, set.superset?(ArraySet[2])) + [Set, Set2].each { |klass| assert_equal(true, set.superset?(klass[]), klass.name) assert_equal(true, set.superset?(klass[1,2]), klass.name) @@ -261,6 +269,8 @@ def test_proper_superset? set.proper_superset?([2]) } + assert_equal(true, set.proper_superset?(ArraySet[2])) + [Set, Set2].each { |klass| assert_equal(true, set.proper_superset?(klass[]), klass.name) assert_equal(true, set.proper_superset?(klass[1,2]), klass.name) @@ -290,6 +300,8 @@ def test_subset? set.subset?([2]) } + assert_equal(false, set.subset?(ArraySet[2])) + [Set, Set2].each { |klass| assert_equal(true, set.subset?(klass[1,2,3,4]), klass.name) assert_equal(true, set.subset?(klass[1,2,3]), klass.name) @@ -319,6 +331,8 @@ def test_proper_subset? set.proper_subset?([2]) } + assert_equal(false, set.proper_subset?(ArraySet[2])) + [Set, Set2].each { |klass| assert_equal(true, set.proper_subset?(klass[1,2,3,4]), klass.name) assert_equal(false, set.proper_subset?(klass[1,2,3]), klass.name) @@ -336,14 +350,14 @@ def assert_intersect(expected, set, other) case expected when true assert_send([set, :intersect?, other]) - assert_send([other, :intersect?, set]) + assert_send([other, :intersect?, set]) if other.is_a?(Set) assert_not_send([set, :disjoint?, other]) - assert_not_send([other, :disjoint?, set]) + assert_not_send([other, :disjoint?, set]) if other.is_a?(Set) when false assert_not_send([set, :intersect?, other]) - assert_not_send([other, :intersect?, set]) + assert_not_send([other, :intersect?, set]) if other.is_a?(Set) assert_send([set, :disjoint?, other]) - assert_send([other, :disjoint?, set]) + assert_send([other, :disjoint?, set]) if other.is_a?(Set) when Class assert_raise(expected) { set.intersect?(other) @@ -361,6 +375,7 @@ def test_intersect? assert_intersect(ArgumentError, set, 3) assert_intersect(ArgumentError, set, [2,4,6]) + assert_intersect(true, set, ArraySet[2,4,6]) assert_intersect(true, set, set) assert_intersect(true, set, Set[2,4]) @@ -617,6 +632,11 @@ def test_eq assert_equal(set1, set1) assert_equal(set1, set2) assert_not_equal(Set[1], [1]) + assert_equal(Set[1], ArraySet[1]) + assert_equal(false, Set[1].eql?([1])) + assert_equal(true, Set[1].eql?(ArraySet[1])) + assert_equal(false, Set[1] == [1]) + assert_equal(true, Set[1] == ArraySet[1]) set1 = Class.new(Set)["a", "b"] set2 = Set["a", "b", set1] -- 2.22.0