Count an IOrderedEnumerable without consuming it

2020-07-10 11:22发布

问题:

What I want to do, short version:

var source = new[]{2,4,6,1,9}.OrderBy(x=>x);
int count = source.Count; // <-- get the number of elements without performing the sort

Long version:

To determine the number of elements in an IEnumerable, it is neccessary to iterate over all elements. This could potentially be a very expensive operation.

If the IEnumerable can be cast to ICollection, then the count can be determined quickly without iterating. The LINQ Count() method does this automatically.

The function myEnumerable.OrderBy() returns an IOrderedEnumerable. An IOrderedEnumerable can obviously not be cast to ICollection, so calling Count() will consume the whole thing.

But sorting does not change the number of elements, and an IOrderedEnumerable has to keep a reference to its source. So if that source is an ICollection, it should be possible to determine the count from the IOrderedEnumerable without consuming it.

My goal is to have a library method, that takes an IEnumerable with n elements, and then for example retrieves the element at position n/2;

I want to avoid iterating over the IEnumerable twice just to get its count, but I also want to avoid creating an unnecessary copy if at all possible.


Here is a skeleton of the function I want to create

public void DoSomething(IEnumerable<T> source)
{
    int count; // What we do with the source depends on its length

    if (source is ICollection)
    {
        count = source.Count(); // Great, we can use ICollection.Count
    }
    else if (source is IOrderedEnumerable)
    {
        // TODO: Find out whether this is based on an ICollection, 
        // TODO: then determine the count of that ICollection
    }
    else
    {
        // Iterating over the source may be expensive, 
        // to avoid iterating twice, make a copy of the source
        source = source.ToList();
        count = source.Count();
    }

    // do some stuff

}

回答1:

Let's think what this code actually looks like:

var source = new[]{ 2, 4, 6, 1, 9 }.OrderBy(x => x);
int count = source.Count();

It is same as

int count = Enumerable.Count(Enumerable.OrderBy(new[]{ 2, 4, 6, 1, 9 }, x => x));

Result of Enumerable.OrderBy(new[]{ 2, 4, 6, 1, 9 }, x => x) is passed into Count extension. You cannot avoid OrderBy execution. And thus it is non-streaming operator, it consumes all source before returning something, which will be passed to Count.

So, the only way to avoid iterating over all collection, is avoiding OrderBy - count items before sorting.


UPDATE: You can call this extension method on any OrderedEnumerable - it will use reflection to get source field of OrderedEnumerable<T> which holds source sequence. Then check if this sequence is collection, and use Count without executing ordering:

public static class Extensions
{
    public static int Count<T>(this IOrderedEnumerable<T> ordered)
    {
        // you can check if ordered is of type OrderedEnumerable<T>
        Type type = ordered.GetType();
        var flags = BindingFlags.NonPublic | BindingFlags.Instance;
        var field = type.GetField("source", flags);
        var source = field.GetValue(ordered);
        if (source is ICollection<T>)
            return ((ICollection<T>)source).Count;

        return ordered.Count();
    }
}

Usage:

var source = new[]{ 2, 4, 6, 1, 9 }.OrderBy(x => x);
int count = source.Count();


回答2:

If you're looking to create a performant solution i'd consider creating overloads that take either a collection or an IOrderedEnumerable etc.. all that "is " and "as" typechecking and casting can't be good for the kind of thing you are creating.

You are reinventing the wheel. linq's "Count()" function does pretty much as you want.

Also, add the this keyword and make this into a nifty extension method, to please yourself and other using the code.

DoSomething(this Collection source);
DoSomething<T>(this List<T> source);
DoSomething<T>(this IOrderedEnumerable<T> source);

etc...



回答3:

Another approach is to implement a class that implements IOrderedEnumerable<T>. You can then implement class members that will short-circuit the usual Linq extension methods, and provide a count method that looks at the original enumeration.

public class MyOrderedEnumerable<T> : IOrderedEnumerable<T>
{
    private IEnumerable<T> Original;
    private IOrderedEnumerable<T> Sorted;

    public MyOrderedEnumerable(IEnumerable<T> orig)
    {
            Original = orig;
            Sorted = null;
    }

    private void ApplyOrder<TKey>(Func<T, TKey> keySelector, IComparer<TKey> comparer, bool descending)
    {
            var before = Sorted != null ? Sorted : Original;
            if (descending)
                    Sorted = before.OrderByDescending(keySelector, comparer);
            else
                    Sorted = before.OrderBy(keySelector, comparer);
    }

    #region Interface Implementations

    public IEnumerator<T> GetEnumerator()
    {
            return Sorted != null ? Sorted.GetEnumerator() : Original.GetEnumerator();
    }

    IEnumerator IEnumerable.GetEnumerator()
    {
            return GetEnumerator();
    }

    public IOrderedEnumerable<T> CreateOrderedEnumerable<TKey>(
            Func<T, TKey> keySelector,
            IComparer<TKey> comparer,
            bool descending)
    {
            var newSorted = new MyOrderedEnumerable<T>(Original);
            newSorted.ApplyOrder(keySelector, comparer, descending);
            return newSorted;
    }

    #endregion Interface Implementations


    //Ensure that OrderBy returns the right type. 
    //There are other variants of OrderBy extension methods you'll have to short-circuit
    public MyOrderedEnumerable<T> OrderBy<TKey>(Func<T, TKey> keySelector)
    {   
            Console.WriteLine("Ordering");
            var newSorted = new MyOrderedEnumerable<T>(Original);
            newSorted.Sorted = (Sorted != null ? Sorted : Original).OrderBy(keySelector);
            return newSorted;
    }

    public int Count()
    {
            Console.WriteLine("Fast counting..");
            var collection = Original as ICollection;
            return collection == null ? Original.Count() : collection.Count;
    }

    public static void Test()
    {
            var nums = new MyOrderedEnumerable<int>(Enumerable.Range(0,10).ToList());
            var nums2 = nums.OrderBy(x => -x);
            var z = nums.Count() + nums2.Count();
    }
}