AES Encryption

This post will discuss AES 256 in Java. To achieve the desired key length of 256, the Java Cryptography Extension (JCE) must be installed. If you’d rather not worry about adding the files for the JCE, simply change keyLength in AESTest.java from 256 to 128 (the maximum limited strength key length).

The Advanced Encryption Standard

In 1995, the National Institute of Science and Technology (NIST) began its search to replace the aging Data Encryption Standard (DES). In 1997, a call went out for proposals that, among other things, required the use of 128-bit blocks and 128, 192, and 256-bit key lengths. In 1998, 15 algorithms were selected, 5 receiving further scrutiny, with Rijndael (by Vincent Rijmen and Joan Daemen) eventually selected to became the Advanced Encryption Standard (AES) in 2001 (as of this post, AES-256 is an NSA approved TOP SECRET standard).

AES and HIPAA

On April 27, 2009, HHS issued guidelines for rendering electronic Protected Health Information (ePHI) “unusable, unreadable, or indecipherable to unauthorized individuals.” To meet this definition, the data must be encrypted (Security Rule (45 CFR ยง164.304)) per NIST Special Publication 800-111 (source). In that document, it states, “whenever possible, AES should be used for the encryption algorithm because of its strength and speed.” Coupled with the NSA’s TOP SECRET recommendation, AES-256 is a strong standard for ePHI encryption, thus the focus of this post.

AES Class

AES.java is a simple implementation of AES. The class will create and store the initialization vector and random salt for each new key length, allowing for repeated use.

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.security.SecureRandom;
import java.security.spec.KeySpec;
import java.util.Base64;
import java.util.Random;
import javax.crypto.Cipher;
import javax.crypto.SecretKey;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.PBEKeySpec;
import javax.crypto.spec.SecretKeySpec;
 
/**
 * A simple AES encryption/decryption class with variable key length. As it is
 * written, this class will create a static initialization vector and salt for 
 * each key length entered. These values are saved to files for repeated use -
 * i.e., long-term use.
 * @author Ray Hylock
 */
class AES {
    // static variables
    private static final String IV_FILE = "iv";
    private static final String SALT_FILE = "salt";
    private static final String EXT = ".k";
    private static final int SALT_BYTES = 32;
    private static final int ITERATION_COUNT = 65536;
    private static final String CHARSET = "UTF-16";
     
    // instantiation variables
    private final String filePath;
    private final int keyLength;
    private Cipher cipherEnc;
    private Cipher cipherDec;
    private byte[] iv = null;
    private byte[] salt = null;
 
    /**
     * Instantiates a new AES object with the specified pass phrase and key 
     * length.
     * @param passPhrase    the pass phrase
     * @param keyLength     the key length
     * @throws Exception 
     */
    public AES(char[] passPhrase, int keyLength) throws Exception{
        this("", passPhrase, keyLength);
    }
     
    /**
     * Instantiates a new AES object with the specified file path, pass 
     * phrase, and key length.
     * @param filePath      the file path
     * @param passPhrase    the pass phrase
     * @param keyLength     the key length
     * @throws Exception 
     */
    public AES(String filePath, char[] passPhrase, int keyLength) throws Exception{
        // create file path
        String fp = filePath.trim();
        if(fp.length() > 0){
            char c = fp.charAt(fp.length()-1);
            this.filePath = (c == '/' || c == '\\') ? fp : fp + "/"; 
        } else {
            this.filePath = fp;
        }
         
        // create directory structure
        File f = new File(this.filePath+IV_FILE);
        if(this.filePath.length() > 0) f.getParentFile().mkdirs();
         
        // prepare for encryption/decryption
        this.keyLength = keyLength;
        initEncryption(passPhrase);
        initDecrypter(passPhrase);
    }
     
