Implementing Text Justification with Dynamic Progr

2019-01-22 14:07发布

问题:

I'm trying to understand the concept of Dynamic Programming, via the course on MIT OCW here. The explanation on OCW video is great and all, but I feel like I don't really understand it until I implemented the explanation into code. While implementiing, I refer to some notes from the lecture note here, particularly page 3 of the note.

The problem is, I have no idea how to translate some of the mathematical notation to code. Here's some part of the solution I've implemented (and think it's implemented right):

import math

paragraph = "Some long lorem ipsum text."
words = paragraph.split(" ")

# Count total length for all strings in a list of strings.
# This function will be used by the badness function below.
def total_length(str_arr):
    total = 0

    for string in str_arr:
        total = total + len(string)

    total = total + len(str_arr) # spaces
    return total

# Calculate the badness score for a word.
# str_arr is assumed be send as word[i:j] as in the notes
# we don't make i and j as argument since it will require
# global vars then.
def badness(str_arr, page_width):
    line_len = total_length(str_arr)
    if line_len > page_width:
        return float('nan') 
    else:
        return math.pow(page_width - line_len, 3)

Now the part I don't understand is on point 3 to 5 in the lecture notes. I literally don't understand and don't know where to start implementing those. So far, I've tried iterating the list of words, and counting the badness of each allegedly end of line, like this:

def justifier(str_arr, page_width):
    paragraph = str_arr
    par_len = len(paragraph)
    result = [] # stores each line as list of strings
    for i in range(0, par_len):
        if i == (par_len - 1):
            result.append(paragraph)
        else:
            dag = [badness(paragraph[i:j], page_width) + justifier(paragraph[j:], page_width) for j in range(i + 1, par_len + 1)] 
            # Should I do a min(dag), get the index, and declares it as end of line?

But then, I don't know how I can continue the function, and to be honest, I don't understand this line:

dag = [badness(paragraph[i:j], page_width) + justifier(paragraph[j:], page_width) for j in range(i + 1, par_len + 1)] 

and how I'll return justifier as an int (since I already decided to store return value in result, which is a list. Should I make another function and recurse from there? Should there be any recursion at all?

Could you please show me what to do next, and explain how this is dynamic programming? I really can't see where the recursion is, and what the subproblem is.

Thanks before.

回答1:

In case you have trouble understanding the core idea of dynamic programming itself here is my take on it:

Dynamic programming is essentially sacrificing space complexity for time complexity (but the extra space you use is usually very little compared to the time you save, making dynamic programming totally worth it if implemented correctly). You store the values from each recursive call as you go (e.g. in an array or a dictionary) so you can avoid computing for the second time when you run into the same recursive call in another branch of the recursion tree.

And no you do not have to use recursion. Here is my implementation of the question you were working on using just loops. I followed the TextAlignment.pdf linked by AlexSilva very closely. Hopefully you find this helpful.

def length(wordLengths, i, j): return sum(wordLengths[i- 1:j]) + j - i + 1 def breakLine(text, L): # wl = lengths of words wl = [len(word) for word in text.split()] # n = number of words in the text n = len(wl) # total badness of a text l1 ... li m = dict() # initialization m[0] = 0 # auxiliary array s = dict() # the actual algorithm for i in range(1, n + 1): sums = dict() k = i while (length(wl, k, i) <= L and k > 0): sums[(L - length(wl, k, i))**3 + m[k - 1]] = k k -= 1 m[i] = min(sums) s[i] = sums[min(sums)] # actually do the splitting by working backwords line = 1 while n > 1: print("line " + str(line) + ": " + str(s[n]) + "->" + str(n)) n = s[n] - 1 line += 1



回答2:

For anyone else still interested in this: The key is to move backwards from the end of the text (as mentioned here). If you do so, you just compare already memorized elements.

Say, words is a list of strings to be wrapped according to textwidth. Then, in the notation of the lecture, the task reduces to three lines of code:

import numpy as np

textwidth = 80

DP = [0]*(len(words)+1)

for i in range(len(words)-1,-1,-1):
    DP[i] = np.min([DP[j] + badness(words[i:j],textwidth) for j in range(i+1,len(words)+1)])

With:

def badness(line,textwidth):

    # Number of gaps
    length_line = len(line) - 1

    for word in line:
        length_line += len(word)

    if length_line > textwidth: return float('inf')

    return ( textwidth - length_line )**3

He mentions that one can add a second list to keep track of the breaking positions. You can do so by altering to code to:

DP = [0]*(len(words)+1)
breaks = [0]*(len(words)+1)

for i in range(len(words)-1,-1,-1):
    temp = [DP[j] + badness(words[i:j],args.textwidth) for j in range(i+1,len(words)+1)]

    index = np.argmin(temp)

    # Index plus position in upper list
    breaks[i] = index + i + 1
    DP[i] = temp[index]

To recover the text, just use the list of breaking positions:

def reconstruct_text(words,breaks):                                                                                                                

    lines = []
    linebreaks = []

    i = 0 
    while True:

        linebreaks.append(breaks[i])
        i = breaks[i]

        if i == len(words):
            linebreaks.append(0)
            break

    for i in range( len(linebreaks) ):
        lines.append( ' '.join( words[ linebreaks[i-1] : linebreaks[i] ] ).strip() )

    return lines

Result: (text = reconstruct_text(words,breaks))

Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy
eirmod tempor invidunt ut labore et dolore magna aliquyam erat, sed diam
voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Stet
clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit
amet. Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam
nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, sed
diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Stet
clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet.

One might be tempted to add some whitespaces. This is pretty tricky (since one might come up with various aesthetic rules) but a naive try might be:

import re

def spacing(text,textwidth,maxspace=4):

    for i in range(len(text)):

        length_line = len(text[i])

        if length_line < textwidth:

            status_length = length_line
            whitespaces_remain = textwidth - status_length
            Nwhitespaces = text[i].count(' ')

            # If whitespaces (to add) per whitespace exeeds
            # maxspace, don't do anything.
            if whitespaces_remain/Nwhitespaces > maxspace-1:pass
            else:
                text[i] = text[i].replace(' ',' '*( 1 + int(whitespaces_remain/Nwhitespaces)) )
                status_length = len(text[i])

                # Periods have highest priority for whitespace insertion
                periods = text[i].split('.')

                # Can we add a whitespace behind each period?
                if len(periods) - 1 + status_length <= textwidth:
                    text[i] = '. '.join(periods).strip()

                status_length = len(text[i])
                whitespaces_remain = textwidth - status_length
                Nwords = len(text[i].split())
                Ngaps = Nwords - 1

                if whitespaces_remain != 0:factor = Ngaps / whitespaces_remain

                # List of whitespaces in line i
                gaps = re.findall('\s+', text[i])

                temp = text[i].split()
                for k in range(Ngaps):
                    temp[k] = ''.join([temp[k],gaps[k]])

                for j in range(whitespaces_remain):
                    if status_length >= textwidth:pass
                    else:
                        replace = temp[int(factor*j)]
                        replace = ''.join([replace, " "])
                        temp[int(factor*j)] = replace

                text[i] = ''.join(temp)

    return text

What gives you: (text = spacing(text,textwidth))

Lorem  ipsum  dolor  sit  amet, consetetur  sadipscing  elitr,  sed  diam nonumy
eirmod  tempor  invidunt  ut labore  et  dolore  magna aliquyam  erat,  sed diam
voluptua.   At  vero eos  et accusam  et justo  duo dolores  et ea  rebum.  Stet
clita  kasd  gubergren,  no  sea  takimata sanctus  est  Lorem  ipsum  dolor sit
amet.   Lorem  ipsum  dolor  sit amet,  consetetur  sadipscing  elitr,  sed diam
nonumy  eirmod  tempor invidunt  ut labore  et dolore  magna aliquyam  erat, sed
diam  voluptua.  At vero eos et accusam et  justo duo dolores et ea rebum.  Stet
clita  kasd gubergren, no sea  takimata sanctus est Lorem  ipsum dolor sit amet.


回答3:

i just saw the lecture and thought would put here whatever I could understand. I have put in the code in the similar format as that of the questioner. I have used recursion here, as the lecture had explained.
Point #3, defines recurrence. This is basically a bottom to approach, where in you calculate a value of the function pertaining to a higher input earlier, and then use it to calculate the for the lower valued input.
The lecture explains it as :
DP(i) = min(DP(j) + badness(i, j))
for j which varies from i+1 to n.
Here, i varies from n to 0 (bottom to top!).
As DP(n) = 0 ,
DP(n-1) = DP(n) + badness(n-1, n)
and then you calculate D(n-2) from D(n-1) and D(n) and take a minimum out of them.
This way you can go down till i=0 and that's the final answer of badness!
In point #4, as you can see, there are two loops going on here. One for i and the other inside i for j.
Hence, when i=0, j(max) = n, i = 1, j(max) = n-1, ... i = n , j(max) = 0.
hence total time = addition of these = n(n+1)/2.
Hence O(n^2).
Point #5 just identifies the solution which DP[0]!
Hope this helps!

import math

justification_map = {}
min_map = {}

def total_length(str_arr):
    total = 0

    for string in str_arr:
        total = total + len(string)

    total = total + len(str_arr) - 1 # spaces
    return total

def badness(str_arr, page_width):
    line_len = total_length(str_arr)
    if line_len > page_width:
        return float('nan') 
    else:
        return math.pow(page_width - line_len, 3)

def justify(i, n, words, page_width):
    if i == n:

        return 0
    ans = []
    for j in range(i+1, n+1):
        #ans.append(justify(j, n, words, page_width)+ badness(words[i:j], page_width))
        ans.append(justification_map[j]+ badness(words[i:j], page_width))
    min_map[i] = ans.index(min(ans)) + 1
    return min(ans)

def main():
    print "Enter page width"
    page_width = input()
    print "Enter text"
    paragraph = input() 
    words = paragraph.split(' ')
    n = len(words)
    #justification_map[n] = 0 
    for i in reversed(range(n+1)):
        justification_map[i] = justify(i, n, words, page_width)

    print "Minimum badness achieved: ", justification_map[0]

    key = 0
    while(key <n):
        key = key + min_map[key]
        print key

if __name__ == '__main__':
    main()


回答4:

This is what I think according to your definition.

import math

class Text(object):
    def __init__(self, words, width):
        self.words = words
        self.page_width = width
        self.str_arr = words
        self.memo = {}

    def total_length(self, str):
        total = 0
        for string in str:
            total = total + len(string)
        total = total + len(str) # spaces
        return total

    def badness(self, str):
        line_len = self.total_length(str)
        if line_len > self.page_width:
            return float('nan') 
        else:
            return math.pow(self.page_width - line_len, 3)

    def dp(self):
        n = len(self.str_arr)
        self.memo[n-1] = 0

        return self.judge(0)

    def judge(self, i):
        if i in self.memo:
            return self.memo[i]

        self.memo[i] = float('inf') 
        for j in range(i+1, len(self.str_arr)):
            bad = self.judge(j) + self.badness(self.str_arr[i:j])
            if bad < self.memo[i]:
                self.memo[i] = bad

        return self.memo[i]


回答5:

Java Implementation Given the maximum line width as L, the idea to justify the Text T, is to consider all suffixes of the Text (consider words instead of characters for forming suffixes to be precise.) Dynamic Programming is nothing but "Careful Brute-force". If you consider the brute force approach, you need to do the following.

  1. consider putting 1, 2, .. n words in the first line.
  2. for each case described in case 1(say i words are put in line 1), consider cases of putting 1, 2, .. n -i words in the second line and then remaining words on third line and so on..

Instead lets just consider the problem to find out the cost of putting a word at the beginning of a line. In general we can define DP(i) to be the cost for considering (i- 1)th word as the beginning of a Line.

How can we form a recurrence relation for DP(i)?

If jth word is the beginning of the next line, then the current line will contain words[i:j) (j exclusive) and the cost of jth word being the beginning of the next line will be DP(j). Hence DP(i) = DP(j) + cost of putting words[i:j) in the current line As we want to minimise the total cost, DP(i) can be defined as follows.

Recurrence relation:

DP(i) = min { DP(j) + cost of putting words[i:j in the current line } for all j in [i+1, n]

Note j = n signify that no words are left to be put in the next line.

The base Case: DP(n) = 0 => at this point there is no word left to be written.

To summarise:

  1. Subproblems: suffixes , words[:i]
  2. Guess: Where to start the next line, # of choices n - i -> O(n)
  3. Recurrence: DP(i) = min {DP(j) + cost of putting words[i:j) in the current line } If we use memoization, the expression inside the curly brace should should take O(1) time, and the loop run O(n) times (# of choices times). i Varies from n down to 0 => Hence total complexity is brought down to O(n^2).

Now even though we derived the minimum cost for justifying the text, we also need to solve the original problem by keeping track of the j value for chosen as minimum in the above expression, so that we can later use the same to print out the justified text. The idea is of keeping parent pointer.

Hope this helps you understand the solution. Below is the simple implementation of the above idea.

 public class TextJustify {
    class IntPair {
        //The cost or badness
        final int x;

        //The index of word at the beginning of a line
        final int y;
        IntPair(int x, int y) {this.x=x;this.y=y;}
    }
    public List<String> fullJustify(String[] words, int L) {
        IntPair[] memo = new IntPair[words.length + 1];

        //Base case
        memo[words.length] = new IntPair(0, 0);


        for(int i = words.length - 1; i >= 0; i--) {
            int score = Integer.MAX_VALUE;
            int nextLineIndex = i + 1;
            for(int j = i + 1; j <= words.length; j++) {
                int badness = calcBadness(words, i, j, L);
                if(badness < 0 || badness == Integer.MAX_VALUE) break;
                int currScore = badness + memo[j].x;
                if(currScore < 0 || currScore == Integer.MAX_VALUE) break;
                if(score > currScore) {
                    score = currScore;
                    nextLineIndex = j;
                }
            }
            memo[i] = new IntPair(score, nextLineIndex);
        }

        List<String> result = new ArrayList<>();
        int i = 0;
        while(i < words.length) {
            String line = getLine(words, i, memo[i].y);
            result.add(line);
            i = memo[i].y;
        }
        return result;
    }

    private int calcBadness(String[] words, int start, int end, int width) {
        int length = 0;
        for(int i = start; i < end; i++) {
            length += words[i].length();
            if(length > width) return Integer.MAX_VALUE;
            length++;
        }
        length--;
        int temp = width - length;
        return temp * temp;
    }


    private String getLine(String[] words, int start, int end) {
        StringBuilder sb = new StringBuilder();
        for(int i = start; i < end - 1; i++) {
            sb.append(words[i] + " ");
        }
        sb.append(words[end - 1]);

        return sb.toString();
    }
  }