How to optimize/simplify heapsorting django object

2019-07-07 02:44发布

问题:

I have to ask for some help with an assignment I got as a test for a django internship. I had to make and imaginary api with rabbits and their carrots. Each rabbit was supposed to have a number of carrots, but the api had to be designed to allow for easy addition of other kind of vegetable. I rejected integer field for each vegetable and instead went for vegetable object with type and value of vegetable.

Problem is, the assignment also included listing the rabbits sorted by carrots, descending. They wanted me to implement heapsort, no database sort was allowed, no external libs. While i had no problem with that, I am having trouble with time constraints they thought of - for 20 000 rabbits to be sorted in under 30 seconds, ideally 5 seconds. And it already takes 5 seconds with 200 rabbits(just sorting and serializing to json).

I make a queryset that has only rabbits with "carrots" vegetables. Then I force it into normal list and run heapsort function on it.

How would I need to change it to be faster? Is it even possible? I will be very happy if someone helps even a bit. Thank You in advance!

My models:

class Bunny(models.Model):
    """Bunny model for bunny usage"""
    def __str__(self):
        return self.name + " " + str(list(self.vegetables.all()))

    name = models.CharField("Name", max_length=50)
    userAccount = models.ForeignKey(User, on_delete=models.CASCADE)

    def getVegetable(self, vegetableType):
        for vegetable in self.vegetables.all():
            if vegetable.vegetableType == vegetableType:
                return vegetable
        return False


class Vegetable(models.Model):
    """Vegetable model for storing vegetable counts"""
    def __str__(self):
        return self.vegetableType + ":" + str(self.value)

    vegetableType = models.CharField(max_length=30, choices=vegetableChoices)
    value = models.PositiveIntegerField(default=0, validators=[MinValueValidator(0)])
    bunny = models.ForeignKey(Bunny, related_name="vegetables", on_delete=models.CASCADE)

My heapsort function:

