import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.lang.reflect.Constructor;
import java.net.Socket;
import java.net.URL;
import java.security.KeyManagementException;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.Principal;
import java.security.PrivateKey;
import java.security.Provider;
import java.security.Security;
import java.security.UnrecoverableKeyException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.Enumeration;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.KeyManager;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSession;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509KeyManager;
import javax.net.ssl.X509TrustManager;

public class PKCS11Test {
    
    public PKCS11Test() {
    }
    
    public void run(String[] args) {
        try {
            createPKCS11Provider(args.length > 3 ? args[3] : null);
            KeyStore ks = createKeystore(args[1]);
            printKeystoreInfo(ks);
            configureHTTPSUrlConnection(ks, args[0], args[1]);
            // you can test this against "openssl s_server -www -verify 0"
            // you need a cert and key in server.pem, of course
            fetchUrl(new URL(args.length > 2 ? args[2] : "https://localhost:4433"));
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
    }
    
    private void createPKCS11Provider(String configFile) throws IOException {
        InputStream is;
        if (configFile == null) {
            StringBuffer cardConfig = new StringBuffer();
            cardConfig.append("name = SoftToken\n");
            cardConfig.append("library = /usr/local/lib/soft-pkcs11.so");
            is = new ByteArrayInputStream(cardConfig.toString().getBytes());
        } else {
            is = new FileInputStream(configFile);
        }
        
        Provider pkcs11 = null;
        // if compiling under Java 1.5
        // pkcs11 = new sun.security.pkcs11.SunPKCS11(is);
        
        // if compiling under Java 1.4
        try {
            Class pkcs11Class = Class.forName("sun.security.pkcs11.SunPKCS11");
            Constructor c = pkcs11Class.getConstructor(new Class[] { InputStream.class });
            pkcs11 = (Provider) c.newInstance(new Object[] { is });
        } catch (Exception e) {
            System.err.println("Error instantiating the PKCS11 provider");
            e.printStackTrace();
            System.exit(1);
        }
        
        Security.addProvider(pkcs11);
    }
    
    private KeyStore createKeystore(String kspassword) 
    throws KeyStoreException, IOException, NoSuchAlgorithmException, CertificateException {
        // create the keystore
        KeyStore ks = KeyStore.getInstance("PKCS11");
        ks.load(null, kspassword == null ? null : kspassword.toCharArray());
        return ks;
    }
    
    private void printKeystoreInfo(KeyStore keystore) throws KeyStoreException {
        System.out.println("Provider : " + keystore.getProvider().getName());
        System.out.println("Type : " + keystore.getType());
        System.out.println("Size : " + keystore.size());
        
        Enumeration en = keystore.aliases();
        while (en.hasMoreElements()) {
            System.out.println("Alias: " + en.nextElement());
        }
    }
    
    private void configureHTTPSUrlConnection(KeyStore keystore, String alias, String password) 
    throws NoSuchAlgorithmException, KeyStoreException, UnrecoverableKeyException, KeyManagementException {
        // make a naive trust manager that does not check cert validity
        TrustManager[] trustAllCerts = new TrustManager[] {
            new X509TrustManager() {
                public java.security.cert.X509Certificate[] getAcceptedIssuers(){
                    return null;
                }
                public void checkClientTrusted(java.security.cert.X509Certificate[] certs, String authType) {}
                public void checkServerTrusted(java.security.cert.X509Certificate[] certs, String authType) {}
            }
        };
        
        KeyManager[] managers;
        if (alias != null) {
            managers = new KeyManager[] { new AliasKeyManager(keystore, alias, password) };
        } else {
            // create a KeyManagerFactory from the keystore
            KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
            kmf.init(keystore, password == null ? null : password.toCharArray());
            managers = kmf.getKeyManagers();
        }
        
        // install the keymanager factory, and the trust manager
        SSLContext sc = SSLContext.getInstance("SSL");
        sc.init(managers, trustAllCerts, new java.security.SecureRandom());
        HttpsURLConnection.setDefaultSSLSocketFactory(sc.getSocketFactory());
        
        // prevent checking the hostname for spoofing
        HttpsURLConnection.setDefaultHostnameVerifier(new HostnameVerifier() {
            public boolean verify(String string, SSLSession session) {
                // we don't care if the hostname doesn't match, or the server cert is expired
                return true;
            }
        });
    }
    
    private void fetchUrl(URL url) throws IOException {
        // Fetch the URL
        BufferedReader br = new BufferedReader(new InputStreamReader(url.openStream()));
        String line;
        while ((line = br.readLine()) != null) {
            System.out.println(line);
        }
    }
    
    private class AliasKeyManager implements X509KeyManager {
        
        private KeyStore _ks;
        private String _alias;
        private String _password;
        
        public AliasKeyManager(KeyStore ks, String alias, String password) {
            _ks = ks;
            _alias = alias;
            _password = password;
        }
        
        public String chooseClientAlias(String[] str, Principal[] principal, Socket socket) {
            return _alias;
        }

        public String chooseServerAlias(String str, Principal[] principal, Socket socket) {
            return _alias;
        }

        public X509Certificate[] getCertificateChain(String alias) {
            try {
                return (X509Certificate[]) _ks.getCertificateChain(alias);
            } catch (Exception e) {
                e.printStackTrace();
                return null;
            }
        }

        public String[] getClientAliases(String str, Principal[] principal) {
            return new String[] { _alias };
        }

        public PrivateKey getPrivateKey(String alias) {
            try {
                return (PrivateKey) _ks.getKey(alias, _password == null ? null : _password.toCharArray());
            } catch (Exception e) {
                e.printStackTrace();
                return null;
            }
        }

        public String[] getServerAliases(String str, Principal[] principal) {
            return new String[] { _alias };
        }
        
    }
    
    private static void usage() {
        System.out.println("Usage:");
        System.out.println("\tjava PKCS11Test alias password [url] [config file]");
        System.out.println();
        System.out.println("Where:");
        System.out.println("\talias refers to the certificate to use");
        System.out.println("\tpassword is the PIN/password to unlock the hardware token");
        System.out.println("\t[url] defaults to https://localhost:4433/");
        System.out.println("\t[config file] assumes soft-pkcs11.so is in /usr/local/lib if not set");
        System.out.println();
    }
    
    public static void main(String[] args) {
        if (args.length<2) {
            usage();
            System.exit(1);
        }
        new PKCS11Test().run(args);
    }
    
}
