diff deps/Platinum/ThirdParty/Neptune/Source/Core/NptTls.cpp @ 0:3425707ddbf6

Initial import (hopefully this mercurial stuff works...)
author fraserofthenight
date Mon, 06 Jul 2009 08:06:28 -0700
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/deps/Platinum/ThirdParty/Neptune/Source/Core/NptTls.cpp	Mon Jul 06 08:06:28 2009 -0700
@@ -0,0 +1,492 @@
+/*****************************************************************
+|
+|   Neptune - TLS/SSL Support
+|
+| Copyright (c) 2002-2008, Axiomatic Systems, LLC.
+| All rights reserved.
+|
+| Redistribution and use in source and binary forms, with or without
+| modification, are permitted provided that the following conditions are met:
+|     * Redistributions of source code must retain the above copyright
+|       notice, this list of conditions and the following disclaimer.
+|     * Redistributions in binary form must reproduce the above copyright
+|       notice, this list of conditions and the following disclaimer in the
+|       documentation and/or other materials provided with the distribution.
+|     * Neither the name of Axiomatic Systems nor the
+|       names of its contributors may be used to endorse or promote products
+|       derived from this software without specific prior written permission.
+|
+| THIS SOFTWARE IS PROVIDED BY AXIOMATIC SYSTEMS ''AS IS'' AND ANY
+| EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+| WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+| DISCLAIMED. IN NO EVENT SHALL AXIOMATIC SYSTEMS BE LIABLE FOR ANY
+| DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+| (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+| LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+| ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+| (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+| SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+|
+ ****************************************************************/
+
+#if defined(NPT_CONFIG_ENABLE_TLS)
+
+/*----------------------------------------------------------------------
+|   includes
++---------------------------------------------------------------------*/
+#include "NptConfig.h"
+#include "NptTls.h"
+#include "NptLogging.h"
+#include "NptUtils.h"
+#include "NptSockets.h"
+
+#include "ssl.h"
+
+/*----------------------------------------------------------------------
+|   logging
++---------------------------------------------------------------------*/
+//NPT_SET_LOCAL_LOGGER("neptune.tls")
+
+/*----------------------------------------------------------------------
+|   constants
++---------------------------------------------------------------------*/
+const unsigned int NPT_TLS_CONTEXT_DEFAULT_SESSION_CACHE = 16;
+
+/*----------------------------------------------------------------------
+|   types
++---------------------------------------------------------------------*/
+typedef NPT_Reference<NPT_TlsSessionImpl> NPT_TlsSessionImplReference;
+
+/*----------------------------------------------------------------------
+|   NPT_Tls_MapResult
++---------------------------------------------------------------------*/
+static NPT_Result
+NPT_Tls_MapResult(int err)
+{
+    switch (err) {
+        case SSL_ERROR_CONN_LOST:         return NPT_ERROR_CONNECTION_ABORTED;
+        case SSL_ERROR_TIMEOUT:           return NPT_ERROR_TIMEOUT;
+        case SSL_ERROR_EOS:               return NPT_ERROR_EOS;
+        case SSL_ERROR_NOT_SUPPORTED:     return NPT_ERROR_NOT_SUPPORTED;
+        case SSL_ERROR_INVALID_HANDSHAKE: return NPT_ERROR_TLS_INVALID_HANDSHAKE;
+        case SSL_ERROR_INVALID_PROT_MSG:  return NPT_ERROR_TLS_INVALID_PROTOCOL_MESSAGE;
+        case SSL_ERROR_INVALID_HMAC:      return NPT_ERROR_TLS_INVALID_HMAC;
+        case SSL_ERROR_INVALID_VERSION:   return NPT_ERROR_TLS_INVALID_VERSION;
+        case SSL_ERROR_INVALID_SESSION:   return NPT_ERROR_TLS_INVALID_SESSION;
+        case SSL_ERROR_NO_CIPHER:         return NPT_ERROR_TLS_NO_CIPHER;
+        case SSL_ERROR_BAD_CERTIFICATE:   return NPT_ERROR_TLS_BAD_CERTIFICATE;
+        case SSL_ERROR_INVALID_KEY:       return NPT_ERROR_INVALID_KEY;
+        case 0:                           return NPT_SUCCESS;
+        default:                          return NPT_FAILURE;
+    }
+}
+
+/*----------------------------------------------------------------------
+|   NPT_TlsContextImpl
++---------------------------------------------------------------------*/
+class NPT_TlsContextImpl {
+public:
+    NPT_TlsContextImpl() :
+        m_SSL_CTX(ssl_ctx_new(SSL_SERVER_VERIFY_LATER, NPT_TLS_CONTEXT_DEFAULT_SESSION_CACHE)) {};
+    ~NPT_TlsContextImpl() { ssl_ctx_free(m_SSL_CTX); }
+    
+   NPT_Result LoadKey(NPT_TlsKeyFormat     key_format, 
+                      const unsigned char* key_data,
+                      NPT_Size             key_data_size,
+                      const char*          password);
+    
+    SSL_CTX* m_SSL_CTX;
+};
+
+/*----------------------------------------------------------------------
+|   NPT_TlsContextImpl::LoadKey
++---------------------------------------------------------------------*/
+NPT_Result 
+NPT_TlsContextImpl::LoadKey(NPT_TlsKeyFormat     key_format, 
+                            const unsigned char* key_data,
+                            NPT_Size             key_data_size,
+                            const char*          password)
+{
+    int object_type;
+    switch (key_format) {
+        case NPT_TLS_KEY_FORMAT_RSA_PRIVATE: object_type = SSL_OBJ_RSA_KEY; break;
+        case NPT_TLS_KEY_FORMAT_PKCS8:       object_type = SSL_OBJ_PKCS8;   break;
+        case NPT_TLS_KEY_FORMAT_PKCS12:      object_type = SSL_OBJ_PKCS12;  break;
+        default: return NPT_ERROR_INVALID_PARAMETERS;
+    }
+    
+    int result = ssl_obj_memory_load(m_SSL_CTX, object_type, key_data, key_data_size, password);
+    return NPT_Tls_MapResult(result);
+}
+
+/*----------------------------------------------------------------------
+|   NPT_TlsStreamAdapter
++---------------------------------------------------------------------*/
+struct NPT_TlsStreamAdapter
+{
+    static int Read(SSL_SOCKET* _self, void* buffer, unsigned int size) {
+        NPT_TlsStreamAdapter* self = (NPT_TlsStreamAdapter*)_self;
+        NPT_Size bytes_read = 0;
+        NPT_Result result = self->m_Input->Read(buffer, size, &bytes_read);
+        if (NPT_FAILED(result)) {
+            switch (result) {
+                case NPT_ERROR_EOS:     return SSL_ERROR_EOS;
+                case NPT_ERROR_TIMEOUT: return SSL_ERROR_TIMEOUT;
+                default:                return SSL_ERROR_CONN_LOST;
+            }
+        }
+        return bytes_read;
+    }
+
+    static int Write(SSL_SOCKET* _self, const void* buffer, unsigned int size) {
+        NPT_TlsStreamAdapter* self = (NPT_TlsStreamAdapter*)_self;
+        NPT_Size bytes_written = 0;
+        NPT_Result result = self->m_Output->Write(buffer, size, &bytes_written);
+        if (NPT_FAILED(result)) {
+            switch (result) {
+                case NPT_ERROR_EOS:     return SSL_ERROR_EOS;
+                case NPT_ERROR_TIMEOUT: return SSL_ERROR_TIMEOUT;
+                default:                return SSL_ERROR_CONN_LOST;
+            }
+        }
+        return bytes_written;
+    }
+    
+    NPT_TlsStreamAdapter(NPT_InputStreamReference  input, 
+                         NPT_OutputStreamReference output) :
+        m_Input(input), m_Output(output) {
+        m_Base.Read  = Read;
+        m_Base.Write = Write;
+    }
+    
+    SSL_SOCKET                m_Base;
+    NPT_InputStreamReference  m_Input;
+    NPT_OutputStreamReference m_Output;
+};
+
+
+/*----------------------------------------------------------------------
+|   NPT_TlsSessionImpl
++---------------------------------------------------------------------*/
+class NPT_TlsSessionImpl {
+public:
+    NPT_TlsSessionImpl(SSL_CTX*                   context,
+                       NPT_InputStreamReference&  input,
+                       NPT_OutputStreamReference& output) :
+        m_SSL_CTX(context),
+        m_SSL(NULL),
+        m_StreamAdapter(input, output) {}
+    ~NPT_TlsSessionImpl() { ssl_free(m_SSL); }
+    
+    // methods
+    NPT_Result Handshake();
+    NPT_Result GetSessionId(NPT_DataBuffer& session_id);
+    NPT_UInt32 GetCipherSuiteId();
+    NPT_Result GetPeerCertificateInfo(NPT_TlsCertificateInfo& cert_info);
+    
+    // members
+    SSL_CTX*             m_SSL_CTX;
+    SSL*                 m_SSL;
+    NPT_TlsStreamAdapter m_StreamAdapter;
+};
+
+/*----------------------------------------------------------------------
+|   NPT_TlsSessionImpl::Handshake
++---------------------------------------------------------------------*/
+NPT_Result
+NPT_TlsSessionImpl::Handshake()
+{
+    if (m_SSL == NULL) {
+        // we have not performed the handshake yet 
+        m_SSL = ssl_client_new(m_SSL_CTX, &m_StreamAdapter.m_Base, NULL, 0);
+    }
+    
+    int result = ssl_handshake_status(m_SSL);
+    return NPT_Tls_MapResult(result);
+}
+
+/*----------------------------------------------------------------------
+|   NPT_TlsSessionImpl::GetSessionId
++---------------------------------------------------------------------*/
+NPT_Result
+NPT_TlsSessionImpl::GetSessionId(NPT_DataBuffer& session_id)
+{
+    if (m_SSL == NULL) {
+        // no handshake done
+        session_id.SetDataSize(0);
+        return NPT_ERROR_INVALID_STATE;
+    }
+    
+    // return the session id
+    session_id.SetData(ssl_get_session_id(m_SSL),
+                       ssl_get_session_id_size(m_SSL));
+    return NPT_SUCCESS;
+}
+
+/*----------------------------------------------------------------------
+|   NPT_TlsSessionImpl::GetCipherSuiteId
++---------------------------------------------------------------------*/
+NPT_UInt32
+NPT_TlsSessionImpl::GetCipherSuiteId()
+{
+    if (m_SSL == NULL) {
+        // no handshake done
+        return 0;
+    }
+    
+    return ssl_get_cipher_id(m_SSL);
+}
+
+/*----------------------------------------------------------------------
+|   NPT_TlsSessionImpl::GetPeerCertificateInfo
++---------------------------------------------------------------------*/
+NPT_Result
+NPT_TlsSessionImpl::GetPeerCertificateInfo(NPT_TlsCertificateInfo& cert_info)
+{
+    cert_info.subject.common_name         = ssl_get_cert_dn(m_SSL, SSL_X509_CERT_COMMON_NAME);
+    cert_info.subject.organization        = ssl_get_cert_dn(m_SSL, SSL_X509_CERT_ORGANIZATION);
+    cert_info.subject.organizational_name = ssl_get_cert_dn(m_SSL, SSL_X509_CERT_ORGANIZATIONAL_NAME);
+    cert_info.issuer.common_name          = ssl_get_cert_dn(m_SSL, SSL_X509_CA_CERT_COMMON_NAME);
+    cert_info.issuer.organization         = ssl_get_cert_dn(m_SSL, SSL_X509_CA_CERT_ORGANIZATION);
+    cert_info.issuer.organizational_name  = ssl_get_cert_dn(m_SSL, SSL_X509_CA_CERT_ORGANIZATIONAL_NAME);
+    
+    ssl_get_cert_fingerprints(m_SSL, cert_info.fingerprint.md5, cert_info.fingerprint.sha1);
+    
+    return NPT_SUCCESS;
+}
+
+/*----------------------------------------------------------------------
+|   NPT_TlsInputStream
++---------------------------------------------------------------------*/
+class NPT_TlsInputStream : public NPT_InputStream {
+public:
+    NPT_TlsInputStream(NPT_TlsSessionImplReference& session) :
+        m_Session(session),
+        m_Position(0),
+        m_RecordCacheData(NULL),
+        m_RecordCacheSize(0) {}
+    
+    // NPT_InputStream methods
+    virtual NPT_Result Read(void*     buffer, 
+                            NPT_Size  bytes_to_read, 
+                            NPT_Size* bytes_read = NULL);
+    virtual NPT_Result Seek(NPT_Position)           { return NPT_ERROR_NOT_SUPPORTED; }
+    virtual NPT_Result Tell(NPT_Position& offset)   { offset = m_Position; return NPT_SUCCESS; }
+    virtual NPT_Result GetSize(NPT_LargeSize& size) { size=0; return NPT_ERROR_NOT_SUPPORTED; }
+    virtual NPT_Result GetAvailable(NPT_LargeSize& available);
+
+private:
+    NPT_TlsSessionImplReference m_Session;
+    NPT_Position                m_Position;
+    uint8_t*                    m_RecordCacheData;
+    NPT_Size                    m_RecordCacheSize;
+};
+
+/*----------------------------------------------------------------------
+|   NPT_TlsInputStream::Read
++---------------------------------------------------------------------*/
+NPT_Result 
+NPT_TlsInputStream::Read(void*     buffer, 
+                         NPT_Size  bytes_to_read, 
+                         NPT_Size* bytes_read)
+{
+    // setup default values
+    if (bytes_read) *bytes_read = 0;
+    
+    // quick check for edge case
+    if (bytes_to_read == 0) return NPT_SUCCESS;
+    
+    // read a new record if we don't have one cached
+    if (m_RecordCacheData == NULL) {
+        int ssl_result;
+        do {
+            ssl_result = ssl_read(m_Session->m_SSL, &m_RecordCacheData);
+        } while (ssl_result == 0);
+        if (ssl_result < 0) {
+            return NPT_Tls_MapResult(ssl_result);
+        } 
+        m_RecordCacheSize = ssl_result;
+    }
+    
+    // we now have data in cache
+    if (bytes_to_read > m_RecordCacheSize) {
+        // read at most what's in the cache
+        bytes_to_read = m_RecordCacheSize;
+    }
+    NPT_CopyMemory(buffer, m_RecordCacheData, bytes_to_read);
+    if (bytes_read) *bytes_read = bytes_to_read;
+
+    // update the record cache
+    m_RecordCacheSize -= bytes_to_read;
+    if (m_RecordCacheSize == 0) {
+        // nothing left in the cache
+        m_RecordCacheData = NULL;
+    } else {
+        // move the cache pointer
+        m_RecordCacheData += bytes_to_read;
+    }
+    
+    return NPT_SUCCESS;
+}
+
+/*----------------------------------------------------------------------
+|   NPT_TlsInputStream::GetAvailable
++---------------------------------------------------------------------*/
+NPT_Result 
+NPT_TlsInputStream::GetAvailable(NPT_LargeSize& /*available*/)
+{
+    return NPT_SUCCESS;
+}
+
+/*----------------------------------------------------------------------
+|   NPT_TlsOutputStream
++---------------------------------------------------------------------*/
+class NPT_TlsOutputStream : public NPT_OutputStream {
+public:
+    NPT_TlsOutputStream(NPT_TlsSessionImplReference& session) :
+        m_Session(session),
+        m_Position(0) {}
+    
+    // NPT_OutputStream methods
+    virtual NPT_Result Write(const void* buffer, 
+                             NPT_Size    bytes_to_write, 
+                             NPT_Size*   bytes_written = NULL);
+    virtual NPT_Result Seek(NPT_Position) { return NPT_ERROR_NOT_SUPPORTED; }
+    virtual NPT_Result Tell(NPT_Position& offset) { offset = m_Position; return NPT_SUCCESS; }
+
+private:
+    NPT_TlsSessionImplReference m_Session;
+    NPT_Position                m_Position;
+};
+
+/*----------------------------------------------------------------------
+|   NPT_TlsOutputStream::Write
++---------------------------------------------------------------------*/
+NPT_Result 
+NPT_TlsOutputStream::Write(const void* buffer, 
+                           NPT_Size    bytes_to_write, 
+                           NPT_Size*   bytes_written)
+{
+    // setup default values
+    if (bytes_written) *bytes_written = 0;
+    
+    // quick check for edge case 
+    if (bytes_to_write == 0) return NPT_SUCCESS;
+    
+    // write some data
+    int ssl_result;
+    do {
+        ssl_result = ssl_write(m_Session->m_SSL, (const uint8_t*)buffer, bytes_to_write);
+    } while (ssl_result == 0);
+    if (ssl_result < 0) {
+        return NPT_Tls_MapResult(ssl_result);
+    }
+    m_Position += ssl_result;
+    if (bytes_written) *bytes_written = (NPT_Size)ssl_result;
+    
+    return NPT_SUCCESS;
+}
+
+/*----------------------------------------------------------------------
+|   NPT_TlsContext::NPT_TlsContext
++---------------------------------------------------------------------*/
+NPT_TlsContext::NPT_TlsContext() :
+    m_Impl(new NPT_TlsContextImpl())
+{
+}
+
+/*----------------------------------------------------------------------
+|   NPT_TlsContext::~NPT_TlsContext
++---------------------------------------------------------------------*/
+NPT_TlsContext::~NPT_TlsContext()
+{
+    delete m_Impl;
+}
+
+/*----------------------------------------------------------------------
+|   NPT_TlsContext::LoadKey
++---------------------------------------------------------------------*/
+NPT_Result 
+NPT_TlsContext::LoadKey(NPT_TlsKeyFormat     key_format, 
+                        const unsigned char* key_data,
+                        NPT_Size             key_data_size,
+                        const char*          password)
+{
+    return m_Impl->LoadKey(key_format, key_data, key_data_size, password);
+}
+
+/*----------------------------------------------------------------------
+|   NPT_TlsClientSession::NPT_TlsClientSession
++---------------------------------------------------------------------*/
+NPT_TlsClientSession::NPT_TlsClientSession(NPT_TlsContextReference&   context,
+                                           NPT_InputStreamReference&  input,
+                                           NPT_OutputStreamReference& output) :
+    m_Context(context),
+    m_Impl(new NPT_TlsSessionImpl(context->m_Impl->m_SSL_CTX, input, output))
+{
+}
+
+/*----------------------------------------------------------------------
+|   NPT_TlsClientSession::~NPT_TlsClientSession
++---------------------------------------------------------------------*/
+NPT_TlsClientSession::~NPT_TlsClientSession()
+{
+}
+
+/*----------------------------------------------------------------------
+|   NPT_TlsClientSession::Handshake
++---------------------------------------------------------------------*/
+NPT_Result 
+NPT_TlsClientSession::Handshake()
+{
+    return m_Impl->Handshake();
+}
+
+/*----------------------------------------------------------------------
+|   NPT_TlsClientSession::GetSessionId
++---------------------------------------------------------------------*/
+NPT_Result 
+NPT_TlsClientSession::GetSessionId(NPT_DataBuffer& session_id)
+{
+    return m_Impl->GetSessionId(session_id);
+}
+
+/*----------------------------------------------------------------------
+|   NPT_TlsClientSession::GetCipherSuiteId
++---------------------------------------------------------------------*/
+NPT_UInt32
+NPT_TlsClientSession::GetCipherSuiteId()
+{
+    return m_Impl->GetCipherSuiteId();
+}
+              
+/*----------------------------------------------------------------------
+|   NPT_TlsSession::GetPeerCertificateInfo
++---------------------------------------------------------------------*/
+NPT_Result
+NPT_TlsClientSession::GetPeerCertificateInfo(NPT_TlsCertificateInfo& cert_info)
+{
+    return m_Impl->GetPeerCertificateInfo(cert_info);
+}
+
+/*----------------------------------------------------------------------
+|   NPT_TlsClientSession::GetInputStream
++---------------------------------------------------------------------*/
+NPT_Result 
+NPT_TlsClientSession::GetInputStream(NPT_InputStreamReference& stream)
+{
+    stream = new NPT_TlsInputStream(m_Impl);
+    return NPT_SUCCESS;
+}
+
+/*----------------------------------------------------------------------
+|   NPT_TlsClientSession::GetOutputStream
++---------------------------------------------------------------------*/
+NPT_Result 
+NPT_TlsClientSession::GetOutputStream(NPT_OutputStreamReference& stream)
+{
+    stream = new NPT_TlsOutputStream(m_Impl);
+    return NPT_SUCCESS;
+}
+
+#endif // NPT_CONFIG_ENABLE_TLS