Equality and polymorphism

2019-03-14 11:56发布

问题:

With two immutable classes Base and Derived (which derives from Base) I want to define Equality so that

  • equality is always polymorphic - that is ((Base)derived1).Equals((Base)derived2) will call Derived.Equals

  • operators == and != will call Equals rather than ReferenceEquals (value equality)

What I did:

class Base: IEquatable<Base> {
  public readonly ImmutableType1 X;
  readonly ImmutableType2 Y;

  public Base(ImmutableType1 X, ImmutableType2 Y) { 
    this.X = X; 
    this.Y = Y; 
  }

  public override bool Equals(object obj) {
    if (object.ReferenceEquals(this, obj)) return true;
    if (obj is null || obj.GetType()!=this.GetType()) return false;

    return obj is Base o 
      && X.Equals(o.X) && Y.Equals(o.Y);
  }

  public override int GetHashCode() => HashCode.Combine(X, Y);

  // boilerplate
  public bool Equals(Base o) => object.Equals(this, o);
  public static bool operator ==(Base o1, Base o2) => object.Equals(o1, o2);
  public static bool operator !=(Base o1, Base o2) => !object.Equals(o1, o2);    }

Here everything ends up in Equals(object) which is always polymorphic so both targets are achieved.

I then derive like this:

class Derived : Base, IEquatable<Derived> {
  public readonly ImmutableType3 Z;
  readonly ImmutableType4 K;

  public Derived(ImmutableType1 X, ImmutableType2 Y, ImmutableType3 Z, ImmutableType4 K) : base(X, Y) {
    this.Z = Z; 
    this.K = K; 
  }

  public override bool Equals(object obj) {
    if (object.ReferenceEquals(this, obj)) return true;
    if (obj is null || obj.GetType()!=this.GetType()) return false;

    return obj is Derived o
      && base.Equals(obj) /* ! */
      && Z.Equals(o.Z) && K.Equals(o.K);
  }

  public override int GetHashCode() => HashCode.Combine(base.GetHashCode(), Z, K);

  // boilerplate
  public bool Equals(Derived o) => object.Equals(this, o);
}

Which is basically the same except for one gotcha - when calling base.Equals I call base.Equals(object) and not base.Equals(Derived) (which will cause an endless recursion).

Also Equals(C) will in this implementation do some boxing/unboxing but that is worth it for me.

My questions are -

First is this correct ? my (testing) seems to suggest it is but with C# being so difficult in equality I'm just not sure anymore .. are there any cases where this is wrong ?

and Second - is this good ? are there better cleaner ways to achieve this ?

回答1:

Well I guess there are two parts to you problem:

  1. executing equals at nested level
  2. restricting to the same type

Would this work? https://dotnetfiddle.net/eVLiMZ (I had to use some older syntax as it didn't compile in dotnetfiddle otherwise)

using System;


public class Program
{
    public class Base
    {
        public string Name { get; set; }
        public string VarName { get; set; }

        public override bool Equals(object o)
        {
            return object.ReferenceEquals(this, o) 
                || o.GetType()==this.GetType() && ThisEquals(o);
        }

        protected virtual bool ThisEquals(object o)
        {
            Base b = o as Base;
            return b != null
                && (Name == b.Name);
        }

        public override string ToString()
        {
            return string.Format("[{0}@{1} Name:{2}]", GetType(), VarName, Name);
        }

        public override int GetHashCode()
        {
            return Name.GetHashCode();
        }
    }

    public class Derived : Base
    {
        public int Age { get; set; }

        protected override bool ThisEquals(object o)
        {
            var d = o as Derived;
            return base.ThisEquals(o)
                && d != null
                && (d.Age == Age);
        }

        public override string ToString()
        {
            return string.Format("[{0}@{1} Name:{2} Age:{3}]", GetType(), VarName, Name, Age);
        }

        public override int GetHashCode()
        {
            return base.GetHashCode() ^ Age.GetHashCode();
        }
    }

    public static void Main()
    {
        var b1 = new Base { Name = "anna", VarName = "b1" };
        var b2 = new Base { Name = "leo", VarName = "b2" };
        var b3 = new Base { Name = "anna", VarName = "b3" };
        var d1 = new Derived { Name = "anna", Age = 21, VarName = "d1" };
        var d2 = new Derived { Name = "anna", Age = 12, VarName = "d2" };
        var d3 = new Derived { Name = "anna", Age = 21, VarName = "d3" };

        var all = new object [] { b1, b2, b3, d1, d2, d3 };

        foreach(var a in all) 
        {
            foreach(var b in all)
            {
                Console.WriteLine("{0}.Equals({1}) => {2}", a, b, a.Equals(b));
            }
        }
    }
}



