Mercurial > projects > hoofbaby
view deps/Platinum/ThirdParty/Neptune/Source/Core/NptTls.cpp @ 0:3425707ddbf6
Initial import (hopefully this mercurial stuff works...)
author | fraserofthenight |
---|---|
date | Mon, 06 Jul 2009 08:06:28 -0700 |
parents | |
children |
line wrap: on
line source
/***************************************************************** | | Neptune - TLS/SSL Support | | Copyright (c) 2002-2008, Axiomatic Systems, LLC. | All rights reserved. | | Redistribution and use in source and binary forms, with or without | modification, are permitted provided that the following conditions are met: | * Redistributions of source code must retain the above copyright | notice, this list of conditions and the following disclaimer. | * Redistributions in binary form must reproduce the above copyright | notice, this list of conditions and the following disclaimer in the | documentation and/or other materials provided with the distribution. | * Neither the name of Axiomatic Systems nor the | names of its contributors may be used to endorse or promote products | derived from this software without specific prior written permission. | | THIS SOFTWARE IS PROVIDED BY AXIOMATIC SYSTEMS ''AS IS'' AND ANY | EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | DISCLAIMED. IN NO EVENT SHALL AXIOMATIC SYSTEMS BE LIABLE FOR ANY | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ****************************************************************/ #if defined(NPT_CONFIG_ENABLE_TLS) /*---------------------------------------------------------------------- | includes +---------------------------------------------------------------------*/ #include "NptConfig.h" #include "NptTls.h" #include "NptLogging.h" #include "NptUtils.h" #include "NptSockets.h" #include "ssl.h" /*---------------------------------------------------------------------- | logging +---------------------------------------------------------------------*/ //NPT_SET_LOCAL_LOGGER("neptune.tls") /*---------------------------------------------------------------------- | constants +---------------------------------------------------------------------*/ const unsigned int NPT_TLS_CONTEXT_DEFAULT_SESSION_CACHE = 16; /*---------------------------------------------------------------------- | types +---------------------------------------------------------------------*/ typedef NPT_Reference<NPT_TlsSessionImpl> NPT_TlsSessionImplReference; /*---------------------------------------------------------------------- | NPT_Tls_MapResult +---------------------------------------------------------------------*/ static NPT_Result NPT_Tls_MapResult(int err) { switch (err) { case SSL_ERROR_CONN_LOST: return NPT_ERROR_CONNECTION_ABORTED; case SSL_ERROR_TIMEOUT: return NPT_ERROR_TIMEOUT; case SSL_ERROR_EOS: return NPT_ERROR_EOS; case SSL_ERROR_NOT_SUPPORTED: return NPT_ERROR_NOT_SUPPORTED; case SSL_ERROR_INVALID_HANDSHAKE: return NPT_ERROR_TLS_INVALID_HANDSHAKE; case SSL_ERROR_INVALID_PROT_MSG: return NPT_ERROR_TLS_INVALID_PROTOCOL_MESSAGE; case SSL_ERROR_INVALID_HMAC: return NPT_ERROR_TLS_INVALID_HMAC; case SSL_ERROR_INVALID_VERSION: return NPT_ERROR_TLS_INVALID_VERSION; case SSL_ERROR_INVALID_SESSION: return NPT_ERROR_TLS_INVALID_SESSION; case SSL_ERROR_NO_CIPHER: return NPT_ERROR_TLS_NO_CIPHER; case SSL_ERROR_BAD_CERTIFICATE: return NPT_ERROR_TLS_BAD_CERTIFICATE; case SSL_ERROR_INVALID_KEY: return NPT_ERROR_INVALID_KEY; case 0: return NPT_SUCCESS; default: return NPT_FAILURE; } } /*---------------------------------------------------------------------- | NPT_TlsContextImpl +---------------------------------------------------------------------*/ class NPT_TlsContextImpl { public: NPT_TlsContextImpl() : m_SSL_CTX(ssl_ctx_new(SSL_SERVER_VERIFY_LATER, NPT_TLS_CONTEXT_DEFAULT_SESSION_CACHE)) {}; ~NPT_TlsContextImpl() { ssl_ctx_free(m_SSL_CTX); } NPT_Result LoadKey(NPT_TlsKeyFormat key_format, const unsigned char* key_data, NPT_Size key_data_size, const char* password); SSL_CTX* m_SSL_CTX; }; /*---------------------------------------------------------------------- | NPT_TlsContextImpl::LoadKey +---------------------------------------------------------------------*/ NPT_Result NPT_TlsContextImpl::LoadKey(NPT_TlsKeyFormat key_format, const unsigned char* key_data, NPT_Size key_data_size, const char* password) { int object_type; switch (key_format) { case NPT_TLS_KEY_FORMAT_RSA_PRIVATE: object_type = SSL_OBJ_RSA_KEY; break; case NPT_TLS_KEY_FORMAT_PKCS8: object_type = SSL_OBJ_PKCS8; break; case NPT_TLS_KEY_FORMAT_PKCS12: object_type = SSL_OBJ_PKCS12; break; default: return NPT_ERROR_INVALID_PARAMETERS; } int result = ssl_obj_memory_load(m_SSL_CTX, object_type, key_data, key_data_size, password); return NPT_Tls_MapResult(result); } /*---------------------------------------------------------------------- | NPT_TlsStreamAdapter +---------------------------------------------------------------------*/ struct NPT_TlsStreamAdapter { static int Read(SSL_SOCKET* _self, void* buffer, unsigned int size) { NPT_TlsStreamAdapter* self = (NPT_TlsStreamAdapter*)_self; NPT_Size bytes_read = 0; NPT_Result result = self->m_Input->Read(buffer, size, &bytes_read); if (NPT_FAILED(result)) { switch (result) { case NPT_ERROR_EOS: return SSL_ERROR_EOS; case NPT_ERROR_TIMEOUT: return SSL_ERROR_TIMEOUT; default: return SSL_ERROR_CONN_LOST; } } return bytes_read; } static int Write(SSL_SOCKET* _self, const void* buffer, unsigned int size) { NPT_TlsStreamAdapter* self = (NPT_TlsStreamAdapter*)_self; NPT_Size bytes_written = 0; NPT_Result result = self->m_Output->Write(buffer, size, &bytes_written); if (NPT_FAILED(result)) { switch (result) { case NPT_ERROR_EOS: return SSL_ERROR_EOS; case NPT_ERROR_TIMEOUT: return SSL_ERROR_TIMEOUT; default: return SSL_ERROR_CONN_LOST; } } return bytes_written; } NPT_TlsStreamAdapter(NPT_InputStreamReference input, NPT_OutputStreamReference output) : m_Input(input), m_Output(output) { m_Base.Read = Read; m_Base.Write = Write; } SSL_SOCKET m_Base; NPT_InputStreamReference m_Input; NPT_OutputStreamReference m_Output; }; /*---------------------------------------------------------------------- | NPT_TlsSessionImpl +---------------------------------------------------------------------*/ class NPT_TlsSessionImpl { public: NPT_TlsSessionImpl(SSL_CTX* context, NPT_InputStreamReference& input, NPT_OutputStreamReference& output) : m_SSL_CTX(context), m_SSL(NULL), m_StreamAdapter(input, output) {} ~NPT_TlsSessionImpl() { ssl_free(m_SSL); } // methods NPT_Result Handshake(); NPT_Result GetSessionId(NPT_DataBuffer& session_id); NPT_UInt32 GetCipherSuiteId(); NPT_Result GetPeerCertificateInfo(NPT_TlsCertificateInfo& cert_info); // members SSL_CTX* m_SSL_CTX; SSL* m_SSL; NPT_TlsStreamAdapter m_StreamAdapter; }; /*---------------------------------------------------------------------- | NPT_TlsSessionImpl::Handshake +---------------------------------------------------------------------*/ NPT_Result NPT_TlsSessionImpl::Handshake() { if (m_SSL == NULL) { // we have not performed the handshake yet m_SSL = ssl_client_new(m_SSL_CTX, &m_StreamAdapter.m_Base, NULL, 0); } int result = ssl_handshake_status(m_SSL); return NPT_Tls_MapResult(result); } /*---------------------------------------------------------------------- | NPT_TlsSessionImpl::GetSessionId +---------------------------------------------------------------------*/ NPT_Result NPT_TlsSessionImpl::GetSessionId(NPT_DataBuffer& session_id) { if (m_SSL == NULL) { // no handshake done session_id.SetDataSize(0); return NPT_ERROR_INVALID_STATE; } // return the session id session_id.SetData(ssl_get_session_id(m_SSL), ssl_get_session_id_size(m_SSL)); return NPT_SUCCESS; } /*---------------------------------------------------------------------- | NPT_TlsSessionImpl::GetCipherSuiteId +---------------------------------------------------------------------*/ NPT_UInt32 NPT_TlsSessionImpl::GetCipherSuiteId() { if (m_SSL == NULL) { // no handshake done return 0; } return ssl_get_cipher_id(m_SSL); } /*---------------------------------------------------------------------- | NPT_TlsSessionImpl::GetPeerCertificateInfo +---------------------------------------------------------------------*/ NPT_Result NPT_TlsSessionImpl::GetPeerCertificateInfo(NPT_TlsCertificateInfo& cert_info) { cert_info.subject.common_name = ssl_get_cert_dn(m_SSL, SSL_X509_CERT_COMMON_NAME); cert_info.subject.organization = ssl_get_cert_dn(m_SSL, SSL_X509_CERT_ORGANIZATION); cert_info.subject.organizational_name = ssl_get_cert_dn(m_SSL, SSL_X509_CERT_ORGANIZATIONAL_NAME); cert_info.issuer.common_name = ssl_get_cert_dn(m_SSL, SSL_X509_CA_CERT_COMMON_NAME); cert_info.issuer.organization = ssl_get_cert_dn(m_SSL, SSL_X509_CA_CERT_ORGANIZATION); cert_info.issuer.organizational_name = ssl_get_cert_dn(m_SSL, SSL_X509_CA_CERT_ORGANIZATIONAL_NAME); ssl_get_cert_fingerprints(m_SSL, cert_info.fingerprint.md5, cert_info.fingerprint.sha1); return NPT_SUCCESS; } /*---------------------------------------------------------------------- | NPT_TlsInputStream +---------------------------------------------------------------------*/ class NPT_TlsInputStream : public NPT_InputStream { public: NPT_TlsInputStream(NPT_TlsSessionImplReference& session) : m_Session(session), m_Position(0), m_RecordCacheData(NULL), m_RecordCacheSize(0) {} // NPT_InputStream methods virtual NPT_Result Read(void* buffer, NPT_Size bytes_to_read, NPT_Size* bytes_read = NULL); virtual NPT_Result Seek(NPT_Position) { return NPT_ERROR_NOT_SUPPORTED; } virtual NPT_Result Tell(NPT_Position& offset) { offset = m_Position; return NPT_SUCCESS; } virtual NPT_Result GetSize(NPT_LargeSize& size) { size=0; return NPT_ERROR_NOT_SUPPORTED; } virtual NPT_Result GetAvailable(NPT_LargeSize& available); private: NPT_TlsSessionImplReference m_Session; NPT_Position m_Position; uint8_t* m_RecordCacheData; NPT_Size m_RecordCacheSize; }; /*---------------------------------------------------------------------- | NPT_TlsInputStream::Read +---------------------------------------------------------------------*/ NPT_Result NPT_TlsInputStream::Read(void* buffer, NPT_Size bytes_to_read, NPT_Size* bytes_read) { // setup default values if (bytes_read) *bytes_read = 0; // quick check for edge case if (bytes_to_read == 0) return NPT_SUCCESS; // read a new record if we don't have one cached if (m_RecordCacheData == NULL) { int ssl_result; do { ssl_result = ssl_read(m_Session->m_SSL, &m_RecordCacheData); } while (ssl_result == 0); if (ssl_result < 0) { return NPT_Tls_MapResult(ssl_result); } m_RecordCacheSize = ssl_result; } // we now have data in cache if (bytes_to_read > m_RecordCacheSize) { // read at most what's in the cache bytes_to_read = m_RecordCacheSize; } NPT_CopyMemory(buffer, m_RecordCacheData, bytes_to_read); if (bytes_read) *bytes_read = bytes_to_read; // update the record cache m_RecordCacheSize -= bytes_to_read; if (m_RecordCacheSize == 0) { // nothing left in the cache m_RecordCacheData = NULL; } else { // move the cache pointer m_RecordCacheData += bytes_to_read; } return NPT_SUCCESS; } /*---------------------------------------------------------------------- | NPT_TlsInputStream::GetAvailable +---------------------------------------------------------------------*/ NPT_Result NPT_TlsInputStream::GetAvailable(NPT_LargeSize& /*available*/) { return NPT_SUCCESS; } /*---------------------------------------------------------------------- | NPT_TlsOutputStream +---------------------------------------------------------------------*/ class NPT_TlsOutputStream : public NPT_OutputStream { public: NPT_TlsOutputStream(NPT_TlsSessionImplReference& session) : m_Session(session), m_Position(0) {} // NPT_OutputStream methods virtual NPT_Result Write(const void* buffer, NPT_Size bytes_to_write, NPT_Size* bytes_written = NULL); virtual NPT_Result Seek(NPT_Position) { return NPT_ERROR_NOT_SUPPORTED; } virtual NPT_Result Tell(NPT_Position& offset) { offset = m_Position; return NPT_SUCCESS; } private: NPT_TlsSessionImplReference m_Session; NPT_Position m_Position; }; /*---------------------------------------------------------------------- | NPT_TlsOutputStream::Write +---------------------------------------------------------------------*/ NPT_Result NPT_TlsOutputStream::Write(const void* buffer, NPT_Size bytes_to_write, NPT_Size* bytes_written) { // setup default values if (bytes_written) *bytes_written = 0; // quick check for edge case if (bytes_to_write == 0) return NPT_SUCCESS; // write some data int ssl_result; do { ssl_result = ssl_write(m_Session->m_SSL, (const uint8_t*)buffer, bytes_to_write); } while (ssl_result == 0); if (ssl_result < 0) { return NPT_Tls_MapResult(ssl_result); } m_Position += ssl_result; if (bytes_written) *bytes_written = (NPT_Size)ssl_result; return NPT_SUCCESS; } /*---------------------------------------------------------------------- | NPT_TlsContext::NPT_TlsContext +---------------------------------------------------------------------*/ NPT_TlsContext::NPT_TlsContext() : m_Impl(new NPT_TlsContextImpl()) { } /*---------------------------------------------------------------------- | NPT_TlsContext::~NPT_TlsContext +---------------------------------------------------------------------*/ NPT_TlsContext::~NPT_TlsContext() { delete m_Impl; } /*---------------------------------------------------------------------- | NPT_TlsContext::LoadKey +---------------------------------------------------------------------*/ NPT_Result NPT_TlsContext::LoadKey(NPT_TlsKeyFormat key_format, const unsigned char* key_data, NPT_Size key_data_size, const char* password) { return m_Impl->LoadKey(key_format, key_data, key_data_size, password); } /*---------------------------------------------------------------------- | NPT_TlsClientSession::NPT_TlsClientSession +---------------------------------------------------------------------*/ NPT_TlsClientSession::NPT_TlsClientSession(NPT_TlsContextReference& context, NPT_InputStreamReference& input, NPT_OutputStreamReference& output) : m_Context(context), m_Impl(new NPT_TlsSessionImpl(context->m_Impl->m_SSL_CTX, input, output)) { } /*---------------------------------------------------------------------- | NPT_TlsClientSession::~NPT_TlsClientSession +---------------------------------------------------------------------*/ NPT_TlsClientSession::~NPT_TlsClientSession() { } /*---------------------------------------------------------------------- | NPT_TlsClientSession::Handshake +---------------------------------------------------------------------*/ NPT_Result NPT_TlsClientSession::Handshake() { return m_Impl->Handshake(); } /*---------------------------------------------------------------------- | NPT_TlsClientSession::GetSessionId +---------------------------------------------------------------------*/ NPT_Result NPT_TlsClientSession::GetSessionId(NPT_DataBuffer& session_id) { return m_Impl->GetSessionId(session_id); } /*---------------------------------------------------------------------- | NPT_TlsClientSession::GetCipherSuiteId +---------------------------------------------------------------------*/ NPT_UInt32 NPT_TlsClientSession::GetCipherSuiteId() { return m_Impl->GetCipherSuiteId(); } /*---------------------------------------------------------------------- | NPT_TlsSession::GetPeerCertificateInfo +---------------------------------------------------------------------*/ NPT_Result NPT_TlsClientSession::GetPeerCertificateInfo(NPT_TlsCertificateInfo& cert_info) { return m_Impl->GetPeerCertificateInfo(cert_info); } /*---------------------------------------------------------------------- | NPT_TlsClientSession::GetInputStream +---------------------------------------------------------------------*/ NPT_Result NPT_TlsClientSession::GetInputStream(NPT_InputStreamReference& stream) { stream = new NPT_TlsInputStream(m_Impl); return NPT_SUCCESS; } /*---------------------------------------------------------------------- | NPT_TlsClientSession::GetOutputStream +---------------------------------------------------------------------*/ NPT_Result NPT_TlsClientSession::GetOutputStream(NPT_OutputStreamReference& stream) { stream = new NPT_TlsOutputStream(m_Impl); return NPT_SUCCESS; } #endif // NPT_CONFIG_ENABLE_TLS