Threads

A Java Thread allows for segmentation and parallelization of a task. If using a single machine, threading is a natural option to boost performance. For instance, a quad-core, hyper-threaded computer can support up to 8 threads (though I’d recommend leaving 1 for the system, so 7 max). Each cluster node in contains 32 threads, so if granted access to a single node, then threading will provide significant performance gains of serial implementations (using multiple nodes simultaneously requires the use of MPI and will be the topic of a future post). In this post, we provide a basic introduction to creating threads to parallelize a simple task.

Threading example

The class below computes a simple distributed average using threads as distributed processes. The equation (with proof of correctness) used to compute the distribute average is as follows:

\begin{equation}
n=\sum_{i=1}^{n_{1}}1+\sum_{i=1}^{n_{2}}1+\cdots+\sum_{i=1}^{n_{m}}1=n_{1}+n_{2}+\cdots+n_{m}\label{eq:1avg}\tag{1}
\end{equation}
\begin{equation}
\sum_{i=1}^{n}x_{i}=\sum_{i=1}^{n_{1}}x_{i}+\sum_{i=1}^{n_{2}}x_{i}+\cdots+\sum_{i=1}^{n_{m}}x_{i}\label{eq:2avg}\tag{2}
\end{equation}
\begin{equation}
\frac{\sum_{i=1}^{n}x_{i}}{n}=\frac{(\ref{eq:2avg})}{(\ref{eq:1avg})}\tag{3}
\end{equation}

In this example, the number \(N\) = 1 billion is summed and averaged over three nodes. The results are then validated using the following short-cut average calculation for \(1…N\):

\begin{equation}
\frac{\left(\frac{N\left(N+1\right)}{2}\right)}{N}=\frac{N\left(N+1\right)}{2N}=\frac{N+1}{2}\tag{4}
\end{equation}

The method execute(threads, N) creates the threads (lines 25-26), starts them (line 29), and waits for them to finish (line 32). Each thread is a new instance of, in this example, CompThread. There are several elements of note. First, it implements Runnable (line 58), which, essentially, allows it to function in its own separate process. Second, it overrides the run() method (lines 72-76). This method is invoked by start() (line 29). Note: do not confuse t.start() with t.run() as the latter will not create a new thread, but execute it in the invoking one – i.e., non-parallel.

The methods update(sum, count) and getAverate() use a special modifier called synchronized. This ensures only one thread accesses the method at a time to avoid collisions. For example, if two threads attempt to alter sum simultaneously, one might overwrite the other, leading to inconsistent results.

The method interval(N, thread, nThreads) is quite important as it creates non-intersecting ranges. Taking our example, it will create the following ranges to process for each of the three threads: 0-333,333,332, 333,333,333-666,666,665, and 666,666,666-999,999,999 (notice the last thread is given a range of 333,333,334, which is 1 greater than the previous to ensure all values are included).

import static java.lang.System.out;
import java.util.ArrayList;
import java.util.List;
 
/**
 * This example computes a distributed average, using threads as the 
 * mechanism of distribution.
 * @author Ray Hylock
 */
public class Threads {
    // global variables
    private long sum;
    private long count;
     
    /**
     * Creates, executes, and waits for threads to finish.
     * @param threads   the number of threads
     * @param N         the number to inclusively average (i.e., SUM(1...N)/N)
     * @throws InterruptedException 
     */
    private void execute(final int threads, final long N) 
            throws InterruptedException{
        // create the threads
        List<Thread> tList = new ArrayList<Thread>();
        for(int i=0; i<threads; i++)
            tList.add(new Thread(new CompThread(interval(N, i+1, threads))));
         
        // start threads - done separately to exclude instantiation if timing
        for(Thread t : tList) t.start();    // calls run() in CompThread
         
        // wait for them to finish
        for(Thread t : tList) t.join();
    }
     
    /**
     * Update sum and count in a thread-safe fashion.
     * @param sum   the sum to add
     * @param count the count to add
     */
    private synchronized void update(final long sum, final long count){
        this.sum += sum;
        this.count += count;
    }
     
    /**
     * Get the average in a thread-safe fashion.
     * @return the average
     */
    private synchronized double getAverage(){
        return (double)sum/(double)count;
    }
     
    /**
     * Compute {@code thread}'s range within {@code N}. This is useful when
     * segmenting a task into non-intersecting portions.
     * @param N         the value to segment
     * @param thread    the current thread number
     * @param nthreads  the number of threads
     * @return          a 2D array where the start (index 0) and end (index 1) 
     *                  values are placed
     */
    private long[] interval(final long N, final int thread, final int nthreads){
        long interval[] = new long[2];
        interval[0] = (thread == 1) ? 0 : (N/nthreads)*(thread-1);
        interval[1] = (thread == nthreads) ? N-1 : (N/nthreads)*thread - 1;
        return interval;
    }
     
    /**
     * The computational thread. Implements {@link Runnable}, so {@code run()} 
     * is executed when {@code start()} is invoked on an instantiated
     * {@link CompThread} object.
     */
    private class CompThread implements Runnable {
        private long start, end;    // range to compute
        private long sum;
         
        /**
         * Create a new computational thread.
         * @param interval the start-to-end range
         */
        public CompThread (final long interval[]){
            // convert 0-index to 1-index (e.g., sum 1...100 instead of 0...99)
            start = interval[0] + 1;
            end = interval[1] + 1;
        }
         
        @Override
        public void run(){
            for(long i=start; i<=end; i++) sum += i;
            update(sum, end-start+1);
        }
    }
     
    /**
     * Main method.
     * @param args  command line arguments
     * @throws InterruptedException 
     */
    public static void main(String args[]) throws InterruptedException{
        int threads = 3;
        long N = 1_000_000_000;
        Threads t = new Threads();
        t.execute(threads, N);
        double dist = t.getAverage();
        double check = (((double)N+1d)/2d);     // ((N(N+1)/2)/N) -> (N+1)/2
        out.println("Distributed: " + dist);
        out.println("Check: " + check);
        out.println("Concur? " + (dist == check));
    }
}

Output from the example class:

Distributed: 5.000000005E8
Check: 5.000000005E8
Concur? true