如何在两个排序数组的并集中找到 kth 最小元素?

这是一个家庭作业问题,二进制搜索已经被引入:

Given two arrays, respectively N and M elements in ascending order, not necessarily unique:
什么是时间有效的算法,以找到 K最小的元素在两个数组的联合?

They say it takes O(logN + logM) where N and M are the arrays lengths.

Let's name the arrays a and b. Obviously we can ignore all a[i] and b[i] where i > k.
First let's compare a[k/2] and b[k/2]. Let b[k/2] > a[k/2]. Therefore we can discard also all b[i], where i > k/2.

现在我们有了所有的 a[i],其中 i < k 和所有的 b[i],其中 i < k/2可以找到答案。

下一步是什么?

99109 次浏览

没问题,继续,小心索引。

To simplify a bit I'll assume that N and M are > k, so the complexity here is O(log k), which is O(log N + log M).

伪代码:

i = k/2
j = k - i
step = k/4
while step > 0
if a[i-1] > b[j-1]
i -= step
j += step
else
i += step
j -= step
step /= 2


if a[i-1] > b[j-1]
return a[i-1]
else
return b[j-1]

在演示中,您可以使用循环不变式 i + j = k,但我不会完成您所有的功课:)

以下是我根据 Jules Olleon 的解决方案编写的代码:

int getNth(vector<int>& v1, vector<int>& v2, int n)
{
int step = n / 4;


int i1 = n / 2;
int i2 = n - i1;


while(!(v2[i2] >= v1[i1 - 1] && v1[i1] > v2[i2 - 1]))
{
if (v1[i1 - 1] >= v2[i2 - 1])
{
i1 -= step;
i2 += step;
}
else
{
i1 += step;
i2 -= step;
}


step /= 2;
if (!step) step = 1;
}


if (v1[i1 - 1] >= v2[i2 - 1])
return v1[i1 - 1];
else
return v2[i2 - 1];
}


int main()
{
int a1[] = {1,2,3,4,5,6,7,8,9};
int a2[] = {4,6,8,10,12};


//int a1[] = {1,2,3,4,5,6,7,8,9};
//int a2[] = {4,6,8,10,12};


//int a1[] = {1,7,9,10,30};
//int a2[] = {3,5,8,11};
vector<int> v1(a1, a1+9);
vector<int> v2(a2, a2+5);




cout << getNth(v1, v2, 5);
return 0;
}

我希望我没有回答你们的作业,因为这个问题已经问了一年多了。下面是一个需要花费 Log (len (a) + len (b))时间的尾递归解决方案。

假设: 输入是正确的,也就是说,K[0,len (a) + len (b)]范围内。

基本情况:

  • 如果其中一个数组的长度为0,则答案是第二个数组的 Kth 元素。

削减步骤:

  • + B中值指数小于 K:
    • If mid element of a is greater than mid element of B, we can ignore the first half of B, adjust K.
    • 否则,忽略 的前半部分,调整 K
  • 如果 K小于 b的中间指数之和:
    • 如果 的中间元素大于 B的中间元素,我们可以安全地忽略 的后半部分。
    • 否则,我们可以忽略 b的后半部分。

密码:

def kthlargest(arr1, arr2, k):
if len(arr1) == 0:
return arr2[k]
elif len(arr2) == 0:
return arr1[k]


mida1 = len(arr1) // 2  # integer division
mida2 = len(arr2) // 2
if mida1 + mida2 < k:
if arr1[mida1] > arr2[mida2]:
return kthlargest(arr1, arr2[mida2+1:], k - mida2 - 1)
else:
return kthlargest(arr1[mida1+1:], arr2, k - mida1 - 1)
else:
if arr1[mida1] > arr2[mida2]:
return kthlargest(arr1[:mida1], arr2, k)
else:
return kthlargest(arr1, arr2[:mida2], k)

请注意,我的解决方案是在每次调用中创建较小数组的新副本,这可以很容易地通过传递原始数组的开始和结束索引来消除。

检查这个代码。

import math
def findkthsmallest():


A=[1,5,10,22,30,35,75,125,150,175,200]
B=[15,16,20,22,25,30,100,155,160,170]
lM=0
lN=0
hM=len(A)-1
hN=len(B)-1
k=17


while True:
if k==1:
return min(A[lM],B[lN])




cM=hM-lM+1
cN=hN-lN+1
tmp = cM/float(cM+cN)
iM=int(math.ceil(tmp*k))
iN=k-iM
iM=lM+iM-1
iN=lN+iN-1
if A[iM] >= B[iN]:
if iN == hN or A[iM] < B[iN+1]:
return A[iM]
else:
k = k - (iN-lN+1)
lN=iN+1
hM=iM-1
if B[iN] >= A[iM]:
if iM == hM or B[iN] < A[iM+1]:
return B[iN]
else:
k = k - (iM-lM+1)
lM=iM+1
hN=iN-1
if hM < lM:
return B[lN+k-1]
if hN < lN:
return A[lM+k-1]


if __name__ == '__main__':
print findkthsmallest();

Here's a C++ iterative version of @ lambdapilgrim 的解决方案 (see the explanation of the algorithm there):

#include <cassert>
#include <iterator>


