all permutations of +-r, +-s

2020-08-24 06:47发布

问题:

Given two numbers r and s, I would like to get a list of all permutations of n +-r and m +-s. For example (with r=3.14 and s=2.71),

n = 1
m = 1
out = [
    (+r, +s), (+r, -s), (-r, +s), (-r, -s), 
    (+s, +r), (+s, -r), (-s, +r), (-s, -r)
    ]
n = 1
m = 2
out = [
    (+r, +s, +s), (+r, -s, +s), (-r, +s, +s), (-r, -s, +s), ...
    (+s, +r, +s), (-s, +r, +s), (+s, -r, +s), (-s, -r, +s), ...
    ...
    ]

With itertools.product([+r, -r], repeat=n) I can get the list of the rs and ss separately, and I'd only need to intertwine them, but I'm not sure if this is the right thing to do.

Efficiency is not overly important, so I wouldn't mind a solution that produces many repeated results only to make them unique afterwards.

回答1:

Update: general solution added.

Here is a solution that is bit more complicated in code but does not produce repeated elements and can be evaluated lazily:

from itertools import combinations, product, chain

r = 3.14
s = 2.71
n = 1
m = 2
idx = combinations(range(n + m), n)
vs = ((r if j in i else s for j in range(n + m)) for i in idx)
res = chain.from_iterable(product(*((+vij, -vij) for vij in vi)) for vi in vs)
print("\n".join(map(str, res)))

Output:

(3.14, 2.71, 2.71)
(3.14, 2.71, -2.71)
(3.14, -2.71, 2.71)
(3.14, -2.71, -2.71)
(-3.14, 2.71, 2.71)
(-3.14, 2.71, -2.71)
(-3.14, -2.71, 2.71)
(-3.14, -2.71, -2.71)
(2.71, 3.14, 2.71)
(2.71, 3.14, -2.71)
(2.71, -3.14, 2.71)
(2.71, -3.14, -2.71)
(-2.71, 3.14, 2.71)
(-2.71, 3.14, -2.71)
(-2.71, -3.14, 2.71)
(-2.71, -3.14, -2.71)
(2.71, 2.71, 3.14)
(2.71, 2.71, -3.14)
(2.71, -2.71, 3.14)
(2.71, -2.71, -3.14)
(-2.71, 2.71, 3.14)
(-2.71, 2.71, -3.14)
(-2.71, -2.71, 3.14)
(-2.71, -2.71, -3.14)

Explanation

We can think of the output as permutations containing n +/- r elements and m +/- s elements, or, in other words, tuples of n + m elements where n are +/- r and the rest are +/- s. idx contains tuples with all the possible positions for +/- r elements; for example, for the first result it is (0,).

Then, for each of these tuples i we create "template" tuples in vs, which are just tuples of size n + m where indices in i are r and the rest are s. So, for the tuple (0,) in idx you would get (r, s, s). If n + m is very big you could consider a previous step idx = map(set, idx) for a faster in operation, but I'm not sure at which point that would be worth it.

Finally, for each of these templates vi in v I need to consider all the possibilities using a positive and negative value for each of its elements. So it is a Cartesian product of (+vi[0], -vi[0]), (+vi[1], -vi[1]), .... And finally you just need to chain each of the generator of each of these products to get the final result.

General solution

To build a general solution to the problem for an arbitrary number of different elements, you need to consider partitions of the set of indices. For example, for n = 3 and m = 5, all the possible ways you can split {0, 1, 2, 3, 4, 5, 6, 7} in two parts of sizes 3 and 5. Here is an implementation for that:

from itertools import chain, repeat, permutations, product


def partitions(*sizes):
    if not sizes or all(s <= 0 for s in sizes):
        yield ()
    for i_size, size in enumerate(sizes):
        if size <= 0:
            continue
        next_sizes = sizes[:i_size] + (sizes[i_size] - 1,) + sizes[i_size + 1:]
        for p in partitions(*next_sizes):
            yield (i_size,) + p


def signed_permutations(*elems):
    values, sizes = zip(*elems)
    templates = partitions(*sizes)
    return chain.from_iterable(
        product(*((+values[ti], -values[ti]) for ti in t)) for t in templates)


r = 3.14
s = 2.71
n = 1
m = 2
res = signed_permutations((r, n), (s, m))
print("\n".join(map(str, res)))

The idea is the same, you build the "templates" (this time they contain indices of the values instead of the values themselves) and then the Cartesian products from them.



回答2:

You could also combine the permutations of r and s with the product of +1 and -1 and zip the two. This way, the entire construction is a bit more readable IMHO:

>>> n, m = 1, 2
>>> r, s = 3.14, 2.71
>>> [[x*i for x,i in zip(perm, prod)] for perm in permutations([r]*n + [s]*m) 
...                                   for prod in product((+1, -1), repeat=n+m)]
[[3.14, 2.71, 2.71],
 [3.14, 2.71, -2.71],
 ...
 [-2.71, -2.71, 3.14],
 [-2.71, -2.71, -3.14]]


回答3:

First use product, then permutations on each element. Then concatenate all results and pass them to set() for removing duplicates:

arr = set(itertools.chain.from_iterable([
    itertools.permutations(x)
    for x in itertools.product(*([[+r, -r]] * n + [[+s, -s]] * m))
    ]))
print(arr)