def heapsort(bunnies, vegetableType):
    """Heapsort function for bunnies, works in place, descending"""

    for start in range((len(bunnies) - 2) // 2, -1, -1):
        siftdown(bunnies, start, len(bunnies) - 1, vegetableType)

    for end in range(len(bunnies) - 1, 0, -1):
        bunnies[end], bunnies[0] = bunnies[0], bunnies[end]
        siftdown(bunnies, 0, end - 1, vegetableType)
    return bunnies


def siftdown(bunnies, start, end, vegetableType):
    """helper function for heapsort"""
    root = start
    while True:
        child = root * 2 + 1
        if child > end:
            break
        if child + 1 <= end and bunnies[child].vegetables.get(vegetableType=vegetableType).value > bunnies[
                    child + 1].vegetables.get(vegetableType=vegetableType).value:
            child += 1
        if bunnies[root].vegetables.get(vegetableType=vegetableType).value > bunnies[child].vegetables.get(
                vegetableType=vegetableType).value:
            bunnies[root], bunnies[child] = bunnies[child], bunnies[root]
            root = child
        else:
            break

And also the performance test they asked for(I do not know of a better way. Just creating bunnies takes a long time)

def test_20000_rabbits_performance(self):
    print("Creating bunnies")
    register20000Bunnies()

    print("Created bunnies")
    timestart = time()

    url = reverse("api:list", args=["carrots"])

    response = self.client.get(url)
    timeMeasured = time() - timestart
    print("Sorted. Took: " + str(timeMeasured))

    self.assertEqual(response.status_code, status.HTTP_200_OK)

My view:

@api_view(["GET"])
def bunnyList(request, vegetableType):
""" Displays heap-sorted list of bunnies, in decreasing order.
    Takes word after list ("/list/xxx") as argument to determine
    which vegetable list to display"""
    if vegetableType in vegetablesChoices:
        bunnies =
    Bunny.objects.filter(vegetables__vegetableType=vegetableType)
        bunnies = list(bunnies)  # force into normal list

        if len(bunnies) == 0:
            return Response({"No bunnies": "there is %d bunnies with this vegetable" % len(bunnies)},
                        status=status.HTTP_204_NO_CONTENT)

        heapsort(bunnies, vegetableType)
        serialized = BunnySerializerPartial(bunnies, many=True)
        return Response(serialized.data, status=status.HTTP_200_OK)
    else:
        raise serializers.ValidationError("No such vegetable. Available are: " + ", ".join(vegetablesChoices))

Edit: just checked now, currently it takes 1202 seconds to sort... My machine is 2 core 1.86GHz, but still.

Edit2, new code:

@api_view(["GET"])
def bunnyList(request, vegetableType):
""" Displays heap-sorted list of bunnies, in decreasing order.
    Takes word after list ("/list/xxx") as argument to determine
    which vegetable list to display"""
if vegetableType in vegetablesChoices:
    vegetables =  Vegetable.objects.filter(vegetableType=vegetableType).select_related('bunny')
    vegetables = list(vegetables)

    if len(vegetables) == 0:
        return Response({"No bunnies": "there is 0 bunnies with this vegetable"},
                        status=status.HTTP_204_NO_CONTENT)

    heapsort(vegetables)

    bunnies = [vegetable.bunny for vegetable in vegetables]
    serialized = BunnySerializerPartial(bunnies, many=True)
    return Response(serialized.data, status=status.HTTP_200_OK)
else:
    raise serializers.ValidationError("No such vegetable. Available are: " + ", ".join(vegetablesChoices))

Updated heapsort:

def heapsort(vegetables):
"""Heapsort function for vegetables, works in place, descending"""

for start in range((len(vegetables) - 2) // 2, -1, -1):
    siftdown(vegetables, start, len(vegetables) - 1)

for end in range(len(vegetables) - 1, 0, -1):
    vegetables[end], vegetables[0] = vegetables[0], vegetables[end]
    siftdown(vegetables, 0, end - 1)
return vegetables


def siftdown(vegetables, start, end):
"""helper function for heapsort"""
root = start
while True:
    child = root * 2 + 1
    if child > end:
        break
    if child + 1 <= end and vegetables[child].value > vegetables[child+1].value:
        child += 1
    if vegetables[root].value > vegetables[child].value:
        vegetables[root], vegetables[child] = vegetables[child], vegetables[root]
        root = child
    else:
        break

My serializers:

class BunnySerializerPartial(serializers.ModelSerializer):
"""Used in list view, mirrors BunnySerializerFull but without account details"""
    vegetables = VegetableSerializer(many=True)

    class Meta:
        model = Bunny
        fields = ("name", "vegetables")


class VegetableSerializer(serializers.ModelSerializer):
"""Used for displaying vegetables, for example in list view"""
    class Meta:
        model = Vegetable
        fields = ("vegetableType", "value")

And queries from the toolbar:

SELECT ••• FROM "zajaczkowskiBoardApi_vegetable" INNER JOIN "zajaczkowskiBoardApi_bunny" ON ("zajaczkowskiBoardApi_vegetable"."bunny_id" = "zajaczkowskiBoardApi_bunny"."id") WHERE "zajaczkowskiBoardApi_vegetable"."vegetableType" = '''carrots'''


SELECT ••• FROM "zajaczkowskiBoardApi_vegetable" WHERE "zajaczkowskiBoardApi_vegetable"."bunny_id" = '141'

Second one duplicated 20 000 times

回答1:

This is the classic N+1 queries problem. You perform a single query to fetch all the bunnies, but then you go on to do bunnies[child].vegetables.get(vegetableType=vegetableType) for each bunny, which performs an additional query, and thus an additional database roundtrip, for each bunny. So you perform 1 query for N bunnies, plus around N queries to get all the vegetables (hence the N+1).

Database roundtrips are one of the most expensive resource available to web developers. While comparisons take somewhere in the order of nanoseconds, a database roundtrip takes in the order of milliseconds. Do ~20K queries and this will soon add up to take several minutes.

The quick solution is to use prefetch_related('vegetables') and exclusively use bunny.getVegetable('carrot') to get the carrot. prefetch_related() will perform a single query to get all the vegetables for all bunnies and cache them, so iterating self.vegetables.all() in getVegetables() won't perform any additional queries.


There are better solutions, though. In this case, it seems that each bunny should have at most 1 Vegetable object of a specific vegetableType. If you enforce this at the database level, you won't have to worry about errors in your sorting algorithm when someone decides to add a second Vegetable of type 'carrot' to a bunny. Instead, the database will stop them from doing that in the first place. To do this, you need a unique_together constraint:

class Vegetable(models.Model):
    ...
    class Meta:
        unique_together = [
            ('vegetableType', 'bunny'),
        ]

Then, rather than fetching all bunnies and prefetching all related vegetables, you can fetch all vegetables of type "carrot" and join the related bunnies. Now you will only have a single query:

carrots = Vegetable.objects.filter(vegetableType='carrot').select_related('bunny')

Since the combination of vegetableType and bunny is unique, you won't get any duplicate bunnies, and you will still get all bunnies that have some carrots.

Of course you'd have to adapt your algorithm to work with the vegetables rather than the bunnies.