Thursday, June 30, 2011

TopCoder SRM 165 Div 2 500-pt

http://www.topcoder.com/tc?module=Static&d1=match_editorials&d2=srm165

We are trying to see which number of processors requires the least amount of time to process K tasks. We are told it takes 1 ms for 1 processor to process 1 task. So it takes N processors K / N ms to process K tasks. But parallelization requires time to sync the processors. We are told it takes OVERHEAD ms per pair of processors to sync up initially. How many pairs are there? There are n choose 2 pairs, which is n!/(2!(n-2)!). So we can calculate the total time for N processors as (n choose 2)*overhead + K / N. Ok, so this is easy, we just loop through and plug in a different N every iteration, and then return the best one. But don't forget, n! can become a very big number, so you have to protect against overflow. There's two ways: 1) use a bigger type, like long or BigInteger 2) Notice that with n!/2!(n-2)! you can cancel out alot of the top numbers because they appear on the bottom. I did 1. Another thing to worry about is the fact that computing n! is a recursion that computes the value of n*(n-1)... The thing to notice about this is that n! is actually n * (n - 1)!. Well since your main loop is going from 2 to n, you've already computed the value of (n - 1)! in the last loop. All you need to do is store the result in a hash table and look it up instead of recomputing everytime. Instead of doing N deep recursion through every loop, it's reduced to looking up the previous value in a hash table, or O(1). Another optimization is the fact that the function has a minimum point (the shortest time). So once you find that you can return. In fact, adding that last optimization my execution time from ~50 ms to ~5 ms.


import java.util.HashMap;
import java.math.BigInteger;
public class ParallelSpeedup
{
    HashMap<Integer, BigInteger> facts = new HashMap<Integer, BigInteger>();
    public int numProcessors(int k, int overhead)
    {
   
        //(n!/(2!(n-2)! x overhead + k/n is the function
       
        BigInteger bestTime = BigInteger.valueOf(k);
        BigInteger prevTime = BigInteger.valueOf(k);
        int bestN = 1;
        int diff = 0;
        for(int n = 2; n < 1000; n++)
        {
            BigInteger newTime = getTime(n, k, overhead);
            //check if we passed the minimum
            if(newTime.compareTo(prevTime) > 0)
                return bestN;
            diff = newTime.compareTo(bestTime);
            if(diff < 0){
                bestTime = newTime;
                bestN = n;
                }
            prevTime = newTime;
        }
        return bestN;
    }
    private BigInteger getTime(int n, int k, int overhead)
    {
        BigInteger nchoose2 = fact(n).divide(fact(2).multiply(fact(n-2)));
        BigInteger perPro = BigInteger.valueOf((int)Math.ceil((double)k / (double)n));
        return nchoose2.multiply(BigInteger.valueOf(overhead)).add(perPro);
    }
    private BigInteger fact(int n)
    {
        if(n <=1)
            return BigInteger.valueOf(1);
        else if (facts.containsKey(n))
        {
            return facts.get(n);
        }
        else
        {
            BigInteger n1 = BigInteger.valueOf(n).multiply(fact(n-1));
            facts.put(n, n1);
            return n1;
        }
    }
}

No comments:

Post a Comment

There was an error in this gadget