1 module hunt.net.secure.conscrypt.ClientSessionContext;
2 
3 // dfmt off
4 version(WITH_HUNT_SECURITY):
5 // dfmt on
6 
7 import hunt.net.secure.conscrypt.AbstractSessionContext;
8 import hunt.net.secure.conscrypt.NativeSslSession;
9 import hunt.net.secure.conscrypt.SSLClientSessionCache;
10 import hunt.net.secure.conscrypt.SSLServerSessionCache;
11 import hunt.net.secure.conscrypt.SSLParametersImpl;
12 
13 import hunt.net.ssl.SSLSession;
14 import hunt.net.ssl.SSLSessionContext;
15 
16 import hunt.collection;
17 import hunt.Exceptions;
18 
19 
20 /**
21  * Caches client sessions. Indexes by host and port. Users are typically
22  * looking to reuse any session for a given host and port.
23  *
24  */
25 final class ClientSessionContext : AbstractSessionContext {
26     /**
27      * Sessions indexed by host and port. Protect from concurrent
28      * access by holding a lock on sessionsByHostAndPort.
29      */
30     private Map!(HostAndPort, NativeSslSession) sessionsByHostAndPort;
31 
32     private SSLClientSessionCache persistentCache;
33 
34     this() {
35         super(10);
36         sessionsByHostAndPort = new HashMap!(HostAndPort, NativeSslSession)();
37     }
38 
39     /**
40      * Applications should not use this method. Instead use {@link
41      * Conscrypt#setClientSessionCache(SSLContext, SSLClientSessionCache)}.
42      */
43     void setPersistentCache(SSLClientSessionCache persistentCache) {
44         this.persistentCache = persistentCache;
45     }
46 
47     /**
48      * Gets the suitable session reference from the session cache container.
49      */
50     NativeSslSession getCachedSession(string hostName, int port, SSLParametersImpl sslParameters) {
51         if (hostName is null) {
52             return null;
53         }
54 
55         NativeSslSession session = getSession(hostName, port);
56         if (session is null) {
57             return null;
58         }
59 
60         implementationMissing();
61 
62         // string protocol = session.getProtocol();
63         // bool protocolFound = false;
64         // foreach (string enabledProtocol ; sslParameters.enabledProtocols) {
65         //     if (protocol.equals(enabledProtocol)) {
66         //         protocolFound = true;
67         //         break;
68         //     }
69         // }
70         // if (!protocolFound) {
71         //     return null;
72         // }
73 
74         // string cipherSuite = session.getCipherSuite();
75         // bool cipherSuiteFound = false;
76         // foreach (string enabledCipherSuite ; sslParameters.enabledCipherSuites) {
77         //     if (cipherSuite.equals(enabledCipherSuite)) {
78         //         cipherSuiteFound = true;
79         //         break;
80         //     }
81         // }
82         // if (!cipherSuiteFound) {
83         //     return null;
84         // }
85 
86         return session;
87     }
88 
89     int size() {
90         return sessionsByHostAndPort.size();
91     }
92 
93     /**
94      * Finds a cached session for the given host name and port.
95      *
96      * @param host of server
97      * @param port of server
98      * @return cached session or null if none found
99      */
100     private NativeSslSession getSession(string host, int port) {
101         if (host is null) {
102             return null;
103         }
104 
105         HostAndPort key = new HostAndPort(host, port);
106         NativeSslSession session;
107         synchronized (sessionsByHostAndPort) {
108             session = sessionsByHostAndPort.get(key);
109         }
110         if (session !is null && session.isValid()) {
111             return session;
112         }
113 
114         // Look in persistent cache.
115         if (persistentCache !is null) {
116             byte[] data = persistentCache.getSessionData(host, port);
117             if (data !is null) {
118                 session = NativeSslSession.newInstance(this, data, host, port);
119                 if (session !is null && session.isValid()) {
120                     synchronized (sessionsByHostAndPort) {
121                         sessionsByHostAndPort.put(key, session);
122                     }
123                     return session;
124                 }
125             }
126         }
127 
128         return null;
129     }
130 
131     override
132     void onBeforeAddSession(NativeSslSession session) {
133         string host = session.getPeerHost();
134         int port = session.getPeerPort();
135         if (host is null) {
136             return;
137         }
138 
139         HostAndPort key = new HostAndPort(host, port);
140         synchronized (sessionsByHostAndPort) {
141             sessionsByHostAndPort.put(key, session);
142         }
143 
144         // TODO: Do this in a background thread.
145         if (persistentCache !is null) {
146             byte[] data = session.toBytes();
147             if (data !is null) {
148                 persistentCache.putSessionData(session.toSSLSession(), data);
149             }
150         }
151     }
152 
153     override
154     void onBeforeRemoveSession(NativeSslSession session) {
155         string host = session.getPeerHost();
156         if (host is null) {
157             return;
158         }
159         int port = session.getPeerPort();
160         HostAndPort hostAndPortKey = new HostAndPort(host, port);
161         synchronized (sessionsByHostAndPort) {
162             sessionsByHostAndPort.remove(hostAndPortKey);
163         }
164     }
165 
166     override
167     NativeSslSession getSessionFromPersistentCache(byte[] sessionId) {
168         // Not implemented for clients.
169         return null;
170     }
171 
172     private static final class HostAndPort {
173         string host;
174         int port;
175 
176         this(string host, int port) {
177             this.host = host;
178             this.port = port;
179         }
180 
181         override size_t toHash() @trusted nothrow {
182             return hashOf(host) * 31 + port;
183         }
184 
185         override
186         bool opEquals(Object o) {
187             if (typeid(o) != typeid(HostAndPort)) {
188                 return false;
189             }
190             HostAndPort lhs = cast(HostAndPort) o;
191             return host == lhs.host && port == lhs.port;
192         }
193     }
194 }
195