1 module hunt.net.secure.AbstractSecureSession;
2 
3 // dfmt off
4 version(WITH_HUNT_SECURITY):
5 // dfmt on
6 
7 import hunt.net.secure.ProtocolSelector;
8 import hunt.net.secure.SecureSession;
9 import hunt.net.Exceptions;
10 import hunt.net.Connection;
11 import hunt.net.ssl;
12 
13 
14 import hunt.collection;
15 import hunt.concurrency.CountingCallback;
16 import hunt.Exceptions;
17 import hunt.stream.Common;
18 import hunt.text.Common;
19 import hunt.util.Common;
20 
21 import hunt.logging;
22 
23 import std.array;
24 import std.conv;
25 import std.format;
26 
27 
28 abstract class AbstractSecureSession : SecureSession {
29 
30     protected __gshared ByteBuffer hsBuffer;
31 
32     protected Connection session;
33     protected SSLEngine sslEngine;
34     protected ProtocolSelector applicationProtocolSelector;
35     protected SecureSessionHandshakeListener handshakeListener;
36 
37     protected ByteBuffer receivedPacketBuf;
38     protected ByteBuffer receivedAppBuf;
39 
40     protected bool closed = false;
41     protected HandshakeStatus initialHSStatus;
42     protected bool initialHSComplete;
43 
44     shared static this() {
45         hsBuffer = new HeapByteBuffer(0,0); 
46     }
47 
48     this(Connection session, SSLEngine sslEngine,
49             ProtocolSelector applicationProtocolSelector,
50             SecureSessionHandshakeListener handshakeListener) {
51         this.session = session;
52         this.sslEngine = sslEngine;
53         this.applicationProtocolSelector = applicationProtocolSelector;
54         this.handshakeListener = handshakeListener;
55 
56         SSLSession ses = sslEngine.getSession();
57         receivedAppBuf = newBuffer(ses.getApplicationBufferSize());
58         // receivedAppBuf = newBuffer(sslEngine.getSession().getApplicationBufferSize());
59         initialHSComplete = false;
60 
61         // start tls
62         version(HUNT_DEBUG) info("Starting TLS ...");
63         this.sslEngine.beginHandshake();
64         initialHSStatus = sslEngine.getHandshakeStatus();
65         if (sslEngine.getUseClientMode()) {
66             doHandshakeResponse();
67         }
68     }
69 
70     /**
71      * The initial handshake is a procedure by which the two peers exchange
72      * communication parameters until an SecureSession is established. Application
73      * data can not be sent during this phase.
74      *
75      * @param receiveBuffer Encrypted message
76      * @return True means handshake success
77      * @The I/O exception
78      */
79     protected bool doHandshake(ByteBuffer receiveBuffer) {
80         try {
81             return _doHandshake(receiveBuffer);
82         } catch(Exception ex) {
83             debug error(ex.msg);
84             version(HUNT_NET_DEBUG) warning(ex);
85             
86             throw new SSLHandshakeException(ex.msg);
87             // return false;
88         }
89     }
90     
91     protected bool _doHandshake(ByteBuffer receiveBuffer) {
92         if (!session.isConnected()) {
93             close();
94             return (initialHSComplete = false);
95         }
96 
97         if (initialHSComplete) {
98             return true;
99         }
100 
101         switch (initialHSStatus) {
102             case HandshakeStatus.NOT_HANDSHAKING:
103             case HandshakeStatus.FINISHED: {
104                 handshakeFinish();
105                 return initialHSComplete;
106             }
107 
108             case HandshakeStatus.NEED_UNWRAP:
109                 doHandshakeReceive(receiveBuffer);
110                 if (initialHSStatus == HandshakeStatus.NEED_WRAP)
111                     doHandshakeResponse();
112                 break;
113 
114             case HandshakeStatus.NEED_WRAP:
115                 doHandshakeResponse();
116                 break;
117 
118             default: // NEED_TASK
119                 throw new SecureNetException("Invalid Handshaking State" ~ initialHSStatus.to!string());
120         }
121         return initialHSComplete;
122     }
123 
124     protected void doHandshakeReceive(ByteBuffer receiveBuffer) {
125         merge(receiveBuffer);
126         needIO:
127         while (initialHSStatus == HandshakeStatus.NEED_UNWRAP) {
128 
129             unwrapLabel:
130             while (true) {
131                 SSLEngineResult result = unwrap();
132                 initialHSStatus = result.getHandshakeStatus();
133 
134                 version(HUNT_NET_DEBUG_MORE) {
135                     tracef("Connection %s handshake result -> %s, initialHSStatus -> %s, inNetRemain -> %s", 
136                         session.getId(), result.toString(), initialHSStatus, receivedPacketBuf.remaining());
137                 }
138 
139                 switch (result.getStatus()) {
140                     case SSLEngineResult.Status.OK: {
141                         switch (initialHSStatus) {
142                             case HandshakeStatus.NEED_TASK:
143                                 initialHSStatus = doTasks();
144                                 break unwrapLabel;
145                             case HandshakeStatus.NOT_HANDSHAKING:
146                             case HandshakeStatus.FINISHED:
147                                 handshakeFinish();
148                                 break needIO;
149                             default:
150                                 break unwrapLabel;
151                         }
152                     }
153 
154                     case SSLEngineResult.Status.BUFFER_UNDERFLOW: {
155                         switch (initialHSStatus) {
156                             case HandshakeStatus.NOT_HANDSHAKING:
157                             case HandshakeStatus.FINISHED:
158                                 handshakeFinish();
159                                 break needIO;
160                             default:
161                                 break;
162                         }
163 
164                         int packetBufferSize = sslEngine.getSession().getPacketBufferSize();
165                         if (receivedPacketBuf.remaining() >= packetBufferSize) {
166                             break; // retry the operation.
167                         } else {
168                             break needIO;
169                         }
170                     }
171 
172                     case SSLEngineResult.Status.BUFFER_OVERFLOW: {
173                         resizeAppBuffer();
174                         // retry the operation.
175                     }
176                     break;
177 
178                     case SSLEngineResult.Status.CLOSED: {
179                         infof("Connection %s handshake failure. SSLEngine will close inbound", session.getId());
180                         closeInbound();
181                     }
182                     break needIO;
183 
184                     default:
185                         throw new SecureNetException(format("Connection %s handshake exception. status -> %s", session.getId(), result.getStatus()));
186 
187                 }
188             }
189         }
190     }
191 
192     protected void handshakeFinish() {
193         version(HUNT_NET_DEBUG) infof("Connection %s handshake success. The application protocol is %s", 
194             session.getId(), getApplicationProtocol());
195         initialHSComplete = true;
196         // FIXME: Needing refactor or cleanup -@zhangxueping at 2020-05-12T18:28:36+08:00
197         // There may be some remaining data which is not consumed yet.
198         receivedAppBuf.clear();
199         if(handshakeListener !is null)
200             handshakeListener(this);
201     }
202 
203 
204     protected void doHandshakeResponse() {
205 
206         outer:
207         while (initialHSStatus == HandshakeStatus.NEED_WRAP) {
208             SSLEngineResult result;
209             ByteBuffer packetBuffer = newBuffer(sslEngine.getSession().getPacketBufferSize());
210 
211             wrap:
212             while (true) {
213                 result = sslEngine.wrap(hsBuffer, packetBuffer);
214                 initialHSStatus = result.getHandshakeStatus();
215                 version(HUNT_NET_DEBUG) {
216                     infof("session %s handshake response, init: %s | ret: %s | complete: %s ",
217                             session.getId(), initialHSStatus, result.getStatus(), initialHSComplete);
218                 }
219 
220                 switch (result.getStatus()) {
221                     case SSLEngineResult.Status.OK: {
222                         packetBuffer.flip();
223                         version(HUNT_NET_DEBUG) {
224                             tracef("session %s handshake response %s bytes", 
225                                 session.getId(), packetBuffer.remaining());
226                         }
227 
228                         switch (initialHSStatus) {
229                             case HandshakeStatus.NEED_TASK: {
230                                 initialHSStatus = doTasks();
231                                 if (packetBuffer.hasRemaining()) {
232                                     session.write(packetBuffer);
233                                 }
234                             }
235                             break;
236 
237                             case HandshakeStatus.FINISHED: {
238                                 if (packetBuffer.hasRemaining()) {
239                                     session.write(packetBuffer);
240                                     handshakeFinish();
241                                 } else {
242                                     handshakeFinish();
243                                 }
244                             }
245                             break;
246 
247                             default: {
248                                 if (packetBuffer.hasRemaining()) {
249                                     session.write(packetBuffer);
250                                 }
251                             }
252                         }
253                     }
254                     break wrap;
255 
256                     case SSLEngineResult.Status.BUFFER_OVERFLOW:
257                         ByteBuffer b = newBuffer(packetBuffer.position() + sslEngine.getSession().getPacketBufferSize());
258                         packetBuffer.flip();
259                         b.put(packetBuffer);
260                         packetBuffer = b;
261                         break;
262 
263                     case SSLEngineResult.Status.CLOSED:
264                         warningf("Connection %s handshake failure. SSLEngine will close inbound", session.getId());
265                         packetBuffer.flip();
266                         if (packetBuffer.hasRemaining()) {
267                             session.write(packetBuffer);
268                         }
269                         closeOutbound();
270                         break outer;
271 
272                     default: // BUFFER_UNDERFLOW
273                         throw new SecureNetException(format("Connection %s handshake exception. status -> %s", 
274                             session.getId(), result.getStatus()));
275                 }
276             }
277         }
278     }
279 
280     protected void resizeAppBuffer() {
281         int applicationBufferSize = sslEngine.getSession().getApplicationBufferSize();
282         ByteBuffer b = newBuffer(receivedAppBuf.position() + applicationBufferSize);
283         receivedAppBuf.flip();
284         b.put(receivedAppBuf);
285         receivedAppBuf = b;
286     }
287 
288     protected void merge(ByteBuffer now) {
289         if (!now.hasRemaining()) {
290             return;
291         }
292 
293         if (receivedPacketBuf !is null) {
294             if (receivedPacketBuf.hasRemaining()) {
295                 version(HUNT_NET_DEBUG_MORE) {
296                     tracef("Connection %s read data, merge buffer -> buffered: %d, incoming: %d", session.getId(),
297                             receivedPacketBuf.remaining(), now.remaining());
298                 }
299                 ByteBuffer ret = newBuffer(receivedPacketBuf.remaining() + now.remaining());
300                 ret.put(receivedPacketBuf).put(now).flip();
301                 receivedPacketBuf = ret;
302             } else {
303                 version(HUNT_NET_DEBUG)  {
304                     tracef("buffering data: %s, current buffer: %s", 
305                         now.toString(), receivedPacketBuf.toString());
306                 }
307 
308                 if(now.remaining() <= receivedPacketBuf.remaining()) {
309                     receivedPacketBuf.clear();
310                     receivedPacketBuf.put(now).flip();
311                 } else {
312                     ByteBuffer ret = newBuffer(now.remaining());
313                     ret.put(now).flip();
314                     receivedPacketBuf = ret;
315                 }
316             }
317         } else {
318             version(HUNT_NET_DEBUG) tracef("buffering data: %d", now.remaining());
319             ByteBuffer ret = newBuffer(now.remaining());
320             ret.put(now).flip();
321             receivedPacketBuf = ret;
322         }
323     }
324 
325     protected ByteBuffer getReceivedAppBuf() {
326         receivedAppBuf.flip();
327         version(HUNT_NET_DEBUG_MORE) {
328             tracef("Connection %s read data, get app buf -> %s, %s", 
329                 session.getId(), receivedAppBuf.position(), receivedAppBuf.limit());
330         }
331 
332         if (receivedAppBuf.hasRemaining()) {
333             ByteBuffer buf = newBuffer(receivedAppBuf.remaining());
334             buf.put(receivedAppBuf).flip();
335             receivedAppBuf = newBuffer(sslEngine.getSession().getApplicationBufferSize());
336             version(HUNT_NET_DEBUG_MORE) {
337                 tracef("SSL session %s unwrap, app buffer -> %s", session.getId(), buf.remaining());
338             }
339             return buf;
340         } else {
341             return null;
342         }
343     }
344 
345     /**
346      * Do all the outstanding handshake tasks in the current Thread.
347      *
348      * @return The result of handshake
349      */
350     protected HandshakeStatus doTasks() {
351         // Runnable runnable;
352 
353         // // We could run this in a separate thread, but do in the current for
354         // // now.
355         // while ((runnable = sslEngine.getDelegatedTask()) !is null) {
356         //     runnable.run();
357         // }
358         // return sslEngine.getHandshakeStatus();
359         implementationMissing(false);
360         return HandshakeStatus.FINISHED;
361     }
362 
363     // override
364     void close() {
365         if (!closed) {
366             closed = true;
367             closeOutbound();
368         }
369     }
370 
371     protected void closeInbound() {
372         try {
373             sslEngine.closeInbound();
374         } catch (SSLException e) {
375             warning("close inbound exception", e);
376         } finally {
377             session.shutdownInput();
378         }
379     }
380 
381     protected void closeOutbound() {
382         sslEngine.closeOutbound();
383         session.close();
384     }
385 
386     override
387     string getApplicationProtocol() {
388         string protocol = applicationProtocolSelector.getApplicationProtocol();
389         version(HUNT_NET_DEBUG) tracef("selected protocol -> %s", protocol);
390         return protocol;
391     }
392 
393     override
394     string[] getSupportedApplicationProtocols() {
395         return applicationProtocolSelector.getSupportedApplicationProtocols();
396     }
397 
398     override
399     bool isOpen() {
400         return !closed;
401     }
402 
403     protected ByteBuffer splitBuffer(int netSize) {
404         ByteBuffer buf = receivedPacketBuf.duplicate();
405         if (buf.remaining() <= netSize) {
406             return buf;
407         } else {
408             ByteBuffer splitBuf = newBuffer(netSize);
409             byte[] data = new byte[netSize];
410             buf.get(data);
411             splitBuf.put(data).flip();
412             return splitBuf;
413         }
414     }
415 
416     protected SSLEngineResult unwrap(ByteBuffer input) {
417         // version(HUNT_NET_DEBUG_MORE) {
418         //     tracef("Connection %d read data, src -> %s, dst -> %s", 
419         //         session.getId(), input.isDirect(), receivedAppBuf.isDirect());
420         // }
421         version(HUNT_NET_DEBUG_MORE) infof("receivedAppBuf=%s", receivedAppBuf.toString());
422         SSLEngineResult result = sslEngine.unwrap(input, receivedAppBuf);
423         if (input !is receivedPacketBuf) {
424             int consumed = result.bytesConsumed();
425             version(HUNT_NET_DEBUG_MORE) infof("receivedAppBuf=%s, consumed=%d", receivedAppBuf.toString(), consumed);
426             receivedPacketBuf.position(receivedPacketBuf.position() + consumed);
427         }
428         return result;
429     }
430 
431     protected SSLEngineResult wrap(ByteBuffer src, ByteBuffer dst) {
432         return sslEngine.wrap(src, dst);
433     }
434 
435     protected ByteBuffer newBuffer(int size) {
436         return BufferUtils.allocate(size);
437     }
438 
439     protected SSLEngineResult unwrap() {
440         if(sslEngine is null)
441             throw new SecureNetException("The SSL Engine is invalid!");
442 
443         SSLSession sslSession = sslEngine.getSession();
444         if(sslSession is null) {
445             throw new SecureNetException("The SSL Session is invalid now.");
446         }
447 
448         int packetBufferSize = sslSession.getPacketBufferSize();
449         //split net buffer when the net buffer remaining great than the net size
450         ByteBuffer buf = splitBuffer(packetBufferSize);
451         version(HUNT_NET_DEBUG_MORE) {
452             tracef("Connection %s read data, buf -> %d, packet -> %d, appBuf -> %d",
453                     session.getId(), buf.remaining(), packetBufferSize, receivedAppBuf.remaining());
454         }
455         if (!receivedAppBuf.hasRemaining()) {
456             resizeAppBuffer();
457         }
458         return unwrap(buf);
459     }
460 
461     /**
462      * This method is used to decrypt data, it implied do handshake
463      *
464      * @param receiveBuffer Encrypted message
465      * @return plaintext
466      * @sslEngine error during data read
467      */
468     override
469     ByteBuffer read(ByteBuffer receiveBuffer) {
470         if (!doHandshake(receiveBuffer)) {
471             return null;
472         }
473 
474         if (!initialHSComplete)
475             throw new IllegalStateException("The initial handshake is not complete.");
476 
477         version(HUNT_NET_DEBUG_MORE) {
478             tracef("session %s read data status -> %s, initialHSComplete -> %s", session.getId(),
479                     session.isConnected(), initialHSComplete);
480         }
481 
482         merge(receiveBuffer);
483 
484         if (!receivedPacketBuf.hasRemaining()) {
485             return null;
486         }
487 
488         needIO:
489         while (true) {
490             SSLEngineResult result = unwrap();
491 
492             version(HUNT_NET_DEBUG_MORE) {
493                 tracef("Connection %s read data result -> %s, receivedPacketBuf -> %s, receivedAppBuf -> %s",
494                         session.getId(), result.toString().replace("\n", " "),
495                         receivedPacketBuf.remaining(), receivedAppBuf.remaining());
496             }
497 
498             switch (result.getStatus()) {
499                 case SSLEngineResult.Status.BUFFER_OVERFLOW: {
500                     resizeAppBuffer();
501                     // retry the operation.
502                 }
503                 break;
504                 case SSLEngineResult.Status.BUFFER_UNDERFLOW: {
505                     int packetBufferSize = sslEngine.getSession().getPacketBufferSize();
506                     if (receivedPacketBuf.remaining() >= packetBufferSize) {
507                         break; // retry the operation.
508                     } else {
509                         break needIO;
510                     }
511                 }
512                 case SSLEngineResult.Status.OK: {
513                     if (result.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
514                         doTasks();
515                     }
516                     if (receivedPacketBuf.hasRemaining()) {
517                         break; // retry the operation.
518                     } else {
519                         break needIO;
520                     }
521                 }
522 
523                 case SSLEngineResult.Status.CLOSED: {
524                     infof("Connection %s read data failure. SSLEngine will close inbound", session.getId());
525                     closeInbound();
526                 }
527                 break needIO;
528 
529                 default:
530                     throw new SecureNetException(format("Connection %s SSLEngine read data exception. status -> %s",
531                             session.getId(), result.getStatus()));
532             }
533         }
534 
535         return getReceivedAppBuf();
536     }
537 
538     override
539     int write(ByteBuffer[] outputBuffers, Callback callback) {
540         int ret = 0;
541         CountingCallback countingCallback = new CountingCallback(callback, cast(int)outputBuffers.length);
542         foreach (ByteBuffer outputBuffer ; outputBuffers) {
543             ret += write(outputBuffer, countingCallback);
544         }
545         return ret;
546     }
547 
548     /**
549      * This method is used to encrypt and flush to socket channel
550      *
551      * @param outAppBuf Plaintext message
552      * @return writen length
553      * @sslEngine error during data write
554      */
555     override
556     int write(ByteBuffer outAppBuf, Callback callback) {
557         if (!initialHSComplete) {
558             IllegalStateException ex = new IllegalStateException("The initial handshake is not complete.");
559             callback.failed(ex);
560             throw ex;
561         }
562 
563         int ret = 0;
564         if (!outAppBuf.hasRemaining()) {
565             callback.succeeded();
566             return ret;
567         }
568 
569         int remain = outAppBuf.remaining();
570         int packetBufferSize = sslEngine.getSession().getPacketBufferSize();
571         List!ByteBuffer pocketBuffers = new ArrayList!ByteBuffer();
572         bool closeOutput = false;
573 
574         outer:
575         while (ret < remain) {
576             ByteBuffer packetBuffer = newBuffer(packetBufferSize);
577 
578             wrap:
579             while (true) {
580                 SSLEngineResult result = wrap(outAppBuf, packetBuffer);
581                 ret += result.bytesConsumed();
582 
583                 switch (result.getStatus()) {
584                     case SSLEngineResult.Status.OK: {
585                         if (result.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
586                             doTasks();
587                         }
588 
589                         packetBuffer.flip();
590                         if (packetBuffer.hasRemaining()) {
591                             pocketBuffers.add(packetBuffer);
592                         }
593                     }
594                     break wrap;
595 
596                     case SSLEngineResult.Status.BUFFER_OVERFLOW: {
597                         packetBufferSize = sslEngine.getSession().getPacketBufferSize();
598                         ByteBuffer b = newBuffer(packetBuffer.position() + packetBufferSize);
599                         packetBuffer.flip();
600                         b.put(packetBuffer);
601                         packetBuffer = b;
602                     }
603                     break; // retry the operation.
604 
605                     case SSLEngineResult.Status.CLOSED: {
606                         infof("Connection %s SSLEngine will close", session.getId());
607                         packetBuffer.flip();
608                         if (packetBuffer.hasRemaining()) {
609                             pocketBuffers.add(packetBuffer);
610                         }
611                         closeOutput = true;
612                     }
613                     break outer;
614 
615                     default: {
616                         SecureNetException ex = new SecureNetException(format("Connection %s SSLEngine writes data exception. status -> %s", session.getId(), result.getStatus()));
617                         callback.failed(ex);
618                         throw ex;
619                     }
620                 }
621             }
622         }
623 
624         foreach(ByteBuffer pocket; pocketBuffers) {
625             session.write(pocket); // , callback
626         }
627 
628         callback.succeeded();
629         if (closeOutput) {
630             closeOutbound();
631         }
632         return ret;
633     }
634 
635     // protected class FileBufferReaderHandler : BufferReaderHandler {
636 
637     //     private long len;
638 
639     //     private this(long len) {
640     //         this.len = len;
641     //     }
642 
643     //     override
644     //     void readBuffer(ByteBuffer buf, CountingCallback countingCallback, long count) {
645     //         tracef("write file,  count: %d , length: %d", count, len);
646     //         try {
647     //             write(buf, countingCallback);
648     //         } catch (Exception e) {
649     //             errorf("ssl session writing error: ", e.msg);
650     //         }
651     //     }
652 
653     // }
654 
655     // override
656     // long transferFileRegion(FileRegion file, Callback callback) {
657     //     long ret = 0;
658     //     try  {
659     //         FileRegion fileRegion = file;
660     //         fileRegion.transferTo(callback, new FileBufferReaderHandler(file.getLength()));
661     //     }
662     //     return ret;
663     // }
664 
665     override
666     bool isHandshakeFinished() {
667         return initialHSComplete;
668     }
669 
670     override
671     bool isClientMode() {
672         return sslEngine.getUseClientMode();
673     }
674 }