回答2:

This method of comparison using Reflection which, other than the extension methods, is simpler. It also keeps private members private.

All of the logic is in the IImmutableExtensions class. It simply looks at what fields are readonly and uses them for the comparison.

You don't need methods in the base or derived classes for the object comparison. Just call the extension method ImmutableEquals when you are overriding ==, !=, and Equals(). Same with the hashcode.

public class Base : IEquatable<Base>, IImmutable
{
    public readonly ImmutableType1 X;
    readonly ImmutableType2 Y;

    public Base(ImmutableType1 X, ImmutableType2 Y) => (this.X, this.Y) = (X, Y);

    // boilerplate
    public override bool Equals(object obj) => this.ImmutableEquals(obj);
    public bool Equals(Base o) => this.ImmutableEquals(o);
    public static bool operator ==(Base o1, Base o2) => o1.ImmutableEquals(o2);
    public static bool operator !=(Base o1, Base o2) => !o1.ImmutableEquals(o2);
    private int? _hashCache;
    public override int GetHashCode() => this.ImmutableHash(ref _hashCache);
}

public class Derived : Base, IEquatable<Derived>, IImmutable
{
    public readonly ImmutableType3 Z;
    readonly ImmutableType4 K;

    public Derived(ImmutableType1 X, ImmutableType2 Y, ImmutableType3 Z, ImmutableType4 K) : base(X, Y) => (this.Z, this.K) = (Z, K);

    public bool Equals(Derived other) => this.ImmutableEquals(other);
}

And the IImmutableExtensions class:

public static class IImmutableExtensions
{
    public static bool ImmutableEquals(this IImmutable o1, object o2)
    {
        if (ReferenceEquals(o1, o2)) return true;
        if (o2 is null || o1.GetType() != o2.GetType() || o1.GetHashCode() != o2.GetHashCode()) return false;

        foreach (var tProp in GetImmutableFields(o1))
        {
            var test = tProp.GetValue(o1)?.Equals(tProp.GetValue(o2));
            if (test is null) continue;
            if (!test.Value) return false;
        }
        return true;
    }

    public static int ImmutableHash(this IImmutable o, ref int? hashCache)
    {
        if (hashCache is null)
        {
            hashCache = 0;

            foreach (var tProp in GetImmutableFields(o))
            {
                hashCache = HashCode.Combine(hashCache.Value, tProp.GetValue(o).GetHashCode());
            }
        }
        return hashCache.Value;
    }

    private static IEnumerable<FieldInfo> GetImmutableFields(object o)
    {
        var t = o.GetType();
        do
        {
            var fields = t.GetFields(BindingFlags.DeclaredOnly | BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public).Where(field => field.IsInitOnly);

            foreach(var field in fields)
            {
                yield return field;
            }
        }
        while ((t = t.BaseType) != typeof(object));
    }
}

Old answer: (I will leave this for reference)

Based on what you were saying about having to cast to object it occurred to me that the methods Equals(object) and Equals(Base) were too ambiguous when calling them from a derived class.

This said to me that the logic should be moved out of both of the classes, to a method that would better describe our intentions.

Equality will remain polymorphic as ImmutableEquals in the base class will call the overridden ValuesEqual. This is where you can decide in each derived class how to compare equality.

This is your code refactored with that goal.

Revised answer:

It occurred to me that all of our logic in IsEqual() and GetHashCode() would work if we simply supplied a tuple that contained the immutable fields that we wanted to compare. This avoids duplicating so much code in every class.

It is up to the developer that creates the derived class to override GetImmutableTuple(). Without using reflection (see other answer), I feel this is the least of all evils.

public class Base : IEquatable<Base>, IImmutable
{
    public readonly ImmutableType1 X;
    readonly ImmutableType2 Y;

    public Base(ImmutableType1 X, ImmutableType2 Y) => 
      (this.X, this.Y) = (X, Y);

    protected virtual IStructuralEquatable GetImmutableTuple() => (X, Y);

