How to iterate nested lists with lambda streams?

2020-06-08 04:26发布

问题:

I'm trying to refactor the following code to lambda expressions with `stream, especially the nested foreach loops:

public static Result match (Response rsp) {
    Exception lastex = null;

    for (FirstNode firstNode : rsp.getFirstNodes()) {
        for (SndNode sndNode : firstNode.getSndNodes()) {
            try {
                if (sndNode.isValid())
                return parse(sndNode); //return the first match, retry if fails with ParseException
            } catch (ParseException e) {
                lastex = e;
            }
        }
    }

    //throw the exception if all elements failed
    if (lastex != null) {
        throw lastex;
    }

    return null;
}

I'm starting with:

rsp.getFirstNodes().forEach().?? // how to iterate the nested 2ndNodes?

回答1:

I am afraid that using streams and lambdas, your performance may suffer. Your current solution returns the first valid and parse-able node, however it is not possible to interrupt an operation on stream such as for-each (source).

Also, because you can have two different outputs (returned result or thrown exception), it won't be possible to do this with single line expression.

Here is what I came up with. It may give you some ideas:

public static Result match(Response rsp) throws Exception {
    Map<Boolean, List<Object>> collect = rsp.getFirstNodes().stream()
            .flatMap(firstNode -> firstNode.getSndNodes().stream()) // create stream of SndNodes
            .filter(SndNode::isValid) // filter so we only have valid nodes
            .map(node -> {
                // try to parse each node and return either the result or the exception
                try {
                    return parse(node);
                } catch (ParseException e) {
                    return e;
                }
            }) // at this point we have stream of objects which may be either Result or ParseException
            .collect(Collectors.partitioningBy(o -> o instanceof Result)); // split the stream into two lists - one containing Results, the other containing ParseExceptions

    if (!collect.get(true).isEmpty()) {
        return (Result) collect.get(true).get(0);
    }
    if (!collect.get(false).isEmpty()) {
        throw (Exception) collect.get(false).get(0); // throws first exception instead of last!
    }
    return null;
}

As mentioned at the beginning, there is possible performance issue as this will try to parse every valid node.


EDIT:

To avoid parsing all nodes, you could use reduce, but it is a bit more complex and ugly (and extra class is needed). This also shows all ParseExceptions instead of just last one.

private static class IntermediateResult {

    private final SndNode node;
    private final Result result;
    private final List<ParseException> exceptions;

    private IntermediateResult(SndNode node, Result result, List<ParseException> exceptions) {
        this.node = node;
        this.result = result;
        this.exceptions = exceptions;
    }

    private Result getResult() throws ParseException {
        if (result != null) {
            return result;
        }
        if (exceptions.isEmpty()) {
            return null;
        }
        // this will show all ParseExceptions instead of just last one
        ParseException exception = new ParseException(String.format("None of %s valid nodes could be parsed", exceptions.size()));
        exceptions.stream().forEach(exception::addSuppressed);
        throw exception;
    }

}

public static Result match(Response rsp) throws Exception {
    return Stream.concat(
                    Arrays.stream(new SndNode[] {null}), // adding null at the beginning of the stream to get an empty "aggregatedResult" at the beginning of the stream
                    rsp.getFirstNodes().stream()
                            .flatMap(firstNode -> firstNode.getSndNodes().stream())
                            .filter(SndNode::isValid)
            )
            .map(node -> new IntermediateResult(node, null, Collections.<ParseException>emptyList()))
            .reduce((aggregatedResult, next) -> {
                if (aggregatedResult.result != null) {
                    return aggregatedResult;
                }

                try {
                    return new IntermediateResult(null, parse(next.node), null);
                } catch (ParseException e) {
                    List<ParseException> exceptions = new ArrayList<>(aggregatedResult.exceptions);
                    exceptions.add(e);
                    return new IntermediateResult(null, null, Collections.unmodifiableList(exceptions));
                }
            })
            .get() // aggregatedResult after going through the whole stream, there will always be at least one because we added one at the beginning
            .getResult(); // return Result, null (if no valid nodes) or throw ParseException
}

EDIT2:

In general, it is also possible to use lazy evaluation when using terminal operators such as findFirst(). So with a minor change of requirements (i.e. returning null instead of throwing exception), it should be possible to do something like below. However, flatMap with findFirst doesn't use lazy evaluation (source), so this code tries to parse all nodes.

private static class ParsedNode {
    private final Result result;

