http://www.topcoder.com/tc?module=Static&d1=match_editorials&d2=srm165

We are trying to see which number of processors requires the least amount of time to process K tasks. We are told it takes 1 ms for 1 processor to process 1 task. So it takes N processors K / N ms to process K tasks. But parallelization requires time to sync the processors. We are told it takes OVERHEAD ms per pair of processors to sync up initially. How many pairs are there? There are n choose 2 pairs, which is n!/(2!(n-2)!). So we can calculate the total time for N processors as (n choose 2)*overhead + K / N. Ok, so this is easy, we just loop through and plug in a different N every iteration, and then return the best one. But don't forget, n! can become a very big number, so you have to protect against overflow. There's two ways: 1) use a bigger type, like long or BigInteger 2) Notice that with n!/2!(n-2)! you can cancel out alot of the top numbers because they appear on the bottom. I did 1. Another thing to worry about is the fact that computing n! is a recursion that computes the value of n*(n-1)... The thing to notice about this is that n! is actually n * (n - 1)!. Well since your main loop is going from 2 to n, you've already computed the value of (n - 1)! in the last loop. All you need to do is store the result in a hash table and look it up instead of recomputing everytime. Instead of doing N deep recursion through every loop, it's reduced to looking up the previous value in a hash table, or O(1). Another optimization is the fact that the function has a minimum point (the shortest time). So once you find that you can return. In fact, adding that last optimization my execution time from ~50 ms to ~5 ms.

import java.util.HashMap;

import java.math.BigInteger;

public class ParallelSpeedup

{

HashMap<Integer, BigInteger> facts = new HashMap<Integer, BigInteger>();

public int numProcessors(int k, int overhead)

{

//(n!/(2!(n-2)! x overhead + k/n is the function

BigInteger bestTime = BigInteger.valueOf(k);

BigInteger prevTime = BigInteger.valueOf(k);

int bestN = 1;

int diff = 0;

for(int n = 2; n < 1000; n++)

{

BigInteger newTime = getTime(n, k, overhead);

//check if we passed the minimum

if(newTime.compareTo(prevTime) > 0)

return bestN;

diff = newTime.compareTo(bestTime);

if(diff < 0){

bestTime = newTime;

bestN = n;

}

prevTime = newTime;

}

return bestN;

}

private BigInteger getTime(int n, int k, int overhead)

{

BigInteger nchoose2 = fact(n).divide(fact(2).multiply(fact(n-2)));

BigInteger perPro = BigInteger.valueOf((int)Math.ceil((double)k / (double)n));

return nchoose2.multiply(BigInteger.valueOf(overhead)).add(perPro);

}

private BigInteger fact(int n)

{

if(n <=1)

return BigInteger.valueOf(1);

else if (facts.containsKey(n))

{

return facts.get(n);

}

else

{

BigInteger n1 = BigInteger.valueOf(n).multiply(fact(n-1));

facts.put(n, n1);

return n1;

}

}

}

We are trying to see which number of processors requires the least amount of time to process K tasks. We are told it takes 1 ms for 1 processor to process 1 task. So it takes N processors K / N ms to process K tasks. But parallelization requires time to sync the processors. We are told it takes OVERHEAD ms per pair of processors to sync up initially. How many pairs are there? There are n choose 2 pairs, which is n!/(2!(n-2)!). So we can calculate the total time for N processors as (n choose 2)*overhead + K / N. Ok, so this is easy, we just loop through and plug in a different N every iteration, and then return the best one. But don't forget, n! can become a very big number, so you have to protect against overflow. There's two ways: 1) use a bigger type, like long or BigInteger 2) Notice that with n!/2!(n-2)! you can cancel out alot of the top numbers because they appear on the bottom. I did 1. Another thing to worry about is the fact that computing n! is a recursion that computes the value of n*(n-1)... The thing to notice about this is that n! is actually n * (n - 1)!. Well since your main loop is going from 2 to n, you've already computed the value of (n - 1)! in the last loop. All you need to do is store the result in a hash table and look it up instead of recomputing everytime. Instead of doing N deep recursion through every loop, it's reduced to looking up the previous value in a hash table, or O(1). Another optimization is the fact that the function has a minimum point (the shortest time). So once you find that you can return. In fact, adding that last optimization my execution time from ~50 ms to ~5 ms.

import java.util.HashMap;

import java.math.BigInteger;

public class ParallelSpeedup

{

HashMap<Integer, BigInteger> facts = new HashMap<Integer, BigInteger>();

public int numProcessors(int k, int overhead)

{

//(n!/(2!(n-2)! x overhead + k/n is the function

BigInteger bestTime = BigInteger.valueOf(k);

BigInteger prevTime = BigInteger.valueOf(k);

int bestN = 1;

int diff = 0;

for(int n = 2; n < 1000; n++)

{

BigInteger newTime = getTime(n, k, overhead);

//check if we passed the minimum

if(newTime.compareTo(prevTime) > 0)

return bestN;

diff = newTime.compareTo(bestTime);

if(diff < 0){

bestTime = newTime;

bestN = n;

}

prevTime = newTime;

}

return bestN;

}

private BigInteger getTime(int n, int k, int overhead)

{

BigInteger nchoose2 = fact(n).divide(fact(2).multiply(fact(n-2)));

BigInteger perPro = BigInteger.valueOf((int)Math.ceil((double)k / (double)n));

return nchoose2.multiply(BigInteger.valueOf(overhead)).add(perPro);

}

private BigInteger fact(int n)

{

if(n <=1)

return BigInteger.valueOf(1);

else if (facts.containsKey(n))

{

return facts.get(n);

}

else

{

BigInteger n1 = BigInteger.valueOf(n).multiply(fact(n-1));

facts.put(n, n1);

return n1;

}

}

}

## No comments:

## Post a Comment