template<class RandomAccessIterator, class Compare>
typename std::iterator_traits<RandomAccessIterator>::value_type
nsmallest_iter(RandomAccessIterator firsta, RandomAccessIterator lasta,
RandomAccessIterator firstb, RandomAccessIterator lastb,
size_t n,
Compare less) {
assert(issorted(firsta, lasta, less) && issorted(firstb, lastb, less));
for ( ; ; ) {
assert(n < static_cast<size_t>((lasta - firsta) + (lastb - firstb)));
if (firsta == lasta) return *(firstb + n);
if (firstb == lastb) return *(firsta + n);


size_t mida = (lasta - firsta) / 2;
size_t midb = (lastb - firstb) / 2;
if ((mida + midb) < n) {
if (less(*(firstb + midb), *(firsta + mida))) {
firstb += (midb + 1);
n -= (midb + 1);
}
else {
firsta += (mida + 1);
n -= (mida + 1);
}
}
else {
if (less(*(firstb + midb), *(firsta + mida)))
lasta = (firsta + mida);
else
lastb = (firstb + midb);
}
}
}

它适用于所有 0 <= n < (size(a) + size(b))索引,并且具有 O(log(size(a)) + log(size(b)))复杂性。

例子

#include <functional> // greater<>
#include <iostream>


#define SIZE(a) (sizeof(a) / sizeof(*a))


int main() {
int a[] = {5,4,3};
int b[] = {2,1,0};
int k = 1; // find minimum value, the 1st smallest value in a,b


int i = k - 1; // convert to zero-based indexing
int v = nsmallest_iter(a, a + SIZE(a), b, b + SIZE(b),
SIZE(a)+SIZE(b)-1-i, std::greater<int>());
std::cout << v << std::endl; // -> 0
return v;
}

我尝试在2个排序数组和 n 个排序数组中获得第一个 k 数、 kth 数:

// require() is recognizable by node.js but not by browser;
// for running/debugging in browser, put utils.js and this file in <script> elements,
if (typeof require === "function") require("./utils.js");


// Find K largest numbers in two sorted arrays.
function k_largest(a, b, c, k) {
var sa = a.length;
var sb = b.length;
if (sa + sb < k) return -1;
var i = 0;
var j = sa - 1;
var m = sb - 1;
while (i < k && j >= 0 && m >= 0) {
if (a[j] > b[m]) {
c[i] = a[j];
i++;
j--;
} else {
c[i] = b[m];
i++;
m--;
}
}
debug.log(2, "i: "+ i + ", j: " + j + ", m: " + m);
if (i === k) {
return 0;
} else if (j < 0) {
while (i < k) {
c[i++] = b[m--];
}
} else {
while (i < k) c[i++] = a[j--];
}
return 0;
}


// find k-th largest or smallest number in 2 sorted arrays.
function kth(a, b, kd, dir){
sa = a.length; sb = b.length;
if (kd<1 || sa+sb < kd){
throw "Mission Impossible! I quit!";
}


var k;
//finding the kd_th largest == finding the smallest k_th;
if (dir === 1){ k = kd;
} else if (dir === -1){ k = sa + sb - kd + 1;}
else throw "Direction has to be 1 (smallest) or -1 (largest).";


return find_kth(a, b, k, sa-1, 0, sb-1, 0);
}


// find k-th smallest number in 2 sorted arrays;
function find_kth(c, d, k, cmax, cmin, dmax, dmin){


sc = cmax-cmin+1; sd = dmax-dmin+1; k0 = k; cmin0 = cmin; dmin0 = dmin;
debug.log(2, "=k: " + k +", sc: " + sc + ", cmax: " + cmax +", cmin: " + cmin + ", sd: " + sd +", dmax: " + dmax + ", dmin: " + dmin);


c_comp = k0-sc;
if (c_comp <= 0){
cmax = cmin0 + k0-1;
} else {
dmin = dmin0 + c_comp-1;
k -= c_comp-1;
}


d_comp = k0-sd;
if (d_comp <= 0){
dmax = dmin0 + k0-1;
} else {
cmin = cmin0 + d_comp-1;
k -= d_comp-1;
}
sc = cmax-cmin+1; sd = dmax-dmin+1;


debug.log(2, "#k: " + k +", sc: " + sc + ", cmax: " + cmax +", cmin: " + cmin + ", sd: " + sd +", dmax: " + dmax + ", dmin: " + dmin + ", c_comp: " + c_comp + ", d_comp: " + d_comp);


if (k===1) return (c[cmin]<d[dmin] ? c[cmin] : d[dmin]);
if (k === sc+sd) return (c[cmax]>d[dmax] ? c[cmax] : d[dmax]);


m = Math.floor((cmax+cmin)/2);
n = Math.floor((dmax+dmin)/2);


debug.log(2, "m: " + m + ", n: "+n+", c[m]: "+c[m]+", d[n]: "+d[n]);


if (c[m]<d[n]){
if (m === cmax){ // only 1 element in c;
return d[dmin+k-1];
}


k_next = k-(m-cmin+1);
return find_kth(c, d, k_next, cmax, m+1, dmax, dmin);
} else {
if (n === dmax){
return c[cmin+k-1];
}


k_next = k-(n-dmin+1);
return find_kth(c, d, k_next, cmax, cmin, dmax, n+1);
}
}