    private ParsedNode(Result result) {
        this.result = result;
    }
}

public static Result match(Response rsp) throws Exception {
    return rsp.getFirstNodes().stream()
            .flatMap(firstNode -> firstNode.getSndNodes().stream())
            .filter(SndNode::isValid)
            .map(node -> {
                try {
                    // will parse all nodes because of flatMap
                    return new ParsedNode(parse(node));
                } catch (ParseException e ) {
                    return new ParsedNode(null);
                }
            })
            .filter(parsedNode -> parsedNode.result != null)
            .findFirst().orElse(new ParsedNode(null)).result;
}


回答2:

Look at flatMap:

flatMap(Function<? super T,? extends Stream<? extends R>> mapper)
Returns a stream consisting of the results of replacing each element of this stream with the contents of a mapped stream produced by applying the provided mapping function to each element.

Code sample assuming isValid() doesn't throw

Optional<SndNode> sndNode = rsp.getFirstNodes()
  .stream()
  .flatMap(firstNode -> firstNode.getSndNodes().stream())  //This is the key line for merging the nested streams
  .filter(sndNode -> sndNode.isValid())
  .findFirst();

if (sndNode.isPresent()) {
    try {
        parse(sndNode.get());
    } catch (ParseException e) {
        lastex = e;
    }
}


回答3:

Try to use map which transform the original source.

   rsp.getFirstNodes().stream().map(FirstNode::getSndNodes)
               .filter(sndNode-> sndNode.isValid())
               .forEach(sndNode->{
   // No do the sndNode parsing operation Here.
   })


回答4:

You can iterate nested loops like below

allAssessmentsForJob.getBody().stream().forEach(assessment -> {
        jobAssessments.stream().forEach(jobAssessment -> {
            if (assessment.getId() == jobAssessment.getAssessmentId()) {
                jobAssessment.setAssessment(assessment);
            }
        });
    });


回答5:

A little bit late but here is a readable approach:

   Result = rsp.getFirstNodes()
        .stream()
        .flatMap(firstNode -> firstNode.getSndNodes.stream())
        .filter(secondNode::isValid))
        .findFirst()
        .map(node -> this.parseNode(node)).orElse(null);

Explanation: you get all the firstNodes and stream() them up. Out a each firstNode you bring n SndNodes. You check each SndNodes to see find the first one that is valid. If there is no valid SndNode then we'll get a null. If there is one, it'll get parsed into a Result

the parseMethod() doesn't change from the original:

public Result parseNode(SndNode node){
        try {
        ...
        ... // attempt to parsed node 
    } catch (ParseException e) {
        throw new ParseException;
    }   
} 


回答6:

You could use that fact that StreamSupport provides a stream method that takes a Spliterator and Iterable has a spliterator method.

You then just need a mechanism to flatten your structure into an Iterable - something like this.

class IterableIterable<T> implements Iterable<T> {

    private final Iterable<? extends Iterable<T>> i;

    public IterableIterable(Iterable<? extends Iterable<T>> i) {
        this.i = i;
    }

    @Override
    public Iterator<T> iterator() {
        return new IIT();
    }

    private class IIT implements Iterator<T> {

        // Pull an iterator.
        final Iterator<? extends Iterable<T>> iit = i.iterator();
        // The current Iterator<T>
        Iterator<T> it = null;
        // The current T.
        T next = null;

        @Override
        public boolean hasNext() {
            boolean finished = false;
            while (next == null && !finished) {
                if (it == null || !it.hasNext()) {
                    if (iit.hasNext()) {
                        it = iit.next().iterator();
                    } else {
                        finished = true;
                    }
                }
                if (it != null && it.hasNext()) {
                    next = it.next();
                }
            }
            return next != null;
        }

        @Override
        public T next() {
            T n = next;
            next = null;
            return n;
        }
    }

}

public void test() {
    List<List<String>> list = new ArrayList<>();
    List<String> first = new ArrayList<>();
    first.add("First One");
    first.add("First Two");
    List<String> second = new ArrayList<>();
    second.add("Second One");
    second.add("Second Two");
    list.add(first);
    list.add(second);
    // Check it works.
    IterableIterable<String> l = new IterableIterable<>(list);
    for (String s : l) {
        System.out.println(s);
    }
    // Stream it like this.
    Stream<String> stream = StreamSupport.stream(l.spliterator(), false);
}

You can now stream directly from your Iterable.

Initial research suggests that this should be done with flatMap but whatever.