Finding the number of all longest increasing subsequences
Full Java code of improved LIS algorithm, which discovers not only the length of longest increasing subsequence, but number of subsequences of such length, is below. I prefer to use generics to allow not only integers, but any comparable types.
@Test
public void testLisNumberAndLength() {
List<Integer> input = Arrays.asList(16, 5, 8, 6, 1, 10, 5, 2, 15, 3, 2, 4, 1);
int[] result = lisNumberAndlength(input);
System.out.println(String.format(
"This sequence has %s longest increasing subsequenses of length %s",
result[0], result[1]
));
}
/**
* Body of improved LIS algorithm
*/
public <T extends Comparable<T>> int[] lisNumberAndLength(List<T> input) {
if (input.size() == 0)
return new int[] {0, 0};
List<List<Sub<T>>> subs = new ArrayList<>();
List<Sub<T>> tails = new ArrayList<>();
for (T e : input) {
int pos = search(tails, new Sub<>(e, 0), false); // row for a new sub to be placed
int sum = 1;
if (pos > 0) {
List<Sub<T>> pRow = subs.get(pos - 1); // previous row
int index = search(pRow, new Sub<T>(e, 0), true); // index of most left element that <= e
if (pRow.get(index).value.compareTo(e) < 0) {
index--;
}
sum = pRow.get(pRow.size() - 1).sum; // sum of tail element in previous row
if (index >= 0) {
sum -= pRow.get(index).sum;
}
}
if (pos >= subs.size()) { // add a new row
List<Sub<T>> row = new ArrayList<>();
row.add(new Sub<>(e, sum));
subs.add(row);
tails.add(new Sub<>(e, 0));
} else { // add sub to existing row
List<Sub<T>> row = subs.get(pos);
Sub<T> tail = row.get(row.size() - 1);
if (tail.value.equals(e)) {
tail.sum += sum;
} else {
row.add(new Sub<>(e, tail.sum + sum));
tails.set(pos, new Sub<>(e, 0));
}
}
}
List<Sub<T>> lastRow = subs.get(subs.size() - 1);
Sub<T> last = lastRow.get(lastRow.size() - 1);
return new int[]{last.sum, subs.size()};
}
/**
* Implementation of binary search in a sorted list
*/
public <T> int search(List<? extends Comparable<T>> a, T v, boolean reversed) {
if (a.size() == 0)
return 0;
int sign = reversed ? -1 : 1;
int right = a.size() - 1;
Comparable<T> vRight = a.get(right);
if (vRight.compareTo(v) * sign < 0)
return right + 1;
int left = 0;
int pos = 0;
Comparable<T> vPos;
Comparable<T> vLeft = a.get(left);
for(;;) {
if (right - left <= 1) {
if (vRight.compareTo(v) * sign >= 0 && vLeft.compareTo(v) * sign < 0)
return right;
else
return left;
}
pos = (left + right) >>> 1;
vPos = a.get(pos);
if (vPos.equals(v)) {
return pos;
} else if (vPos.compareTo(v) * sign > 0) {
right = pos;
vRight = vPos;
} else {
left = pos;
vLeft = vPos;
}
}
}
/**
* Class for 'sub' pairs
*/
public static class Sub<T extends Comparable<T>> implements Comparable<Sub<T>> {
T value;
int sum;
public Sub(T value, int sum) {
this.value = value;
this.sum = sum;
}
@Override public String toString() {
return String.format("(%s, %s)", value, sum);
}
@Override public int compareTo(Sub<T> another) {
return this.value.compareTo(another.value);
}
}
Explanation
As my explanation seems to be long, I will call initial sequence "seq" and any its subsequence "sub". So the task is to calculate count of longest increasing subs that can be obtained from the seq.
As I mentioned before, idea is to keep counts of all possible longest subs obtained on previous steps. So let's create a numbered list of rows, where number of each line equals the length of subs stored in this row. And let's store subs as pairs of numbers (v, c), where "v" is "value" of ending element, "c" is "count" of subs of given length that end by "v". For example:
1: (16, 1) // that means that so far we have 1 sub of length 1 which ends by 16.
We will build such list step by step, taking elements from initial sequence by their order. On every step we will try to add this element to the longest sub that it can be added to and record changes.
Building a list
Let's build the list using sequence from your example, since it has all possible options:
16 5 8 6 1 10 5 2 15 3 2 4 1
First, take element 16. Our list is empty so far, so we just put one pair in it:
1: (16, 1) <= one sub that ends by 16
Next is 5. It cannot be added to a sub that ends by 16, so it will create new sub with length of 1. We create a pair (5, 1) and put it into line 1:
1: (16, 1)(5, 1)
Element 8 is coming next. It cannot create the sub [16, 8] of length 2, but can create the sub [5, 8]. So, this is where algorithm is coming. First, we iterate the list rows upside down, looking at the "values" of last pair. If our element is greater than values of all last elements in all rows, then we can add it to existing sub(s), increasing its length by one. So value 8 will create new row of the list, because it is greater than values all last elements existing in the list so far (i. e. > 5):
1: (16, 1)(5, 1)
2: (8, ?) <=== need to resolve how many longest subs ending by 8 can be obtained
Element 8 can continue 5, but cannot continue 16. So we need to search through previous row, starting from its end, calculating the sum of "counts" in pairs which "value" is less than 8:
(16, 1)(5, 1)^ // sum = 0
(16, 1)^(5, 1) // sum = 1
^(16, 1)(5, 1) // value 16 >= 8: stop. count = sum = 1, so write 1 in pair next to 8
1: (16, 1)(5, 1)
2: (8, 1) <=== so far we have 1 sub of length 2 which ends by 8.
Why don't we store value 8 into subs of length 1 (first line)? Because we need subs of maximum possible length, and 8 can continue some previous subs. So every next number greater than 8 will also continue such sub and there is no need to keep 8 as sub of length less that it can be.
Next. 6. Searching upside down by last "values" in rows:
1: (16, 1)(5, 1) <=== 5 < 6, go next
2: (8, 1)
1: (16, 1)(5, 1)
2: (8, 1 ) <=== 8 >= 6, so 6 should be put here
Found the room for 6, need to calculate a count:
take previous line
(16, 1)(5, 1)^ // sum = 0
(16, 1)^(5, 1) // 5 < 6: sum = 1
^(16, 1)(5, 1) // 16 >= 6: stop, write count = sum = 1
1: (16, 1)(5, 1)
2: (8, 1)(6, 1)
After processing 1:
1: (16, 1)(5, 1)(1, 1) <===
2: (8, 1)(6, 1)
After processing 10:
1: (16, 1)(5, 1)(1, 1)
2: (8, 1)(6, 1)
3: (10, 2) <=== count is 2 because both "values" 8 and 6 from previous row are less than 10, so we summarized their "counts": 1 + 1
After processing 5:
1: (16, 1)(5, 1)(1, 1)
2: (8, 1)(6, 1)(5, 1) <===
3: (10, 2)
After processing 2:
1: (16, 1)(5, 1)(1, 1)
2: (8, 1)(6, 1)(5, 1)(2, 1) <===
3: (10, 2)
After processing 15:
1: (16, 1)(5, 1)(1, 1)
2: (8, 1)(6, 1)(5, 1)(2, 1)
3: (10, 2)
4: (15, 2) <===
After processing 3:
1: (16, 1)(5, 1)(1, 1)
2: (8, 1)(6, 1)(5, 1)(2, 1)
3: (10, 2)(3, 1) <===
4: (15, 2)
After processing 2:
1: (16, 1)(5, 1)(1, 1)
2: (8, 1)(6, 1)(5, 1)(2, 2) <===
3: (10, 2)(3, 1)
4: (15, 2)
If when searching rows by last element we find equal element, we calculate its "count" again based on previous row, and add to existing "count".
After processing 4:
1: (16, 1)(5, 1)(1, 1)
2: (8, 1)(6, 1)(5, 1)(2, 2)
3: (10, 2)(3, 1)
4: (15, 2)(4, 1) <===
After processing 1:
1: (16, 1)(5, 1)(1, 2) <===
2: (8, 1)(6, 1)(5, 1)(2, 2)
3: (10, 2)(3, 1)
4: (15, 2)(4, 1)
So what do we have after processing all initial sequence? Looking at the last row, we see that we have 3 longest subs, each consist of 4 elements: 2 end by 15 and 1 ends by 4.
What about complexity?
On every iteration, when taking next element from initial sequence, we make 2 loops: first when iterating rows to find room for next element, and second when summarizing counts in previous row. So for every element we make maximum to n iterations (worst cases: if initial seq consists of elements in increasing order, we will get a list of n rows with 1 pair in every row; if seq is sorted in descending order, we will obtain list of 1 row with n elements). By the way, O(n2) complexity is not what we want.
First, this is obvious, that in every intermediate state rows are sorted by increasing order of their last "value". So instead of brute loop, binary searching can be performed, which complexity is O(log n).
Second, we don't need to summarize "counts" of subs by looping through row elements every time. We can summarize them in process, when new pair is added to the row, like:
1: (16, 1)(5, 2) <=== instead of 1, put 1 + "count" of previous element in the row
So second number will show not count of longest subs that can be obtained with given value at the end, but summary count of all longest subs that end by any element that is greater or equal to "value" from the pair.
Thus, "counts" will be replaced by "sums". And instead of iterating elements in previous row, we just perform binary search (it is possible because pairs in any row are always ordered by their "values") and take "sum" for new pair as "sum" of last element in previous row minus "sum" from element left to found position in previous row plus "sum" of previous element in the current row.
So when processing 4:
1: (16, 1)(5, 2)(1, 3)
2: (8, 1)(6, 2)(5, 3)(2, 5)
3: (10, 2)(3, 3)
4: (15, 2) <=== room for (4, ?)
search in row 3 by "values" < 4:
3: (10, 2)^(3, 3)
4 will be paired with (3-2+2): ("sum" from the last pair of previous row) - ("sum" from pair left to found position in previous row) + ("sum" from previous pair in current row):
4: (15, 2)(4, 3)
In this case, final count of all longest subs is "sum" from the last pair of the last row of the list, i. e. 3, not 3 + 2.
So, performing binary search to both row search and sum search, we will come with O(n*log n) complexity.
What about memory consumed, after processing all array we obtain maximum n pairs, so memory consumption in case of dynamic arrays will be O(n). Besides, when using dynamic arrays or collections, some additional time is needed to allocate and resize them, but most operations are made in O(1) time because we don't make any kind of sorting and rearrangement during process. So complexity estimation seems to be final.
Cpp implementation of above logic:
#include<bits/stdc++.h>
using namespace std;
#define pb push_back
#define pob pop_back
#define pll pair<ll, ll>
#define pii pair<int, int>
#define ll long long
#define ull unsigned long long
#define fori(a,b) for(i=a;i<b;i++)
#define forj(a,b) for(j=a;j<b;j++)
#define fork(a,b) for(k=a;k<b;k++)
#define forl(a,b) for(l=a;l<b;l++)
#define forir(a,b) for(i=a;i>=b;i--)
#define forjr(a,b) for(j=a;j>=b;j--)
#define mod 1000000007
#define boost std::ios::sync_with_stdio(false)
struct comp_pair_int_rev
{
bool operator()(const pair<int,int> &a, const int & b)
{
return (a.first > b);
}
bool operator()(const int & a,const pair<int,int> &b)
{
return (a > b.first);
}
};
struct comp_pair_int
{
bool operator()(const pair<int,int> &a, const int & b)
{
return (a.first < b);
}
bool operator()(const int & a,const pair<int,int> &b)
{
return (a < b.first);
}
};
int main()
{
int n,i,mx=0,p,q,r,t;
cin>>n;
int a[n];
vector<vector<pii > > v(100005);
vector<pii > v1(100005);
fori(0,n)
cin>>a[i];
v[1].pb({a[0], 1} );
v1[1]= {a[0], 1};
mx=1;
fori(1,n)
{
if(a[i]<=v1[1].first)
{
r=v1[1].second;
if(v1[1].first==a[i])
v[1].pob();
v1[1]= {a[i], r+1};
v[1].pb({a[i], r+1});
}
else if(a[i]>v1[mx].first)
{
q=upper_bound(v[mx].begin(), v[mx].end(), a[i], comp_pair_int_rev() )-v[mx].begin();
if(q==0)
{
r=v1[mx].second;
}
else
{
r=v1[mx].second-v[mx][q-1].second;
}
v1[++mx]= {a[i], r};
v[mx].pb({a[i], r});
}
else if(a[i]==v1[mx].first)
{
q=upper_bound(v[mx-1].begin(), v[mx-1].end(), a[i], comp_pair_int_rev() )-v[mx-1].begin();
if(q==0)
{
r=v1[mx-1].second;
}
else
{
r=v1[mx-1].second-v[mx-1][q-1].second;
}
p=v1[mx].second;
v1[mx]= {a[i], p+r};
v[mx].pob();
v[mx].pb({a[i], p+r});
}
else
{
p=lower_bound(v1.begin()+1, v1.begin()+mx+1, a[i], comp_pair_int() )-v1.begin();
t=v1[p].second;
if(v1[p].first==a[i])
{
v[p].pob();
}
q=upper_bound(v[p-1].begin(), v[p-1].end(), a[i], comp_pair_int_rev() )-v[p-1].begin();
if(q==0)
{
r=v1[p-1].second;
}
else
{
r=v1[p-1].second-v[p-1][q-1].second;
}
v1[p]= {a[i], t+r};
v[p].pb({a[i], t+r});
}
}
cout<<v1[mx].second;
return 0;
}
Although I completely agree with Alex this can be done very easily using Segment tree.
Here is the logic to find the length of LIS using segment tree in NlogN.
https://www.quora.com/What-is-the-approach-to-find-the-length-of-the-strictly-increasing-longest-subsequence
Here is an approach that finds no of LIS but takes N^2 complexity.
https://codeforces.com/blog/entry/48677
We use segment tree(as used here) to optimize approach given in this.
Here is the logic:
first sort the array in ascending order(also keep the original order), initialise segment tree with zeroes, segment tree should query two things(use pair for this) for a given range:
a. max of first.
b. sum of second corresponding to max-first.
iterate through sorted array.
let j be the original index of current element, then we query (0 - j-1) and update the j-th element(if result of query is 0,0 then we update it with (1,1)).
Here is my code in c++:
#include<bits/stdc++.h>
#define tr(container, it) for(typeof(container.begin()) it = container.begin(); it != container.end(); it++)
#define ll long long
#define pb push_back
#define endl '\n'
#define pii pair<ll int,ll int>
#define vi vector<ll int>
#define all(a) (a).begin(),(a).end()
#define F first
#define S second
#define sz(x) (ll int)x.size()
#define hell 1000000007
#define rep(i,a,b) for(ll int i=a;i<b;i++)
#define lbnd lower_bound
#define ubnd upper_bound
#define bs binary_search
#define mp make_pair
using namespace std;
#define N 100005
ll max(ll a , ll b)
{
if( a > b) return a ;
else return
b;
}
ll n,l,r;
vector< pii > seg(4*N);
pii query(ll cur,ll st,ll end,ll l,ll r)
{
if(l<=st&&r>=end)
return seg[cur];
if(r<st||l>end)
return mp(0,0); /* 2-change here */
ll mid=(st+end)>>1;
pii ans1=query(2*cur,st,mid,l,r);
pii ans2=query(2*cur+1,mid+1,end,l,r);
if(ans1.F>ans2.F)
return ans1;
if(ans2.F>ans1.F)
return ans2;
return make_pair(ans1.F,ans2.S+ans1.S); /* 3-change here */
}
void update(ll cur,ll st,ll end,ll pos,ll upd1, ll upd2)
{
if(st==end)
{
// a[pos]=upd; /* 4-change here */
seg[cur].F=upd1;
seg[cur].S=upd2; /* 5-change here */
return;
}
ll mid=(st+end)>>1;
if(st<=pos&&pos<=mid)
update(2*cur,st,mid,pos,upd1,upd2);
else
update(2*cur+1,mid+1,end,pos,upd1,upd2);
seg[cur].F=max(seg[2*cur].F,seg[2*cur+1].F);
if(seg[2*cur].F==seg[2*cur+1].F)
seg[cur].S = seg[2*cur].S+seg[2*cur+1].S;
else
{
if(seg[2*cur].F>seg[2*cur+1].F)
seg[cur].S = seg[2*cur].S;
else
seg[cur].S = seg[2*cur+1].S;
/* 6-change here */
}
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int TESTS=1;
// cin>>TESTS;
while(TESTS--)
{
int n ;
cin >> n;
vector< pii > arr(n);
rep(i,0,n)
{
cin >> arr[i].F;
arr[i].S = -i;
}
sort(all(arr));
update(1,0,n-1,-arr[0].S,1,1);
rep(i,1,n)
{
pii x = query(1,0,n-1,-1,-arr[i].S - 1 );
update(1,0,n-1,-arr[i].S,x.F+1,max(x.S,1));
}
cout<<seg[1].S;//answer
}
return 0;
}