function traverse_at(a, ae, h, l, k, at, worker, wp){
var n = ae ? ae.length : 0;
var get_node;
switch (at){
case "k": get_node = function(idx){
var node = {};
var pos = l[idx] + Math.floor(k/n) - 1;
if (pos<l[idx]){ node.pos = l[idx]; }
else if (pos > h[idx]){ node.pos = h[idx];}
else{ node.pos = pos; }


node.idx = idx;
node.val = a[idx][node.pos];
debug.log(6, "pos: "+pos+"\nnode =");
debug.log(6, node);
return node;
};
break;
case "l": get_node = function(idx){
debug.log(6, "a["+idx+"][l["+idx+"]]: "+a[idx][l[idx]]);
return a[idx][l[idx]];
};
break;
case "h": get_node = function(idx){
debug.log(6, "a["+idx+"][h["+idx+"]]: "+a[idx][h[idx]]);
return a[idx][h[idx]];
};
break;
case "s": get_node = function(idx){
debug.log(6, "h["+idx+"]-l["+idx+"]+1: "+(h[idx] - l[idx] + 1));
return h[idx] - l[idx] + 1;
};
break;
default: get_node = function(){
debug.log(1, "!!! Exception: get_node() returns null.");
return null;
};
break;
}


worker.init();


debug.log(6, "--* traverse_at() *--");


var i;
if (!wp){
for (i=0; i<n; i++){
worker.work(get_node(ae[i]));
}
} else {
for (i=0; i<n; i++){
worker.work(get_node(ae[i]), wp);
}
}


return worker.getResult();
}


sumKeeper = function(){
var res = 0;
return {
init     : function(){ res = 0;},
getResult: function(){
debug.log(5, "@@ sumKeeper.getResult: returning: "+res);
return res;
},
work     : function(node){ if (node!==null) res += node;}
};
}();


maxPicker = function(){
var res = null;
return {
init     : function(){ res = null;},
getResult: function(){
debug.log(5, "@@ maxPicker.getResult: returning: "+res);
return res;
},
work     : function(node){
if (res === null){ res = node;}
else if (node!==null && node > res){ res = node;}
}
};
}();


minPicker = function(){
var res = null;
return {
init     : function(){ res = null;},
getResult: function(){
debug.log(5, "@@ minPicker.getResult: returning: ");
debug.log(5, res);
return res;
},
work     : function(node){
if (res === null && node !== null){ res = node;}
else if (node!==null &&
node.val !==undefined &&
node.val < res.val){ res = node; }
else if (node!==null && node < res){ res = node;}
}
};
}();


// find k-th smallest number in n sorted arrays;
// need to consider the case where some of the subarrays are taken out of the selection;
function kth_n(a, ae, k, h, l){
var n = ae.length;
debug.log(2, "------**  kth_n()  **-------");
debug.log(2, "n: " +n+", k: " + k);
debug.log(2, "ae: ["+ae+"],  len: "+ae.length);
debug.log(2, "h: [" + h + "]");
debug.log(2, "l: [" + l + "]");


for (var i=0; i<n; i++){
if (h[ae[i]]-l[ae[i]]+1>k) h[ae[i]]=l[ae[i]]+k-1;
}
debug.log(3, "--after reduction --");
debug.log(3, "h: [" + h + "]");
debug.log(3, "l: [" + l + "]");


if (n === 1)
return a[ae[0]][k-1];
if (k === 1)
return traverse_at(a, ae, h, l, k, "l", minPicker);
if (k === traverse_at(a, ae, h, l, k, "s", sumKeeper))
return traverse_at(a, ae, h, l, k, "h", maxPicker);


var kn = traverse_at(a, ae, h, l, k, "k", minPicker);
debug.log(3, "kn: ");
debug.log(3, kn);


var idx = kn.idx;
debug.log(3, "last: k: "+k+", l["+kn.idx+"]: "+l[idx]);
k -= kn.pos - l[idx] + 1;
l[idx] = kn.pos + 1;
debug.log(3, "next: "+"k: "+k+", l["+kn.idx+"]: "+l[idx]);
if (h[idx]<l[idx]){ // all elements in a[idx] selected;
//remove a[idx] from the arrays.
debug.log(4, "All elements selected in a["+idx+"].");
debug.log(5, "last ae: ["+ae+"]");
ae.splice(ae.indexOf(idx), 1);
h[idx] = l[idx] = "_"; // For display purpose only.
debug.log(5, "next ae: ["+ae+"]");
}


return kth_n(a, ae, k, h, l);
}


function find_kth_in_arrays(a, k){


if (!a || a.length<1 || k<1) throw "Mission Impossible!";


var ae=[], h=[], l=[], n=0, s, ts=0;
for (var i=0; i<a.length; i++){
s = a[i] && a[i].length;
if (s>0){
ae.push(i); h.push(s-1); l.push(0);
ts+=s;
}
}


if (k>ts) throw "Too few elements to choose from!";


return kth_n(a, ae, k, h, l);
}


/////////////////////////////////////////////////////
// tests
// To show everything: use 6.
debug.setLevel(1);


var a = [2, 3, 5, 7, 89, 223, 225, 667];
var b = [323, 555, 655, 673];
//var b = [99];
var c = [];


debug.log(1, "a = (len: " + a.length + ")");
debug.log(1, a);
debug.log(1, "b = (len: " + b.length + ")");
debug.log(1, b);


for (var k=1; k<a.length+b.length+1; k++){
debug.log(1, "================== k: " + k + "=====================");


if (k_largest(a, b, c, k) === 0 ){
debug.log(1, "c = (len: "+c.length+")");
debug.log(1, c);
}


try{
result = kth(a, b, k, -1);
debug.log(1, "===== The " + k + "-th largest number: " + result);
} catch (e) {
debug.log(0, "Error message from kth(): " + e);
}
debug.log("==================================================");
}


debug.log(1, "################# Now for the n sorted arrays ######################");
debug.log(1, "####################################################################");


x = [[1, 3, 5, 7, 9],
[-2, 4, 6, 8, 10, 12],
[8, 20, 33, 212, 310, 311, 623],
[8],
[0, 100, 700],
[300],
[],
null];


debug.log(1, "x = (len: "+x.length+")");
debug.log(1, x);


for (var i=0, num=0; i<x.length; i++){
if (x[i]!== null) num += x[i].length;
}
debug.log(1, "totoal number of elements: "+num);