    // boilerplate
    public override bool Equals(object o) => IsEqual(o as Base);
    public bool Equals(Base o) => IsEqual(o);
    public static bool operator ==(Base o1, Base o2) => o1.IsEqual(o2);
    public static bool operator !=(Base o1, Base o2) => !o1.IsEqual(o2);
    public override int GetHashCode() => hashCache is null ? (hashCache = GetImmutableTuple().GetHashCode()).Value : hashCache.Value;
    protected bool IsEqual(Base obj) => ReferenceEquals(this, obj) || !(obj is null) && GetType() == obj.GetType() && GetHashCode() == obj.GetHashCode() && GetImmutableTuple() != obj.GetImmutableTuple();
    protected int? hashCache;
}

public class Derived : Base, IEquatable<Derived>, IImmutable
{
    public readonly ImmutableType3 Z;
    readonly ImmutableType4 K;

    public Derived(ImmutableType1 X, ImmutableType2 Y, ImmutableType3 Z, ImmutableType4 K) : base(X, Y) => 
      (this.Z, this.K) = (Z, K);

    protected override IStructuralEquatable GetImmutableTuple() => (base.GetImmutableTuple(), K, Z);

    // boilerplate
    public bool Equals(Derived o) => IsEqual(o);
}


回答3:

The code can be simplified using a combination of an extension method and some boilercode. This takes almost all of the pain away and leaves classes focused on comparing their instances without having to deal with all the special edge cases:

namespace System {
  public static partial class ExtensionMethods {
    public static bool Equals<T>(this T inst, object obj, Func<T, bool> thisEquals) where T : IEquatable<T> =>
      object.ReferenceEquals(inst, obj) // same reference ->  equal
      || !(obj is null) // this is not null but obj is -> not equal
      && obj.GetType() == inst.GetType() // obj is more derived than this -> not equal
      && obj is T o // obj cannot be cast to this type -> not equal
      && thisEquals(o);
  }
}

I can now do:

class Base : IEquatable<Base> {
    public SomeType1 X;
    SomeType2 Y;
    public Base(int X, int Y) => (this.X, this.Y) = (X, Y);

    public bool ThisEquals(Base o) => (X, Y) == (o.X, o.Y);

    // boilerplate
    public override bool Equals(object obj) => this.Equals(obj, ThisEquals);
    public bool Equals(Base o) => object.Equals(this, o);
    public static bool operator ==(Base o1, Base o2) => object.Equals(o1, o2);
    public static bool operator !=(Base o1, Base o2) => !object.Equals(o1, o2);
}


class Derived : Base, IEquatable<Derived> {
    public SomeType3 Z;
    SomeType4 K;
    public Derived(int X, int Y, int Z, int K) : base(X, Y) => (this.Z, this.K) = (Z, K);

    public bool ThisEquals(Derived o) => base.ThisEquals(o) && (Z, K) == (o.Z, o.K);

    // boilerplate
    public override bool Equals(object obj) => this.Equals(obj, ThisEquals);
    public bool Equals(Derived o) => object.Equals(this, o);
}

(testing)


For immutable classes it is possible to optimize further by caching the hashcode and using it in Equals to shortcut equality if the hashcodes are different:

namespace System.Immutable {
  public interface IImmutableEquatable<T> : IEquatable<T> { };

  public static partial class ExtensionMethods {
    public static bool ImmutableEquals<T>(this T inst, object obj, Func<T, bool> thisEquals) where T : IImmutableEquatable<T> =>
      object.ReferenceEquals(inst, obj) // same reference ->  equal
      || !(obj is null) // this is not null but obj is -> not equal
      && obj.GetType() == inst.GetType() // obj is more derived than this -> not equal
      && inst.GetHashCode() == obj.GetHashCode() // optimization, hash codes are different -> not equal
      && obj is T o // obj cannot be cast to this type -> not equal
      && thisEquals(o);

    public static int GetHashCode<T>(this T inst, ref int? hashCache, Func<int> thisHashCode) where T : IImmutableEquatable<T> {
      if (hashCache is null) hashCache = thisHashCode();
      return hashCache.Value;
    }
  }
}


I can now do:

class Base : IImmutableEquatable<Base> {
    public readonly SomeImmutableType1 X;
    readonly SomeImmutableType2 Y;
    public Base(int X, int Y) => (this.X, this.Y) = (X, Y);

    public bool ThisEquals(Base o) => (X, Y) == (o.X, o.Y);
    public int ThisHashCode() => (X, Y).GetHashCode();


