Self Training Algorithm

2019-03-11 08:48发布

问题:

I'd like to develop a self training algorithm for a specific problem. To keep things simple i'll nail it down to simple example.

Update: I have added a working solution as answer to this question below.

Let's say i have a huge list of entities coming from a database. Each entity is of the same type and has 4 properties of type byte.

public class Entity
{
    public byte Prop1 { get; set; }
    public byte Prop2 { get; set; }
    public byte Prop3 { get; set; }
    public byte Prop4 { get; set; }
}

Now I'd like to dynamically test one or more property of each entity against a simple condition. Which basically means that i want to test all possible combinations of all properties against this condition.

To get this done I have created a bit mask for the properties.

[Flags]
public enum EEntityValues
{
    Undefined = 0,
    Prop1 = 1,
    Prop2 = 2,
    Prop3 = 4,
    Prop4 = 8,
}

And added an method to get the max value of the bit mask. Which returns 15 (1 + 2 + 4 + 8) for this example.

public static int GetMaxValue<T>() where T : struct
{
    return Enum.GetValues( typeof(T) ).Cast<int>().Sum();
}

At this stage I'm able to iterate over all property combinations with a simple loop. In example within the first iteration the property Prop1 is tested, on the second iteration Prop2 is tested, on the third iteration Prop1 and Prop2 are tested and so on.

for(int i = 1; i <= GetMaxValue<EEntityValues>(); i++)
{
     EEntityValues flags = (EEntityValues)i;

     if(flags.HasSet(EEntityValues.Prop1))
     {
         ....
     }
}

Now let's get the entities into the game.

List<Entity> entities = GetEntitiesFromDb();

for(int i = 1; i <= GetMaxValue<EEntityValues>(); i++)
{
     EEntityValues flags = (EEntityValues)i;
     byte minProp1Value = 10, minProp2Value = 20, minProp3Value = 30, minProp4Value = 40;

     foreach(Entitiy entity in entities)
     {
         if(flags.HasSet(EEntityValues.Prop1) && entitiy.Prop1 >= minProp1Value)
         {
              ....
         } else { continue; }

         if(flags.HasSet(EEntityValues.Prop2) && entitiy.Prop2 >= minProp2Value)
         {
              ....
         } else { continue; }
     }
}

Well, this works great if my minimum values are static.

Now let's get more complicated. As we remember, on the first iteration we are testing property Prop1 only, because the bit mask is 1. The value range for Prop1 is 0..255. I also defined a minimum value for this property which has a valid range of 1..255. This minimum value is just a filter to skip entities in the foreach loop.

Now i'd like to test the property Prop1 while i'm rising the minimum level. These tests are not part of the problem so i don't include them into my samples.

     byte minProp1Value = 1;

     while(minProp1Value <= 255)
     {
         foreach(Entitiy entity in entities)
         {
              if(flags.HasSet(EEntityValues.Prop1) && entitiy.Prop1 >= minProp1Value)
              {
                  .... // Testing the matching entity and storing the result
              } else { continue; }
         }

         minProp1Value++;
     }

This is simple for a single property. On the third iteration i have to deal with 2 properties, Prop1 and Prop2, because the bit mask is 3.

     byte minProp1Value = 1, minProp2Value = 1;

     while(minProp1Value <= 255)
     {
         while(minProp2Value <= 255)
         {
              foreach(Entitiy entity in entities)
              {
                   ....
              }

              minProp2Value++;
         }

         minProp1Value++;
         minProp2Value = 1;
     }

As you can see, at this stage i'm testing Prop1 and Prop2 of each entity against a rising minimum level.

For the reason that i'm dealing with dynamically generated sets of multiple properties i can't hardcode the while loops into my code. That's why i'm looking for a smarter solution to test all possible combinations of minimum values for the given property-set (bit mask).

回答1:

After having a rest i came up with a solution that seems to fit my requirements. The limitation is that all tested properties should be of the same type with the same value range, which is fine for me in my case because all properties are abstract percentage values.

By the way, i'm not sure if the topic "self training algorithm" is a little bit misleading here. There are a couple of ways to implement such a solution, but if you have no idea how your data behave and which effects the values have, the most simple solution is to brute force all possible combinations to identify the best fitting result. That's what i'm showing here.

Anyways, for testing purpose i added a random number generator to my entity class.