// to test k in specific ranges:
var start = 0, end = 25;
for (k=start; k<end; k++){
debug.log(1, "=========================== k: " + k + "===========================");


try{
result = find_kth_in_arrays(x, k);
debug.log(1, "====== The " + k + "-th smallest number: " + result);
} catch (e) {
debug.log(1, "Error message from find_kth_in_arrays: " + e);
}
debug.log(1, "=================================================================");
}
debug.log(1, "x = (len: "+x.length+")");
debug.log(1, x);
debug.log(1, "totoal number of elements: "+num);

The complete code with debug utils can be found at: https://github.com/brainclone/teasers/tree/master/kth

这是我在 C 语言中的实现,你可以参考@Jules Olléon 对算法的解释: 算法背后的思想是,我们维护 i + j = k,并找到这样的 i 和 j,以便 a [ i-1] < b [ j-1] < a [ i ](或反过来)。既然‘ a’中的 i 元素小于 b [ j-1] ,‘ b’中的 j-1元素小于 b [ j-1] ,那么 b [ j-1]就是 i + j-1 + 1 = kth 最小元素。为了找到这样的 i,j 算法对数组进行二分搜索。

int find_k(int A[], int m, int B[], int n, int k) {
if (m <= 0 )return B[k-1];
else if (n <= 0) return A[k-1];
int i =  ( m/double (m + n))  * (k-1);
if (i < m-1 && i<k-1) ++i;
int j = k - 1 - i;


int Ai_1 = (i > 0) ? A[i-1] : INT_MIN, Ai = (i<m)?A[i]:INT_MAX;
int Bj_1 = (j > 0) ? B[j-1] : INT_MIN, Bj = (j<n)?B[j]:INT_MAX;
if (Ai >= Bj_1 && Ai <= Bj) {
return Ai;
} else if (Bj >= Ai_1 && Bj <= Ai) {
return Bj;
}
if (Ai < Bj_1) { // the answer can't be within A[0,...,i]
return find_k(A+i+1, m-i-1, B, n, j);
} else { // the answer can't be within A[0,...,i]
return find_k(A, m, B+j+1, n-j-1, i);
}
}

这是我的解决办法。C + + 代码打印出 kth 最小值以及使用循环获得 kth 最小值的迭代次数,我认为循环的次序是 log (k)。然而,代码要求 k 小于第一个数组的长度,这是一个限制。

#include <iostream>
#include <vector>
#include<math.h>
using namespace std;


template<typename comparable>
comparable kthSmallest(vector<comparable> & a, vector<comparable> & b, int k){


int idx1; // Index in the first array a
int idx2; // Index in the second array b
comparable maxVal, minValPlus;
float iter = k;
int numIterations = 0;


if(k > a.size()){ // Checks if k is larger than the size of first array
cout << " k is larger than the first array" << endl;
return -1;
}
else{ // If all conditions are satisfied, initialize the indexes
idx1 = k - 1;
idx2 = -1;
}


for ( ; ; ){
numIterations ++;
if(idx2 == -1 || b[idx2] <= a[idx1] ){
maxVal = a[idx1];
minValPlus = b[idx2 + 1];
idx1 = idx1 - ceil(iter/2); // Binary search
idx2 = k - idx1 - 2; // Ensures sum of indices  = k - 2
}
else{
maxVal = b[idx2];
minValPlus = a[idx1 + 1];
idx2 = idx2 - ceil(iter/2); // Binary search
idx1 = k - idx2 - 2; // Ensures sum of indices  = k - 2
}
if(minValPlus >= maxVal){ // Check if kth smallest value has been found
cout << "The number of iterations to find the " << k << "(th) smallest value is    " << numIterations << endl;
return maxVal;


}
else
iter/=2; // Reduce search space of binary search
}
}


int main(){
//Test Cases
vector<int> a = {2, 4, 9, 15, 22, 34, 45, 55, 62, 67, 78, 85};
vector<int> b = {1, 3, 6, 8, 11, 13, 15, 20, 56, 67, 89};
// Input k < a.size()
int kthSmallestVal;
for (int k = 1; k <= a.size() ; k++){
kthSmallestVal = kthSmallest<int>( a ,b ,k );
cout << k <<" (th) smallest Value is " << kthSmallestVal << endl << endl << endl;
}
}

在 C # 代码下面找到两个排序数组合并中的第 k 个最小元素。时间复杂度: O (logk)

        public static int findKthSmallestElement1(int[] A, int startA, int endA, int[] B, int startB, int endB, int k)
{
int n = endA - startA;
int m = endB - startB;


if (n <= 0)
return B[startB + k - 1];
if (m <= 0)
return A[startA + k - 1];
if (k == 1)
return A[startA] < B[startB] ? A[startA] : B[startB];


int midA = (startA + endA) / 2;
int midB = (startB + endB) / 2;


if (A[midA] <= B[midB])
{
if (n / 2 + m / 2 + 1 >= k)
return findKthSmallestElement1(A, startA, endA, B, startB, midB, k);
else
return findKthSmallestElement1(A, midA + 1, endA, B, startB, endB, k - n / 2 - 1);
}
else
{
if (n / 2 + m / 2 + 1 >= k)
return findKthSmallestElement1(A, startA, midA, B, startB, endB, k);
else
return findKthSmallestElement1(A, startA, endA, B, midB + 1, endB, k - m / 2 - 1);


}
}

许多人回答这个“两个排序数组中的第 k 个最小元素”的问题,但通常只有一般的想法,没有一个明确的工作代码或边界条件分析。