    /**
     * Initializes the encryption portion.
     * @param passPhrase pass phrase
     * @throws Exception 
     */
    private void initEncryption(char[] passPhrase) throws Exception {
        // get/set salt
        if((salt = readSalt()) == null) {
            salt = randomSalt();
            writeSalt(salt);
        }
         
        // setup AES
        SecretKeyFactory factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA1");
        KeySpec spec = new PBEKeySpec(passPhrase, salt, ITERATION_COUNT, keyLength);
        SecretKey tmp = factory.generateSecret(spec);
        SecretKey secret = new SecretKeySpec(tmp.getEncoded(), "AES");
 
        // setup cipher
        cipherEnc = Cipher.getInstance("AES/CBC/PKCS5Padding");
        if((iv = readIV()) == null){
            cipherEnc.init(Cipher.ENCRYPT_MODE, secret);
            iv = cipherEnc.getParameters().getParameterSpec(IvParameterSpec.class).getIV();
            writeIV(iv);
        } else {
            cipherEnc.init(Cipher.ENCRYPT_MODE, secret, new IvParameterSpec(iv));
        }
    }
     
    /**
     * Initialize the decryption portion.
     * @param passPhrase pass phrase
     * @throws Exception 
     */
    private void initDecrypter(char[] passPhrase) throws Exception {
        // setup AES
        SecretKeyFactory factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA1");
        KeySpec spec = new PBEKeySpec(passPhrase, salt, ITERATION_COUNT, keyLength);
        SecretKey tmp = factory.generateSecret(spec);
        SecretKey secret = new SecretKeySpec(tmp.getEncoded(), "AES");
         
        // setup cipher
        cipherDec = Cipher.getInstance("AES/CBC/PKCS5Padding");
        cipherDec.init(Cipher.DECRYPT_MODE, secret, new IvParameterSpec(iv));
    }
     
    /**
     * Reads the IV from disk.
     * @return the IV
     */
    private byte[] readIV() {
        FileInputStream fis = null;
        try {
            fis = new FileInputStream(filePath+IV_FILE+"_"+keyLength+EXT);
            int l = fis.read();
            byte[] iv = new byte[l];
            fis.read(iv);
            fis.close();
            return iv;
        } catch (FileNotFoundException ex) {
            // file not found so we create the output
        } catch (IOException ex) {
            System.err.println("Error reading IV file: \n" + ex.getMessage());
        } finally {
            try {
                if(fis != null) fis.close();
            } catch (IOException ex) {
                System.err.println("Failed to close the IV input stream.\n"
                    + ex.getMessage());
            }
        }
        return null;
    }
     
    /**
     * Write the IV to disk.
     * @param iv IV to write
     */
    private void writeIV(byte[] iv) {
        FileOutputStream fos = null;
        try {
            fos = new FileOutputStream(filePath+IV_FILE+"_"+keyLength+EXT);
            fos.write((byte) iv.length);
            fos.write(iv);
            fos.flush();
            fos.close();
        } catch (FileNotFoundException ex) {
            System.err.println("Error writing IV file: \n" + ex.getMessage());
        } catch (IOException ex) {
            System.err.println("Error writing IV file: \n" + ex.getMessage());
        } finally {
            try {
                fos.close();
            } catch (IOException ex) {
                System.err.println("Failed to close the IV output stream.\n"
                    + ex.getMessage());
            }
        }
    }
     
    /**
     * Read the salt from disk.
     * @return the salt
     */
    private byte[] readSalt() {
        FileInputStream fis = null;
        try {
            fis = new FileInputStream(filePath+SALT_FILE+"_"+keyLength+EXT);
            int l = fis.read();
            byte[] salt = new byte[l];
            fis.read(salt);
            fis.close();
            return salt;
        } catch (FileNotFoundException ex) {
            // file not found so we create the output
        } catch (IOException ex) {
            System.err.println("Error reading SALT file: \n" + ex.getMessage());
        } finally {
            try {
                if(fis != null) fis.close();
            } catch (IOException ex) {
                System.err.println("Failed to close the SALT input stream.\n"
                    + ex.getMessage());
            }
        }
        return null;
    }
     
