comparison deps/Platinum/ThirdParty/Neptune/Source/Core/NptTls.cpp @ 0:3425707ddbf6

Initial import (hopefully this mercurial stuff works...)
author fraserofthenight
date Mon, 06 Jul 2009 08:06:28 -0700
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:3425707ddbf6
1 /*****************************************************************
2 |
3 | Neptune - TLS/SSL Support
4 |
5 | Copyright (c) 2002-2008, Axiomatic Systems, LLC.
6 | All rights reserved.
7 |
8 | Redistribution and use in source and binary forms, with or without
9 | modification, are permitted provided that the following conditions are met:
10 | * Redistributions of source code must retain the above copyright
11 | notice, this list of conditions and the following disclaimer.
12 | * Redistributions in binary form must reproduce the above copyright
13 | notice, this list of conditions and the following disclaimer in the
14 | documentation and/or other materials provided with the distribution.
15 | * Neither the name of Axiomatic Systems nor the
16 | names of its contributors may be used to endorse or promote products
17 | derived from this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY AXIOMATIC SYSTEMS ''AS IS'' AND ANY
20 | EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL AXIOMATIC SYSTEMS BE LIABLE FOR ANY
23 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
26 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
30 ****************************************************************/
31
32 #if defined(NPT_CONFIG_ENABLE_TLS)
33
34 /*----------------------------------------------------------------------
35 | includes
36 +---------------------------------------------------------------------*/
37 #include "NptConfig.h"
38 #include "NptTls.h"
39 #include "NptLogging.h"
40 #include "NptUtils.h"
41 #include "NptSockets.h"
42
43 #include "ssl.h"
44
45 /*----------------------------------------------------------------------
46 | logging
47 +---------------------------------------------------------------------*/
48 //NPT_SET_LOCAL_LOGGER("neptune.tls")
49
50 /*----------------------------------------------------------------------
51 | constants
52 +---------------------------------------------------------------------*/
53 const unsigned int NPT_TLS_CONTEXT_DEFAULT_SESSION_CACHE = 16;
54
55 /*----------------------------------------------------------------------
56 | types
57 +---------------------------------------------------------------------*/
58 typedef NPT_Reference<NPT_TlsSessionImpl> NPT_TlsSessionImplReference;
59
60 /*----------------------------------------------------------------------
61 | NPT_Tls_MapResult
62 +---------------------------------------------------------------------*/
63 static NPT_Result
64 NPT_Tls_MapResult(int err)
65 {
66 switch (err) {
67 case SSL_ERROR_CONN_LOST: return NPT_ERROR_CONNECTION_ABORTED;
68 case SSL_ERROR_TIMEOUT: return NPT_ERROR_TIMEOUT;
69 case SSL_ERROR_EOS: return NPT_ERROR_EOS;
70 case SSL_ERROR_NOT_SUPPORTED: return NPT_ERROR_NOT_SUPPORTED;
71 case SSL_ERROR_INVALID_HANDSHAKE: return NPT_ERROR_TLS_INVALID_HANDSHAKE;
72 case SSL_ERROR_INVALID_PROT_MSG: return NPT_ERROR_TLS_INVALID_PROTOCOL_MESSAGE;
73 case SSL_ERROR_INVALID_HMAC: return NPT_ERROR_TLS_INVALID_HMAC;
74 case SSL_ERROR_INVALID_VERSION: return NPT_ERROR_TLS_INVALID_VERSION;
75 case SSL_ERROR_INVALID_SESSION: return NPT_ERROR_TLS_INVALID_SESSION;
76 case SSL_ERROR_NO_CIPHER: return NPT_ERROR_TLS_NO_CIPHER;
77 case SSL_ERROR_BAD_CERTIFICATE: return NPT_ERROR_TLS_BAD_CERTIFICATE;
78 case SSL_ERROR_INVALID_KEY: return NPT_ERROR_INVALID_KEY;
79 case 0: return NPT_SUCCESS;
80 default: return NPT_FAILURE;
81 }
82 }
83
84 /*----------------------------------------------------------------------
85 | NPT_TlsContextImpl
86 +---------------------------------------------------------------------*/
87 class NPT_TlsContextImpl {
88 public:
89 NPT_TlsContextImpl() :
90 m_SSL_CTX(ssl_ctx_new(SSL_SERVER_VERIFY_LATER, NPT_TLS_CONTEXT_DEFAULT_SESSION_CACHE)) {};
91 ~NPT_TlsContextImpl() { ssl_ctx_free(m_SSL_CTX); }
92
93 NPT_Result LoadKey(NPT_TlsKeyFormat key_format,
94 const unsigned char* key_data,
95 NPT_Size key_data_size,
96 const char* password);
97
98 SSL_CTX* m_SSL_CTX;
99 };
100
101 /*----------------------------------------------------------------------
102 | NPT_TlsContextImpl::LoadKey
103 +---------------------------------------------------------------------*/
104 NPT_Result
105 NPT_TlsContextImpl::LoadKey(NPT_TlsKeyFormat key_format,
106 const unsigned char* key_data,
107 NPT_Size key_data_size,
108 const char* password)
109 {
110 int object_type;
111 switch (key_format) {
112 case NPT_TLS_KEY_FORMAT_RSA_PRIVATE: object_type = SSL_OBJ_RSA_KEY; break;
113 case NPT_TLS_KEY_FORMAT_PKCS8: object_type = SSL_OBJ_PKCS8; break;
114 case NPT_TLS_KEY_FORMAT_PKCS12: object_type = SSL_OBJ_PKCS12; break;
115 default: return NPT_ERROR_INVALID_PARAMETERS;
116 }
117
118 int result = ssl_obj_memory_load(m_SSL_CTX, object_type, key_data, key_data_size, password);
119 return NPT_Tls_MapResult(result);
120 }
121
122 /*----------------------------------------------------------------------
123 | NPT_TlsStreamAdapter
124 +---------------------------------------------------------------------*/
125 struct NPT_TlsStreamAdapter
126 {
127 static int Read(SSL_SOCKET* _self, void* buffer, unsigned int size) {
128 NPT_TlsStreamAdapter* self = (NPT_TlsStreamAdapter*)_self;
129 NPT_Size bytes_read = 0;
130 NPT_Result result = self->m_Input->Read(buffer, size, &bytes_read);
131 if (NPT_FAILED(result)) {
132 switch (result) {
133 case NPT_ERROR_EOS: return SSL_ERROR_EOS;
134 case NPT_ERROR_TIMEOUT: return SSL_ERROR_TIMEOUT;
135 default: return SSL_ERROR_CONN_LOST;
136 }
137 }
138 return bytes_read;
139 }
140
141 static int Write(SSL_SOCKET* _self, const void* buffer, unsigned int size) {
142 NPT_TlsStreamAdapter* self = (NPT_TlsStreamAdapter*)_self;
143 NPT_Size bytes_written = 0;
144 NPT_Result result = self->m_Output->Write(buffer, size, &bytes_written);
145 if (NPT_FAILED(result)) {
146 switch (result) {
147 case NPT_ERROR_EOS: return SSL_ERROR_EOS;
148 case NPT_ERROR_TIMEOUT: return SSL_ERROR_TIMEOUT;
149 default: return SSL_ERROR_CONN_LOST;
150 }
151 }
152 return bytes_written;
153 }
154
155 NPT_TlsStreamAdapter(NPT_InputStreamReference input,
156 NPT_OutputStreamReference output) :
157 m_Input(input), m_Output(output) {
158 m_Base.Read = Read;
159 m_Base.Write = Write;
160 }
161
162 SSL_SOCKET m_Base;
163 NPT_InputStreamReference m_Input;
164 NPT_OutputStreamReference m_Output;
165 };
166
167
168 /*----------------------------------------------------------------------
169 | NPT_TlsSessionImpl
170 +---------------------------------------------------------------------*/
171 class NPT_TlsSessionImpl {
172 public:
173 NPT_TlsSessionImpl(SSL_CTX* context,
174 NPT_InputStreamReference& input,
175 NPT_OutputStreamReference& output) :
176 m_SSL_CTX(context),
177 m_SSL(NULL),
178 m_StreamAdapter(input, output) {}
179 ~NPT_TlsSessionImpl() { ssl_free(m_SSL); }
180
181 // methods
182 NPT_Result Handshake();
183 NPT_Result GetSessionId(NPT_DataBuffer& session_id);
184 NPT_UInt32 GetCipherSuiteId();
185 NPT_Result GetPeerCertificateInfo(NPT_TlsCertificateInfo& cert_info);
186
187 // members
188 SSL_CTX* m_SSL_CTX;
189 SSL* m_SSL;
190 NPT_TlsStreamAdapter m_StreamAdapter;
191 };
192
193 /*----------------------------------------------------------------------
194 | NPT_TlsSessionImpl::Handshake
195 +---------------------------------------------------------------------*/
196 NPT_Result
197 NPT_TlsSessionImpl::Handshake()
198 {
199 if (m_SSL == NULL) {
200 // we have not performed the handshake yet
201 m_SSL = ssl_client_new(m_SSL_CTX, &m_StreamAdapter.m_Base, NULL, 0);
202 }
203
204 int result = ssl_handshake_status(m_SSL);
205 return NPT_Tls_MapResult(result);
206 }
207
208 /*----------------------------------------------------------------------
209 | NPT_TlsSessionImpl::GetSessionId
210 +---------------------------------------------------------------------*/
211 NPT_Result
212 NPT_TlsSessionImpl::GetSessionId(NPT_DataBuffer& session_id)
213 {
214 if (m_SSL == NULL) {
215 // no handshake done
216 session_id.SetDataSize(0);
217 return NPT_ERROR_INVALID_STATE;
218 }
219
220 // return the session id
221 session_id.SetData(ssl_get_session_id(m_SSL),
222 ssl_get_session_id_size(m_SSL));
223 return NPT_SUCCESS;
224 }
225
226 /*----------------------------------------------------------------------
227 | NPT_TlsSessionImpl::GetCipherSuiteId
228 +---------------------------------------------------------------------*/
229 NPT_UInt32
230 NPT_TlsSessionImpl::GetCipherSuiteId()
231 {
232 if (m_SSL == NULL) {
233 // no handshake done
234 return 0;
235 }
236
237 return ssl_get_cipher_id(m_SSL);
238 }
239
240 /*----------------------------------------------------------------------
241 | NPT_TlsSessionImpl::GetPeerCertificateInfo
242 +---------------------------------------------------------------------*/
243 NPT_Result
244 NPT_TlsSessionImpl::GetPeerCertificateInfo(NPT_TlsCertificateInfo& cert_info)
245 {
246 cert_info.subject.common_name = ssl_get_cert_dn(m_SSL, SSL_X509_CERT_COMMON_NAME);
247 cert_info.subject.organization = ssl_get_cert_dn(m_SSL, SSL_X509_CERT_ORGANIZATION);
248 cert_info.subject.organizational_name = ssl_get_cert_dn(m_SSL, SSL_X509_CERT_ORGANIZATIONAL_NAME);
249 cert_info.issuer.common_name = ssl_get_cert_dn(m_SSL, SSL_X509_CA_CERT_COMMON_NAME);
250 cert_info.issuer.organization = ssl_get_cert_dn(m_SSL, SSL_X509_CA_CERT_ORGANIZATION);
251 cert_info.issuer.organizational_name = ssl_get_cert_dn(m_SSL, SSL_X509_CA_CERT_ORGANIZATIONAL_NAME);
252
253 ssl_get_cert_fingerprints(m_SSL, cert_info.fingerprint.md5, cert_info.fingerprint.sha1);
254
255 return NPT_SUCCESS;
256 }
257
258 /*----------------------------------------------------------------------
259 | NPT_TlsInputStream
260 +---------------------------------------------------------------------*/
261 class NPT_TlsInputStream : public NPT_InputStream {
262 public:
263 NPT_TlsInputStream(NPT_TlsSessionImplReference& session) :
264 m_Session(session),
265 m_Position(0),
266 m_RecordCacheData(NULL),
267 m_RecordCacheSize(0) {}
268
269 // NPT_InputStream methods
270 virtual NPT_Result Read(void* buffer,
271 NPT_Size bytes_to_read,
272 NPT_Size* bytes_read = NULL);
273 virtual NPT_Result Seek(NPT_Position) { return NPT_ERROR_NOT_SUPPORTED; }
274 virtual NPT_Result Tell(NPT_Position& offset) { offset = m_Position; return NPT_SUCCESS; }
275 virtual NPT_Result GetSize(NPT_LargeSize& size) { size=0; return NPT_ERROR_NOT_SUPPORTED; }
276 virtual NPT_Result GetAvailable(NPT_LargeSize& available);
277
278 private:
279 NPT_TlsSessionImplReference m_Session;
280 NPT_Position m_Position;
281 uint8_t* m_RecordCacheData;
282 NPT_Size m_RecordCacheSize;
283 };
284
285 /*----------------------------------------------------------------------
286 | NPT_TlsInputStream::Read
287 +---------------------------------------------------------------------*/
288 NPT_Result
289 NPT_TlsInputStream::Read(void* buffer,
290 NPT_Size bytes_to_read,
291 NPT_Size* bytes_read)
292 {
293 // setup default values
294 if (bytes_read) *bytes_read = 0;
295
296 // quick check for edge case
297 if (bytes_to_read == 0) return NPT_SUCCESS;
298
299 // read a new record if we don't have one cached
300 if (m_RecordCacheData == NULL) {
301 int ssl_result;
302 do {
303 ssl_result = ssl_read(m_Session->m_SSL, &m_RecordCacheData);
304 } while (ssl_result == 0);
305 if (ssl_result < 0) {
306 return NPT_Tls_MapResult(ssl_result);
307 }
308 m_RecordCacheSize = ssl_result;
309 }
310
311 // we now have data in cache
312 if (bytes_to_read > m_RecordCacheSize) {
313 // read at most what's in the cache
314 bytes_to_read = m_RecordCacheSize;
315 }
316 NPT_CopyMemory(buffer, m_RecordCacheData, bytes_to_read);
317 if (bytes_read) *bytes_read = bytes_to_read;
318
319 // update the record cache
320 m_RecordCacheSize -= bytes_to_read;
321 if (m_RecordCacheSize == 0) {
322 // nothing left in the cache
323 m_RecordCacheData = NULL;
324 } else {
325 // move the cache pointer
326 m_RecordCacheData += bytes_to_read;
327 }
328
329 return NPT_SUCCESS;
330 }
331
332 /*----------------------------------------------------------------------
333 | NPT_TlsInputStream::GetAvailable
334 +---------------------------------------------------------------------*/
335 NPT_Result
336 NPT_TlsInputStream::GetAvailable(NPT_LargeSize& /*available*/)
337 {
338 return NPT_SUCCESS;
339 }
340
341 /*----------------------------------------------------------------------
342 | NPT_TlsOutputStream
343 +---------------------------------------------------------------------*/
344 class NPT_TlsOutputStream : public NPT_OutputStream {
345 public:
346 NPT_TlsOutputStream(NPT_TlsSessionImplReference& session) :
347 m_Session(session),
348 m_Position(0) {}
349
350 // NPT_OutputStream methods
351 virtual NPT_Result Write(const void* buffer,
352 NPT_Size bytes_to_write,
353 NPT_Size* bytes_written = NULL);
354 virtual NPT_Result Seek(NPT_Position) { return NPT_ERROR_NOT_SUPPORTED; }
355 virtual NPT_Result Tell(NPT_Position& offset) { offset = m_Position; return NPT_SUCCESS; }
356
357 private:
358 NPT_TlsSessionImplReference m_Session;
359 NPT_Position m_Position;
360 };
361
362 /*----------------------------------------------------------------------
363 | NPT_TlsOutputStream::Write
364 +---------------------------------------------------------------------*/
365 NPT_Result
366 NPT_TlsOutputStream::Write(const void* buffer,
367 NPT_Size bytes_to_write,
368 NPT_Size* bytes_written)
369 {
370 // setup default values
371 if (bytes_written) *bytes_written = 0;
372
373 // quick check for edge case
374 if (bytes_to_write == 0) return NPT_SUCCESS;
375
376 // write some data
377 int ssl_result;
378 do {
379 ssl_result = ssl_write(m_Session->m_SSL, (const uint8_t*)buffer, bytes_to_write);
380 } while (ssl_result == 0);
381 if (ssl_result < 0) {
382 return NPT_Tls_MapResult(ssl_result);
383 }
384 m_Position += ssl_result;
385 if (bytes_written) *bytes_written = (NPT_Size)ssl_result;
386
387 return NPT_SUCCESS;
388 }
389
390 /*----------------------------------------------------------------------
391 | NPT_TlsContext::NPT_TlsContext
392 +---------------------------------------------------------------------*/
393 NPT_TlsContext::NPT_TlsContext() :
394 m_Impl(new NPT_TlsContextImpl())
395 {
396 }
397
398 /*----------------------------------------------------------------------
399 | NPT_TlsContext::~NPT_TlsContext
400 +---------------------------------------------------------------------*/
401 NPT_TlsContext::~NPT_TlsContext()
402 {
403 delete m_Impl;
404 }
405
406 /*----------------------------------------------------------------------
407 | NPT_TlsContext::LoadKey
408 +---------------------------------------------------------------------*/
409 NPT_Result
410 NPT_TlsContext::LoadKey(NPT_TlsKeyFormat key_format,
411 const unsigned char* key_data,
412 NPT_Size key_data_size,
413 const char* password)
414 {
415 return m_Impl->LoadKey(key_format, key_data, key_data_size, password);
416 }
417
418 /*----------------------------------------------------------------------
419 | NPT_TlsClientSession::NPT_TlsClientSession
420 +---------------------------------------------------------------------*/
421 NPT_TlsClientSession::NPT_TlsClientSession(NPT_TlsContextReference& context,
422 NPT_InputStreamReference& input,
423 NPT_OutputStreamReference& output) :
424 m_Context(context),
425 m_Impl(new NPT_TlsSessionImpl(context->m_Impl->m_SSL_CTX, input, output))
426 {
427 }
428
429 /*----------------------------------------------------------------------
430 | NPT_TlsClientSession::~NPT_TlsClientSession
431 +---------------------------------------------------------------------*/
432 NPT_TlsClientSession::~NPT_TlsClientSession()
433 {
434 }
435
436 /*----------------------------------------------------------------------
437 | NPT_TlsClientSession::Handshake
438 +---------------------------------------------------------------------*/
439 NPT_Result
440 NPT_TlsClientSession::Handshake()
441 {
442 return m_Impl->Handshake();
443 }
444
445 /*----------------------------------------------------------------------
446 | NPT_TlsClientSession::GetSessionId
447 +---------------------------------------------------------------------*/
448 NPT_Result
449 NPT_TlsClientSession::GetSessionId(NPT_DataBuffer& session_id)
450 {
451 return m_Impl->GetSessionId(session_id);
452 }
453
454 /*----------------------------------------------------------------------
455 | NPT_TlsClientSession::GetCipherSuiteId
456 +---------------------------------------------------------------------*/
457 NPT_UInt32
458 NPT_TlsClientSession::GetCipherSuiteId()
459 {
460 return m_Impl->GetCipherSuiteId();
461 }
462
463 /*----------------------------------------------------------------------
464 | NPT_TlsSession::GetPeerCertificateInfo
465 +---------------------------------------------------------------------*/
466 NPT_Result
467 NPT_TlsClientSession::GetPeerCertificateInfo(NPT_TlsCertificateInfo& cert_info)
468 {
469 return m_Impl->GetPeerCertificateInfo(cert_info);
470 }
471
472 /*----------------------------------------------------------------------
473 | NPT_TlsClientSession::GetInputStream
474 +---------------------------------------------------------------------*/
475 NPT_Result
476 NPT_TlsClientSession::GetInputStream(NPT_InputStreamReference& stream)
477 {
478 stream = new NPT_TlsInputStream(m_Impl);
479 return NPT_SUCCESS;
480 }
481
482 /*----------------------------------------------------------------------
483 | NPT_TlsClientSession::GetOutputStream
484 +---------------------------------------------------------------------*/
485 NPT_Result
486 NPT_TlsClientSession::GetOutputStream(NPT_OutputStreamReference& stream)
487 {
488 stream = new NPT_TlsOutputStream(m_Impl);
489 return NPT_SUCCESS;
490 }
491
492 #endif // NPT_CONFIG_ENABLE_TLS