在这里,我想用我的正确的 Java 代码来帮助一些新手理解它。A1A2是两个排序的升序数组,分别以 size1size2作为长度。我们需要从这两个数组的并集中找到 k 个最小的元素。这里我们合理地假设 (k > 0 && k <= size1 + size2),这意味着 A1A2不能同时为空。

首先,让我们用一个缓慢的 O (k)算法来处理这个问题。该方法是比较两个数组的第一个元素 A1[0]A2[0]。把小的那个,说 A1[0]放到我们的口袋里。然后比较 A1[1]A2[0],以此类推。重复这个动作,直到我们的口袋达到 k元素。非常重要: 在第一步,我们只能承诺在我们的口袋 A1[0]。我们不能包括或排除 A2[0]! ! !

The following O(k) code gives you one element before the correct answer. Here I use it to show my idea, and analysis boundary condition. I have correct code after this one:

private E kthSmallestSlowWithFault(int k) {
int size1 = A1.length, size2 = A2.length;


int index1 = 0, index2 = 0;
// base case, k == 1
if (k == 1) {
if (size1 == 0) {
return A2[index2];
} else if (size2 == 0) {
return A1[index1];
} else if (A1[index1].compareTo(A2[index2]) < 0) {
return A1[index1];
} else {
return A2[index2];
}
}


/* in the next loop, we always assume there is one next element to compare with, so we can
* commit to the smaller one. What if the last element is the kth one?
*/
if (k == size1 + size2) {
if (size1 == 0) {
return A2[size2 - 1];
} else if (size2 == 0) {
return A1[size1 - 1];
} else if (A1[size1 - 1].compareTo(A2[size2 - 1]) < 0) {
return A1[size1 - 1];
} else {
return A2[size2 - 1];
}
}


/*
* only when k > 1, below loop will execute. In each loop, we commit to one element, till we
* reach (index1 + index2 == k - 1) case. But the answer is not correct, always one element
* ahead, because we didn't merge base case function into this loop yet.
*/
int lastElementFromArray = 0;
while (index1 + index2 < k - 1) {
if (A1[index1].compareTo(A2[index2]) < 0) {
index1++;
lastElementFromArray = 1;
// commit to one element from array A1, but that element is at (index1 - 1)!!!
} else {
index2++;
lastElementFromArray = 2;
}
}
if (lastElementFromArray == 1) {
return A1[index1 - 1];
} else {
return A2[index2 - 1];
}
}

最强大的思想是,在每个循环中,我们总是使用基本案例方法。在提交到当前最小的元素之后,我们离目标更近了一步: k-th 最小元素。永远不要跳到中间,让自己困惑和迷失!

通过观察上面的代码基本情况 k == 1, k == size1+size2,并结合 A1A2不能都是空的。我们可以把下面的逻辑变成更简洁的风格。

下面是一个缓慢但正确的工作代码:

private E kthSmallestSlow(int k) {
// System.out.println("this is an O(k) speed algorithm, very concise");
int size1 = A1.length, size2 = A2.length;


int index1 = 0, index2 = 0;
while (index1 + index2 < k - 1) {
if (size1 > index1 && (size2 <= index2 || A1[index1].compareTo(A2[index2]) < 0)) {
index1++; // here we commit to original index1 element, not the increment one!!!
} else {
index2++;
}
}
// below is the (index1 + index2 == k - 1) base case
// also eliminate the risk of referring to an element outside of index boundary
if (size1 > index1 && (size2 <= index2 || A1[index1].compareTo(A2[index2]) < 0)) {
return A1[index1];
} else {
return A2[index2];
}
}

Now we can try a faster algorithm runs at O(log k). Similarly, compare A1[k/2] with A2[k/2]; if A1[k/2] is smaller, then all the elements from A1[0] to A1[k/2] should be in our pocket. The idea is to not just commit to one element in each loop; the first step contains k/2 elements. Again, we can NOT include or exclude A2[0] to A2[k/2] anyway. So in the first step, we can't go more than k/2 elements. For the second step, we can't go more than k/4 elements...

在每一步之后,我们更加接近 k 元素。同时,每一步都变得越来越小,直到我们到达 (step == 1),也就是 (k-1 == index1+index2)。然后我们可以再次参考简单而强大的基本情况。

下面是正确的工作代码:

private E kthSmallestFast(int k) {
// System.out.println("this is an O(log k) speed algorithm with meaningful variables name");
int size1 = A1.length, size2 = A2.length;


int index1 = 0, index2 = 0, step = 0;
while (index1 + index2 < k - 1) {
step = (k - index1 - index2) / 2;
int step1 = index1 + step;
int step2 = index2 + step;
if (size1 > step1 - 1
&& (size2 <= step2 - 1 || A1[step1 - 1].compareTo(A2[step2 - 1]) < 0)) {
index1 = step1; // commit to element at index = step1 - 1
} else {
index2 = step2;
}
}
// the base case of (index1 + index2 == k - 1)
if (size1 > index1 && (size2 <= index2 || A1[index1].compareTo(A2[index2]) < 0)) {
return A1[index1];
} else {
return A2[index2];
}
}

有些人可能会担心,如果 (index1+index2)跳过 k-1?我们会错过基本情况 (k-1 == index1+index2)吗?这不可能。你可以把0.5 + 0.25 + 0.125... 加起来,你永远不会超过1。

当然,将上面的代码转换成递归算法是非常容易的:

private E kthSmallestFastRecur(int k, int index1, int index2, int size1, int size2) {
// System.out.println("this is an O(log k) speed algorithm with meaningful variables name");


// the base case of (index1 + index2 == k - 1)
if (index1 + index2 == k - 1) {
if (size1 > index1 && (size2 <= index2 || A1[index1].compareTo(A2[index2]) < 0)) {
return A1[index1];
} else {
return A2[index2];
}
}


int step = (k - index1 - index2) / 2;
int step1 = index1 + step;
int step2 = index2 + step;
if (size1 > step1 - 1 && (size2 <= step2 - 1 || A1[step1 - 1].compareTo(A2[step2 - 1]) < 0)) {
index1 = step1;
} else {
index2 = step2;
}
return kthSmallestFastRecur(k, index1, index2, size1, size2);
}

Hope the above analysis and Java code could help you to understand. But never copy my code as your homework! Cheers ;)

上面提供的第一个伪代码对许多值都不起作用, 这里有两个数组。 int[] a = { 1, 5, 6, 8, 9, 11, 15, 17, 19 }; Int [] b = {4,7,8,13,15,18,20,24,26} ;

It did not work for k=3 and k=9 in it. I have another solution. It is given below.

private static void traverse(int pt, int len) {
int temp = 0;


if (len == 1) {
int val = 0;
while (k - (pt + 1) - 1 > -1 && M[pt] < N[k - (pt + 1) - 1]) {


if (val == 0)
val = M[pt] < N[k - (pt + 1) - 1] ? N[k - (pt + 1) - 1]
: M[pt];
else {
int t = M[pt] < N[k - (pt + 1) - 1] ? N[k - (pt + 1) - 1]
: M[pt];
val = val < t ? val : t;


}


++pt;
}


if (val == 0)
val = M[pt] < N[k - (pt + 1) - 1] ? N[k - (pt + 1) - 1] : M[pt];


System.out.println(val);
return;
}


temp = len / 2;


if (M[pt + temp - 1] < N[k - (pt + temp) - 1]) {
traverse(pt + temp, temp);


} else {
traverse(pt, temp);
}


}

但是... 对于 k = 5也不起作用。有一个 k 的偶数/奇数捕获,它不会让它变得简单。

public class KthSmallestInSortedArray {


public static void main(String[] args) {
int a1[] = {2, 3, 10, 11, 43, 56},
a2[] = {120, 13, 14, 24, 34, 36},
k = 4;


System.out.println(findKthElement(a1, a2, k));


}


private static int findKthElement(int a1[], int a2[], int k) {


/** Checking k must less than sum of length of both array **/
if (a1.length + a2.length < k) {
throw new IllegalArgumentException();
}


/** K must be greater than zero **/
if (k <= 0) {
throw new IllegalArgumentException();
}


/**
* Finding begin, l and end such that
* begin <= l < end
* a1[0].....a1[l-1] and
* a2[0]....a2[k-l-1] are the smallest k numbers
*/
int begin = Math.max(0, k - a2.length);
int end = Math.min(a1.length, k);


while (begin < end) {
int l = begin + (end - begin) / 2;


/** Can we include a1[l] in the k smallest numbers */
if ((l < a1.length) &&
(k - l > 0) &&
(a1[l] < a2[k - l - 1])) {


begin = l + 1;


} else if ((l > 0) &&
(k - l < a2.length) &&
(a1[l - 1] > a2[k - 1])) {


/**
* This is the case where we can discard
* a[l-1] from the set of k smallest numbers
*/
end = l;


} else {


/**
* We found our answer since both inequalities were
* false
*/
begin = l;
break;
}
}


if (begin == 0) {
return a2[k - 1];
} else if (begin == k) {
return a1[k - 1];
} else {
return Math.max(a1[begin - 1], a2[k - begin - 1]);
}
}
}

Here is mine solution in java . Will try to further optimize it

  public class FindKLargestTwoSortedArray {


public static void main(String[] args) {
int[] arr1 = { 10, 20, 40, 80 };
int[] arr2 = { 15, 35, 50, 75 };


FindKLargestTwoSortedArray(arr1, 0, arr1.length - 1, arr2, 0,
arr2.length - 1, 6);
}




public static void FindKLargestTwoSortedArray(int[] arr1, int start1,
int end1, int[] arr2, int start2, int end2, int k) {


if ((start1 <= end1 && start1 >= 0 && end1 < arr1.length)
&& (start2 <= end2 && start2 >= 0 && end2 < arr2.length)) {


int midIndex1 = (start1 + (k - 1) / 2);
midIndex1 = midIndex1 >= arr1.length ? arr1.length - 1 : midIndex1;
int midIndex2 = (start2 + (k - 1) / 2);
midIndex2 = midIndex2 >= arr2.length ? arr2.length - 1 : midIndex2;




if (arr1[midIndex1] == arr2[midIndex2]) {
System.out.println("element is " + arr1[midIndex1]);
} else if (arr1[midIndex1] < arr2[midIndex2]) {


if (k == 1) {
System.out.println("element is " + arr1[midIndex1]);
return;
} else if (k == 2) {
System.out.println("element is " + arr2[midIndex2]);
return;
}else if (midIndex1 == arr1.length-1 || midIndex2 == arr2.length-1 ) {
if(k==(arr1.length+arr2.length)){
System.out.println("element is " + arr2[midIndex2]);
return;
}else if(k==(arr1.length+arr2.length)-1){
System.out.println("element is " + arr1[midIndex1]);
return;
}


}


int remainingElementToSearch = k - (midIndex1-start1);
FindKLargestTwoSortedArray(
arr1,
midIndex1,
(midIndex1 + remainingElementToSearch) >= arr1.length ? arr1.length-1
: (midIndex1 + remainingElementToSearch), arr2,
start2, midIndex2, remainingElementToSearch);


} else if (arr1[midIndex1] > arr2[midIndex2]) {
FindKLargestTwoSortedArray(arr2, start2, end2, arr1, start1,
end1, k);
}


} else {
return;
}


}
}