    /**
     * Write the salt to disk.
     * @param salt salt to write
     */
    private void writeSalt(byte[] salt) {
        FileOutputStream fos = null;
        try {
            fos = new FileOutputStream(filePath+SALT_FILE+"_"+keyLength+EXT);
            fos.write((byte) salt.length);
            fos.write(salt);
            fos.flush();
            fos.close();
        } catch (FileNotFoundException ex) {
            System.err.println("Error reading salt file: \n" + ex.getMessage());
        } catch (IOException ex) {
            System.err.println("Error reading salt file: \n" + ex.getMessage());
        } finally {
            try {
                fos.close();
            } catch (IOException ex) {
                System.err.println("Failed to close the SALT output stream.\n"
                    + ex.getMessage());
            }
        }
    } 
     
    /**
     * Creates a random salt.
     * @return random salt as a {@code byte[]}
     */
    private byte[] randomSalt(){
        final Random r = new SecureRandom();
        byte salt[] = new byte[SALT_BYTES];
        r.nextBytes(salt);
        return salt;
    }
 
    /**
     * Encrypt the message.
     * @param message   message to encrypt   
     * @return          the encrypted bytes as a Base64 encoded string
     * @throws Exception 
     */
    public String encrypt(String message) throws Exception {
        byte[] bytes = message.getBytes(CHARSET);
        byte[] encrypted = encrypt(bytes);
        return Base64.getEncoder().encodeToString(encrypted);
    }
 
    /**
     * Encrypt the message.
     * @param message   message to encrypt
     * @return          the encrypted bytes
     * @throws Exception 
     */
    public byte[] encrypt(byte[] message) throws Exception {
        return cipherEnc.doFinal(message);
    }
 
    /**
     * Decrypt the encrypted message.
     * @param message   encrypted message to decrypt
     * @return          the bytes as string
     * @throws Exception 
     */
    public String decrypt(String message) throws Exception {
        byte[] bytes = Base64.getDecoder().decode(message);
        byte[] decrypted = decrypt(bytes);
        return new String(decrypted, CHARSET);
    }
 
    /**
     * Decrypt the encrypted message.
     * @param message   encrypted message to decrypt
     * @return          the bytes
     * @throws Exception 
     */
    public byte[] decrypt(byte[] message) throws Exception {
        return cipherDec.doFinal(message);
    }
}
AES Test Class

The example test class instantiates an AES object using the given password. It then encrypts and decrypts a simple phrase, outputting the results. It is important to note the use of a character array instead of string for the password. Java uses what’s called a string pool in which to store all Strings. If two variables have the same value, they reference the same String object. This minimizes memory consumption, however it retains prior values in the pool for later use until the garbage collector removes them. Thus, it is possible for someone to retrieve your password, in this example, by examining the string pool even after you have set the value to null in your code. Hence, we use a primitive array of characters (line 4) that we then explicitly overwrite and set to null following its use, allowing garbage collection (lines 9-10).

public class AESTest {
    public static void main(String[] args) throws Exception {
        // initialize
        char[] passPhrase = {'s','i','m','p','l','e','p','a','s','s'};
        int keyLength = 256;    // bits of encryption
        AES aes = new AES(passPhrase, keyLength);
         
        // clear passPhrase
        java.util.Arrays.fill(passPhrase, '\u0000');
        passPhrase = null;
         
        // encrypt
        String message = "test message";
        String ciphertext = aes.encrypt(message);
        System.out.println(String.format("encrypt(%s) = %s", 
                message, ciphertext));
         
        // decrypt
        String plaintext = aes.decrypt(ciphertext);
        System.out.println(String.format("decrypt(%s) = %s", 
                ciphertext, plaintext));
    }
}

The above test outputs the following (note: your results will be different depending up the pass phrase and initialized IV and salt values):

encrypt(test message) = 3ZB9/Nx/TeJSgOjnObq+HyAAU/09d5muA1/YFVAX3oM=
decrypt(3ZB9/Nx/TeJSgOjnObq+HyAAU/09d5muA1/YFVAX3oM=) = test message