    // boilerplate
    public override bool Equals(object obj) => this.ImmutableEquals(obj, ThisEquals);
    public bool Equals(Base o) => object.Equals(this, o);
    public static bool operator ==(Base o1, Base o2) => object.Equals(o1, o2);
    public static bool operator !=(Base o1, Base o2) => !object.Equals(o1, o2);
    protected int? hashCache;
    public override int GetHashCode() => this.GetHashCode(ref hashCache, ThisHashCode);
}


class Derived : Base, IImmutableEquatable<Derived> {
    public readonly SomeImmutableType3 Z;
    readonly SomeImmutableType4 K;
    public Derived(int X, int Y, int Z, int K) : base(X, Y) => (this.Z, this.K) = (Z, K);

    public bool ThisEquals(Derived o) => base.ThisEquals(o) && (Z, K) == (o.Z, o.K);
    public new int ThisHashCode() => (base.ThisHashCode(), Z, K).GetHashCode();


    // boilerplate
    public override bool Equals(object obj) => this.ImmutableEquals(obj, ThisEquals);
    public bool Equals(Derived o) => object.Equals(this, o);
    public override int GetHashCode() => this.GetHashCode(ref hashCache, ThisHashCode);
}

Which is not too bad - there is more complexity but it is all just boilerplate which I just cut&paste .. the logic is clearly separated in ThisEquals and ThisHashCode

(testing)



回答4:

Another method would be to use Reflection to automatically compare all of your fields and properties. You just have to decorate them with the Immutable attribute and AutoCompare() will take care of the rest.

This will also use Reflection to build a HashCode based on your fields and properties decorated with Immutable, and then cache it to optimize the object comparison.

public class Base : ComparableImmutable, IEquatable<Base>, IImmutable
{
    [Immutable]
    public ImmutableType1 X { get; set; }

    [Immutable]
    readonly ImmutableType2 Y;

    public Base(ImmutableType1 X, ImmutableType2 Y) => (this.X, this.Y) = (X, Y);

    public bool Equals(Base o) => AutoCompare(o);
}

public class Derived : Base, IEquatable<Derived>, IImmutable
{
    [Immutable]
    public readonly ImmutableType3 Z;

    [Immutable]
    readonly ImmutableType4 K;

    public Derived(ImmutableType1 X, ImmutableType2 Y, ImmutableType3 Z, ImmutableType4 K)
        : base(X, Y)
        => (this.Z, this.K) = (Z, K);

    public bool Equals(Derived o) => AutoCompare(o);
}

[AttributeUsage(validOn: AttributeTargets.Field | AttributeTargets.Property)]
public class ImmutableAttribute : Attribute { }

public abstract class ComparableImmutable
{
    static BindingFlags flags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.DeclaredOnly;

    protected int? hashCache;

    public override int GetHashCode()
    {
        if (hashCache is null)
        {
            hashCache = 0;
            var type = GetType();

            do
            {
                foreach (var field in type.GetFields(flags).Where(field => Attribute.IsDefined(field, typeof(ImmutableAttribute))))
                    hashCache = HashCode.Combine(hashCache, field.GetValue(this));

                foreach (var property in type.GetProperties(flags).Where(property => Attribute.IsDefined(property, typeof(ImmutableAttribute))))
                    hashCache = HashCode.Combine(hashCache, property.GetValue(this));

                type = type.BaseType;
            }
            while (type != null);
        }

        return hashCache.Value;
    }

    protected bool AutoCompare(object obj2)
    {
        if (ReferenceEquals(this, obj2)) return true;

        if (obj2 is null
            || GetType() != obj2.GetType()
            || GetHashCode() != obj2.GetHashCode())
            return false;

        var type = GetType();

        do
        {
            foreach (var field in type.GetFields(flags).Where(field => Attribute.IsDefined(field, typeof(ImmutableAttribute))))
            {
                if (field.GetValue(this) != field.GetValue(obj2))
                {
                    return false;
                }
            }

            foreach (var property in type.GetProperties(flags).Where(property => Attribute.IsDefined(property, typeof(ImmutableAttribute))))
            {
                if (property.GetValue(this) != property.GetValue(obj2))
                {
                    return false;
                }
            }

            type = type.BaseType;
        }
        while (type != null);

        return true;
    }

    public override bool Equals(object o) => AutoCompare(o);
    public static bool operator ==(Comparable o1, Comparable o2) => o1.AutoCompare(o2);
    public static bool operator !=(Comparable o1, Comparable o2) => !o1.AutoCompare(o2);
}


标签: c# c#-7.0 c#-7.3