public class Entity
{
    public byte Prop1 { get; set; }
    public byte Prop2 { get; set; }
    public byte Prop3 { get; set; }
    public byte Prop4 { get; set; }

    public Entity()
    {
        Random random = new Random( Guid.NewGuid().GetHashCode() );
        byte[] bytes = new byte[ 4 ];

        random.NextBytes( bytes );

        this.Prop1 = bytes[0];
        this.Prop2 = bytes[1];
        this.Prop3 = bytes[2];
        this.Prop4 = bytes[3];
    }
}

My bitmask stays untouched.

[Flags]
public enum EProperty
{
    Undefined = 0,
    Prop1 = 1,
    Prop2 = 1 << 1,
    Prop3 = 1 << 2,
    Prop4 = 1 << 3
}

Than i added some new extension methodes to deal with my bitmask.

public static class BitMask
{
    public static int GetMaxValue<T>() where T : struct
    {
        return Enum.GetValues(typeof (T)).Cast<int>().Sum();
    }

    public static int GetTotalCount<T>() where T : struct
    {
        return Enum.GetValues(typeof (T)).Cast<int>().Count(e => e > 0);
    }

    public static int GetFlagCount<T>(this T mask) where T : struct
    {
        int result = 0, value = (int) (object) mask;

        while (value != 0)
        {
            value = value & (value - 1);
            result++;
        }

        return result;
    }

    public static IEnumerable<T> Split<T>(this T mask)
    {
        int maskValue = (int) (object) mask;

        foreach (T flag in Enum.GetValues(typeof (T)))
        {
            int flagValue = (int) (object) flag;

            if (0 != (flagValue & maskValue))
            {
                yield return flag;
            }
        }
    }
}

Than i wrote a query builder

public static class QueryBuilder
{
    public static IEnumerable<Entity> Where(this IEnumerable<Entity> entities, EProperty[] properties, int[] values)
    {
        IEnumerable<Entity> result = entities.Select(e => e);

        for (int index = 0; index <= properties.Length - 1; index++)
        {
            EProperty property = properties[index];
            int value = values[index];

            switch (property)
            {
                case EProperty.Prop1:
                    result = result.Where(e => Math.Abs(e.Prop1) >= value);
                    break;
                case EProperty.Prop2:
                    result = result.Where(e => Math.Abs(e.Prop2) >= value);
                    break;
                case EProperty.Prop3:
                    result = result.Where(e => Math.Abs(e.Prop3) >= value);
                    break;              
                case EProperty.Prop4:
                    result = result.Where(e => Math.Abs(e.Prop4) >= value);
                    break;   
            }
        }

        return result;
    }
}

And finally i'm ready to run the training.

    private const int maxThreads = 10;

    private const int minValue = 0;
    private const int maxValue = 100;

    private static IEnumerable<Entity> entities;

    public static void Main(string[] args)
    {
        Console.WriteLine(DateTime.Now.ToLongTimeString());

        entities = Enumerable.Repeat(new Entity(), 10).ToList();

        Action<EProperty[], int[]> testCase = RunTestCase;
        RunSelfTraining( testCase );

        Console.WriteLine(DateTime.Now.ToLongTimeString());
        Console.WriteLine("Done.");

        Console.Read();
    }

    private static void RunTestCase( EProperty[] properties, int[] values ) 
    {         
        foreach( Entity entity in entities.Where( properties, values ) )
        {

        }
    }

    private static void RunSelfTraining<T>( Action<T[], int[]> testCase ) where T : struct
    {
        ParallelOptions parallelOptions = new ParallelOptions { MaxDegreeOfParallelism = maxThreads };

        for (int maskValue = 1; maskValue <= BitMask.GetMaxValue<T>(); maskValue++)
        {
            T mask = ( T ) (object)maskValue;
            T[] properties = mask.Split().ToArray();         

            int variations = (int) Math.Pow(maxValue - minValue + 1, properties.Length);

            Parallel.For(1, variations, parallelOptions, variation =>
            {
                int[] values = GetVariation(variation, minValue, maxValue, properties.Length).ToArray();   
                testCase.Invoke(properties, values);        
            } );
        }
    }

    public static IEnumerable<int> GetVariation( int index, int minValue, int maxValue, int count )
    {
        index = index - 1; 
        int range = maxValue - minValue + 1;

        for( int j = 0; j < count; j++ )
        {
            yield return index % range + minValue;
            index = index / range;
        }
    }
}