这个灵感来自于 优酷视频的 Algo

链接到代码 复杂度(log (n) + log (m))

链接到 Code (log (n) * log (m))

实现 (log (n) + log (m))解决方案

我想补充一下我对这个问题的解释。 这是一个经典的问题,我们必须使用这两个数组排序的事实。 我们已经给出了两个大小为 sz1的排序数组 arr1和大小为 sz2的 arr2

A)假设

检查 k 是否有效

K 为 > (sz1 + sz2)

那么我们就不能在两个排序数组的联合中找到 kth 最小元素,所以返回无效数据。 b)Now if above condition holds false and we have valid and feasible value of k,

Managing Edge Cases

我们将在数组的前面附加-无穷大值,在末尾附加 + 无穷大值,以覆盖 k = 1,2和 k = (sz1 + sz2-1) ,(sz1 + sz2)等边缘情况。

现在这两个数组的大小分别是 (sz1 + 2)(sz2 + 2)

主要算法

现在,我们将对 arr1进行二进制搜索。我们将在 arr1上进行二进制搜索,查找索引 i,StartIndex < = i < = endIndex

这样,如果我们使用约束{(i + j) = k }在 arr2中找到对应的索引 j,那么如果

如果 (arr2[ j-1] < arr1[ i ] < arr2[ j ]),那么 arr1[ i ]是第 kth 最小的(情况1)

Else 如果 (arr1[ i-1] < arr2[ j ] < arr1[ i ]),那么 arr2[ i ]是第 kth 最小的(Case 2)

else signifies either Arr1[ i ] < arr2[ j-1] < arr2[ j ] (Case3)

Arr2[ j-1] < arr2[ j ] < arr1[ i ](个案4)

既然我们知道 Kth 最小元素具有比它小的(k-1)元素是两个数组 ryt 的联合,那么,

个案1中,我们所做的是确保总共有(k-1)个小于 arr1[ i ]的元素,因为 arr1数组中小于 arr1[ i ]的元素的数目比我们知道的要小(arr2[ j-1] < arr1[ i ] < arr2[ j ]) ,而 arr2中小于 arr1[ i ]的元素的数目是 j-1,因为 j 是使用(i-1) + (j-1) = (k-1)求得的。所以 kth 最小的元素是 arr1[ i ]

但是答案可能不总是来自第一个数组 ie arr1,所以我们检查了 个案2,它也满足类似情况1的要求,因为(i-1) + (j-1) = (k-1)。现在,如果我们有(arr1[ i-1] < arr2[ j ] < arr1[ i ]) ,我们有一个小于 arr2[ j ]的 k-1元素的总和,它是两个数组中最小的元素。

个案3中,为了将它形成任意一种情况1或情况2,我们需要增量 i,j 将根据约束{(i + j) = k } ie 在二进制搜索中移动到右侧即 make startIndex = midIndex

个案4中,为了将其形成任意一种情况1或情况2,我们需要递减 i,j 将根据约束{(i + j) = k } ie 在二进制搜索中移动到左侧即 make endIndex = midIndex。

现在,如何在 arr1上的二进制搜索开始时确定 startIndex 和 endIndex with startindex = 1 and endIndex = ??.We need to decide.

如果 k > sz1,endIndex = (sz1 + 1) ,else endIndex = k;

因为如果 k 大于第一个数组的大小,我们可能需要对整个数组进行二进制搜索,否则我们只需要取第一个 k 元素,因为 sz1-k 元素永远不能计算 kth 最小值。

下面显示的代码

// Complexity    O(log(n)+log(m))


#include<bits/stdc++.h>
using namespace std;
#define f(i,x,y) for(int i = (x);i < (y);++i)
#define F(i,x,y) for(int i = (x);i > (y);--i)
int max(int a,int b){return (a > b?a:b);}
int min(int a,int b){return (a < b?a:b);}
int mod(int a){return (a > 0?a:((-1)*(a)));}
#define INF 1000000








int func(int *arr1,int *arr2,int sz1,int sz2,int k)


{


if((k <= (sz1+sz2))&&(k > 0))


{
int s = 1,e,i,j;
if(k > sz1)e = sz1+1;
else e = k;
while((e-s)>1)
{
i = (e+s)/2;
j = ((k-1)-(i-1));
j++;
if(j > (sz2+1)){s = i;}
else if((arr1[i] >= arr2[j-1])&&(arr1[i] <= arr2[j]))return arr1[i];
else if((arr2[j] >= arr1[i-1])&&(arr2[j] <= arr1[i]))return arr2[j];
else if(arr1[i] < arr2[j-1]){s = i;}
else if(arr1[i] > arr2[j]){e = i;}
else {;}
}
i = e,j = ((k-1)-(i-1));j++;
if((arr1[i] >= arr2[j-1])&&(arr1[i] <= arr2[j]))return arr1[i];
else if((arr2[j] >= arr1[i-1])&&(arr2[j] <= arr1[i]))return arr2[j];
else
{
i = s,j = ((k-1)-(i-1));j++;
if((arr1[i] >= arr2[j-1])&&(arr1[i] <= arr2[j]))return arr1[i];
else return arr2[j];
}


}


else


{
cout << "Data Invalid" << endl;
return -INF;


}


}










int main()


{
int n,m,k;
cin >> n >> m >> k;
int arr1[n+2];
int arr2[m+2];
f(i,1,n+1)
cin >> arr1[i];
f(i,1,m+1)
cin >> arr2[i];
arr1[0] = -INF;
arr2[0] = -INF;
arr1[n+1] = +INF;
arr2[m+1] = +INF;
int val = func(arr1,arr2,n,m,k);
if(val != -INF)cout << val << endl;
return 0;


}

用于解决复杂性(log (n) * log (m))

我没有利用约束{(i-1) + (j-1) = (k-1)}可以找到每个 i 的 j 的优势,所以对于每个 i,我进一步在第二个数组上应用二进制搜索来找到 j,使得 arr2[ j ] < = arr1[ i ]。因此,该解决方案可以进一步优化

基本上,通过这种方法,您可以在每个步骤中丢弃 k/2元素。 K 将从 k = > k/2 = > k/4 = > ... 递归变化,直到达到1。 所以,时间复杂度是 O (logk)

在 k = 1时,我们得到两个数组中最小的一个。

下面的代码是用 JAVA 编写的。请注意,我们正在从索引中减去代码中的1(- 1) ,因为 Java 数组的索引从0开始,而不是从1,K = 3由数组第二个索引中的元素表示。开始

private int kthElement(int[] arr1, int[] arr2, int k) {
if (k < 1 || k > (arr1.length + arr2.length))
return -1;
return helper(arr1, 0, arr1.length - 1, arr2, 0, arr2.length - 1, k);
}




private int helper(int[] arr1, int low1, int high1, int[] arr2, int low2, int high2, int k) {
if (low1 > high1) {
return arr2[low2 + k - 1];
} else if (low2 > high2) {
return arr1[low1 + k - 1];
}
if (k == 1) {
return Math.min(arr1[low1], arr2[low2]);
}
int i = Math.min(low1 + k / 2, high1 + 1);
int j = Math.min(low2 + k / 2, high2 + 1);
if (arr1[i - 1] > arr2[j - 1]) {
return helper(arr1, low1, high1, arr2, j, high2, k - (j - low2));
} else {
return helper(arr1, i, high1, arr2, low2, high2, k - (i - low1));
}
}
#include <bits/stdc++.h>
using namespace std;


int findKthElement(int a[],int start1,int end1,int b[],int start2,int end2,int k){


if(start1 >= end1)return b[start2+k-1];
if(start2 >= end2)return a[start1+k-1];
if(k==1)return min(a[start1],b[start2]);
int aMax = INT_MAX;
int bMax = INT_MAX;
if(start1+k/2-1 < end1) aMax = a[start1 + k/2 - 1];
if(start2+k/2-1 < end2) bMax = b[start2 + k/2 - 1];


if(aMax > bMax){
return findKthElement(a,start1,end1,b,start2+k/2,end2,k-k/2);
}
else{
return findKthElement(a,start1 + k/2,end1,b,start2,end2,k-k/2);
}
}


int main(void){
int t;
scanf("%d",&t);
while(t--){
int n,m,k;
cout<<"Enter the size of 1st Array"<<endl;
cin>>n;
int arr[n];
cout<<"Enter the Element of 1st Array"<<endl;
for(int i = 0;i<n;i++){
cin>>arr[i];
}
cout<<"Enter the size of 2nd Array"<<endl;
cin>>m;
int arr1[m];
cout<<"Enter the Element of 2nd Array"<<endl;
for(int i = 0;i<m;i++){
cin>>arr1[i];
}
cout<<"Enter The Value of K";
cin>>k;
sort(arr,arr+n);
sort(arr1,arr1+m);
cout<<findKthElement(arr,0,n,arr1,0,m,k)<<endl;
}


return 0;
}

时间复杂度为 O (log (min (n,m)))

Most of the answers I found here focus on both arrays. while it's good but it's harder to implement as there are a lot of edge cases that we need to take care of. Besides that most of the implementations are recursive which adds the space complexity of recursion stack. So instead of focusing on both arrays I decided to just focus on the smaller array and do the binary search on just the smaller array and adjust the pointer for the second array based on the value of the pointer in the first Array. By the following implementation, we have the complexity of O(log(min(n,m)) with O(1) space complexity.

    public static int kth_two_sorted(int []a, int b[],int k){
if(a.length > b.length){
return kth_two_sorted(b,a,k);
}
if(a.length + a.length < k){
throw new RuntimeException("wrong argument");
}
int low = 0;
int high = k;
if(a.length <= k){
high = a.length-1;
}
while(low <= high){
int sizeA = low+(high - low)/2;
int sizeB = k - sizeA;
boolean shrinkLeft = false;
boolean extendRight = false;
if(sizeA != 0){
if(sizeB !=b.length){
if(a[sizeA-1] > b[sizeB]){
shrinkLeft = true;
high = sizeA-1;
}
}
}
if(sizeA!=a.length){
if(sizeB!=0){
if(a[sizeA] < b[sizeB-1]){
extendRight = true;
low = sizeA;
}
}
}
if(!shrinkLeft && !extendRight){
return Math.max(a[sizeA-1],b[sizeB-1]) ;
}
}
throw  new IllegalArgumentException("we can't be here");
}

我们有一个数组 a[low, high]范围,我们进一步通过算法来缩小这个范围。sizeA显示了来自 k的项目中有多少来自数组 a,并且它是从 lowhigh的值派生出来的。sizeB是相同的定义,除了我们计算的价值,这样的方式 sizeA+sizeB=k。基于这两个边界上的值得出的结论是,我们必须在数组 a中向右扩展或向左收缩。如果我们停留在相同的位置,这意味着我们找到了解决方案,我们将返回 aa0位置和 a3的 a2位置的最大值。