View Javadoc

1   /*
2    *   @(#) $Id: SSLFilter.java 170524 2005-05-17 07:04:49Z trustin $
3    *
4    *   Copyright 2004 The Apache Software Foundation
5    *
6    *   Licensed under the Apache License, Version 2.0 (the "License");
7    *   you may not use this file except in compliance with the License.
8    *   You may obtain a copy of the License at
9    *
10   *       http://www.apache.org/licenses/LICENSE-2.0
11   *
12   *   Unless required by applicable law or agreed to in writing, software
13   *   distributed under the License is distributed on an "AS IS" BASIS,
14   *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15   *   See the License for the specific language governing permissions and
16   *   limitations under the License.
17   *
18   */
19  package org.apache.mina.io.filter;
20  
21  import java.util.IdentityHashMap;
22  import java.util.Map;
23  import java.util.logging.Level;
24  import java.util.logging.Logger;
25  
26  import javax.net.ssl.SSLContext;
27  import javax.net.ssl.SSLEngine;
28  import javax.net.ssl.SSLException;
29  import javax.net.ssl.SSLHandshakeException;
30  
31  import org.apache.mina.common.ByteBuffer;
32  import org.apache.mina.io.IoFilterAdapter;
33  import org.apache.mina.io.IoHandler;
34  import org.apache.mina.io.IoSession;
35  
36  /***
37   * An SSL filter that encrypts and decrypts the data exchanged in the session.
38   * This filter uses an {@link SSLEngine} which was introduced in Java 5, so 
39   * Java version 5 or above is mandatory to use this filter. And please note that
40   * this filter only works for TCP/IP connections.
41   * <p>
42   * This filter logs debug information in {@link Level#FINEST} using {@link Logger}.
43   * 
44   * @author The Apache Directory Project (dev@directory.apache.org)
45   * @version $Rev: 170524 $, $Date: 2005-05-17 16:04:49 +0900 (?, 17  5? 2005) $
46   */
47  public class SSLFilter extends IoFilterAdapter
48  {
49      private static final Logger log = Logger.getLogger( SSLFilter.class.getName() );
50  
51      /***
52       * A marker which is passed with {@link IoHandler#dataWritten(IoSession, Object)}
53       * when <tt>SSLFilter</tt> writes data other then user actually requested.
54       */
55      private static final Object SSL_MARKER = new Object()
56      {
57          public String toString()
58          {
59              return "SSL_MARKER";
60          }
61      };
62      
63      // SSL Context
64      private SSLContext sslContext;
65  
66      // Map used to map SSLHandler objects per session (key is IoSession)
67      private Map sslSessionHandlerMap = new IdentityHashMap();
68  
69      private boolean client;
70      private boolean needClientAuth;
71      private boolean wantClientAuth;
72      private String[] enabledCipherSuites;
73      private String[] enabledProtocols;
74  
75      /***
76       * Creates a new SSL filter using the specified {@link SSLContext}.
77       */
78      public SSLFilter( SSLContext sslContext )
79      {
80          if( sslContext == null )
81          {
82              throw new NullPointerException( "sslContext" );
83          }
84  
85          this.sslContext = sslContext;
86      }
87  
88      /***
89       * Returns <tt>true</tt> if the engine is set to use client mode
90       * when handshaking.
91       */
92      public boolean isUseClientMode()
93      {
94          return client;
95      }
96      
97      /***
98       * Configures the engine to use client (or server) mode when handshaking.
99       */
100     public void setUseClientMode( boolean clientMode )
101     {
102         this.client = clientMode;
103     }
104     
105     /***
106      * Returns <tt>true</tt> if the engine will <em>require</em> client authentication.
107      * This option is only useful to engines in the server mode.
108      */
109     public boolean isNeedClientAuth()
110     {
111         return needClientAuth;
112     }
113 
114     /***
115      * Configures the engine to <em>require</em> client authentication.
116      * This option is only useful for engines in the server mode.
117      */
118     public void setNeedClientAuth( boolean needClientAuth )
119     {
120         this.needClientAuth = needClientAuth;
121     }
122     
123     
124     /***
125      * Returns <tt>true</tt> if the engine will <em>request</em> client authentication.
126      * This option is only useful to engines in the server mode.
127      */
128     public boolean isWantClientAuth()
129     {
130         return wantClientAuth;
131     }
132     
133     /***
134      * Configures the engine to <em>request</em> client authentication.
135      * This option is only useful for engines in the server mode.
136      */
137     public void setWantClientAuth( boolean wantClientAuth )
138     {
139         this.wantClientAuth = wantClientAuth;
140     }
141     
142     /***
143      * Returns the list of cipher suites to be enabled when {@link SSLEngine}
144      * is initialized.
145      * 
146      * @return <tt>null</tt> means 'use {@link SSLEngine}'s default.'
147      */
148     public String[] getEnabledCipherSuites()
149     {
150         return enabledCipherSuites;
151     }
152     
153     /***
154      * Sets the list of cipher suites to be enabled when {@link SSLEngine}
155      * is initialized.
156      * 
157      * @param cipherSuites <tt>null</tt> means 'use {@link SSLEngine}'s default.'
158      */
159     public void setEnabledCipherSuites( String[] cipherSuites )
160     {
161         this.enabledCipherSuites = cipherSuites;
162     }
163 
164     /***
165      * Returns the list of protocols to be enabled when {@link SSLEngine}
166      * is initialized.
167      * 
168      * @return <tt>null</tt> means 'use {@link SSLEngine}'s default.'
169      */
170     public String[] getEnabledProtocols()
171     {
172         return enabledProtocols;
173     }
174     
175     /***
176      * Sets the list of protocols to be enabled when {@link SSLEngine}
177      * is initialized.
178      * 
179      * @param protocols <tt>null</tt> means 'use {@link SSLEngine}'s default.'
180      */
181     public void setEnabledProtocols( String[] protocols )
182     {
183         this.enabledProtocols = protocols;
184     }
185 
186     // IoFilter impl.
187 
188     public void sessionOpened( NextFilter nextFilter, IoSession session ) throws SSLException
189     {
190         // Create an SSL handler
191         createSSLSessionHandler( nextFilter, session );
192         nextFilter.sessionOpened( session );
193     }
194 
195     public void sessionClosed( NextFilter nextFilter, IoSession session )
196     {
197         SSLHandler sslHandler = getSSLSessionHandler( session );
198         if( log.isLoggable( Level.FINEST ) )
199         {
200             log.log( Level.FINEST, session + " Closed: " + sslHandler );
201         }
202         if( sslHandler != null )
203         {
204             synchronized( sslHandler )
205             {
206                // Start SSL shutdown process
207                try
208                {
209                   // shut down
210                   sslHandler.shutdown();
211                   
212                   // there might be data to write out here?
213                   writeNetBuffer( nextFilter, session, sslHandler );
214                }
215                catch( SSLException ssle )
216                {
217                   nextFilter.exceptionCaught( session, ssle );
218                }
219                finally
220                {
221                   // notify closed session
222                   nextFilter.sessionClosed( session );
223                   
224                   // release buffers
225                   sslHandler.release();
226                   removeSSLSessionHandler( session );
227                }
228             }
229         }
230     }
231    
232     public void dataRead( NextFilter nextFilter, IoSession session,
233                           ByteBuffer buf ) throws SSLException
234     {
235         SSLHandler sslHandler = createSSLSessionHandler( nextFilter, session );
236         if( sslHandler != null )
237         {
238             if( log.isLoggable( Level.FINEST ) )
239             {
240                 log.log( Level.FINEST, session + " Data Read: " + sslHandler + " (" + buf+ ')' );
241             }
242             synchronized( sslHandler )
243             {
244                 try
245                 {
246                     // forward read encrypted data to SSL handler
247                     sslHandler.dataRead( nextFilter, buf.buf() );
248 
249                     // Handle data to be forwarded to application or written to net
250                     handleSSLData( nextFilter, session, sslHandler );
251 
252                     if( sslHandler.isClosed() )
253                     {
254                         if( log.isLoggable( Level.FINEST ) )
255                         {
256                             log.log( Level.FINEST,
257                                      session + " SSL Session closed. Closing connection.." );
258                         }
259                         session.close();
260                     }
261                 }
262                 catch( SSLException ssle )
263                 {
264                     if( !sslHandler.isInitialHandshakeComplete() )
265                     {
266                         SSLException newSSLE = new SSLHandshakeException(
267                                 "Initial SSL handshake failed." );
268                         newSSLE.initCause( ssle );
269                         ssle = newSSLE;
270                     }
271 
272                     nextFilter.exceptionCaught( session, ssle );
273                 }
274             }
275         }
276         else
277         {
278             nextFilter.dataRead( session, buf );
279         }
280     }
281 
282     public void dataWritten( NextFilter nextFilter, IoSession session,
283                             Object marker )
284     {
285         if( marker != SSL_MARKER )
286         {
287             nextFilter.dataWritten( session, marker );
288         }
289     }
290 
291     public void filterWrite( NextFilter nextFilter, IoSession session, ByteBuffer buf, Object marker ) throws SSLException
292     {
293 
294         SSLHandler handler = createSSLSessionHandler( nextFilter, session );
295         if( log.isLoggable( Level.FINEST ) )
296         {
297             log.log( Level.FINEST, session + " Filtered Write: " + handler );
298         }
299 
300         synchronized( handler )
301         {
302             if( handler.isWritingEncryptedData() )
303             {
304                 // data already encrypted; simply return buffer
305                 if( log.isLoggable( Level.FINEST ) )
306                 {
307                     log.log( Level.FINEST, session + "   already encrypted: " + buf );
308                 }
309                 nextFilter.filterWrite( session, buf, marker );
310                 return;
311             }
312             
313             if( handler.isInitialHandshakeComplete() )
314             {
315                 // SSL encrypt
316                 if( log.isLoggable( Level.FINEST ) )
317                 {
318                     log.log( Level.FINEST, session + " encrypt: " + buf );
319                 }
320                 handler.encrypt( buf.buf() );
321                 ByteBuffer encryptedBuffer = copy( handler
322                         .getOutNetBuffer() );
323 
324                 if( log.isLoggable( Level.FINEST ) )
325                 {
326                     log.log( Level.FINEST, session + " encrypted buf: " + encryptedBuffer);
327                 }
328                 buf.release();
329                 nextFilter.filterWrite( session, encryptedBuffer, marker );
330                 return;
331             }
332             else
333             {
334                 if( !session.isConnected() )
335                 {
336                     if( log.isLoggable( Level.FINEST ) )
337                     {
338                         log.log( Level.FINEST, session + " Write request on closed session." );
339                     }
340                 }
341                 else
342                 {
343                     if( log.isLoggable( Level.FINEST ) )
344                     {
345                         log.log( Level.FINEST, session + " Handshaking is not complete yet. Buffering write request." );
346                     }
347                     handler.scheduleWrite( nextFilter, buf, marker );
348                 }
349             }
350         }
351     }
352 
353     // Utiliities
354 
355     private void handleSSLData( NextFilter nextFilter, IoSession session,
356                                SSLHandler handler ) throws SSLException
357     {
358         // Flush any buffered write requests occurred before handshaking.
359         if( handler.isInitialHandshakeComplete() )
360         {
361             handler.flushScheduledWrites();
362         }
363 
364         // Write encrypted data to be written (if any)
365         writeNetBuffer( nextFilter, session, handler );
366 
367         // handle app. data read (if any)
368         handleAppDataRead( nextFilter, session, handler );
369     }
370 
371     private void handleAppDataRead( NextFilter nextFilter, IoSession session,
372                                    SSLHandler sslHandler )
373     {
374         if( log.isLoggable( Level.FINEST ) )
375         {
376             log.log( Level.FINEST, session + " appBuffer: " + sslHandler.getAppBuffer() );
377         }
378         if( sslHandler.getAppBuffer().hasRemaining() )
379         {
380             // forward read app data
381             ByteBuffer readBuffer = copy( sslHandler.getAppBuffer() );
382             if( log.isLoggable( Level.FINEST ) )
383             {
384                 log.log( Level.FINEST, session + " app data read: " + readBuffer + " (" + readBuffer.getHexDump() + ')' );
385             }
386             nextFilter.dataRead( session, readBuffer );
387         }
388     }
389 
390     void writeNetBuffer( NextFilter nextFilter, IoSession session, SSLHandler sslHandler )
391             throws SSLException
392     {
393         // Check if any net data needed to be writen
394         if( !sslHandler.getOutNetBuffer().hasRemaining() )
395         {
396             // no; bail out
397             return;
398         }
399 
400         // write net data
401 
402         // set flag that we are writing encrypted data
403         // (used in filterWrite() above)
404         synchronized( sslHandler )
405         {
406             sslHandler.setWritingEncryptedData( true );
407         }
408 
409         try
410         {
411             if( log.isLoggable( Level.FINEST ) )
412             {
413                 log.log( Level.FINEST, session + " write outNetBuffer: " +
414                                    sslHandler.getOutNetBuffer() );
415             }
416             ByteBuffer writeBuffer = copy( sslHandler.getOutNetBuffer() );
417             if( log.isLoggable( Level.FINEST ) )
418             {
419                 log.log( Level.FINEST, session + " session write: " + writeBuffer );
420             }
421             //debug("outNetBuffer (after copy): {0}", sslHandler.getOutNetBuffer());
422             filterWrite( nextFilter, session, writeBuffer, SSL_MARKER );
423 
424             // loop while more writes required to complete handshake
425             while( sslHandler.needToCompleteInitialHandshake() )
426             {
427                 try
428                 {
429                     sslHandler.continueHandshake( nextFilter );
430                 }
431                 catch( SSLException ssle )
432                 {
433                     SSLException newSSLE = new SSLHandshakeException(
434                             "Initial SSL handshake failed." );
435                     newSSLE.initCause( ssle );
436                     throw newSSLE;
437                 }
438                 if( sslHandler.getOutNetBuffer().hasRemaining() )
439                 {
440                     if( log.isLoggable( Level.FINEST ) )
441                     {
442                         log.log( Level.FINEST, session + " write outNetBuffer2: " +
443                                            sslHandler.getOutNetBuffer() );
444                     }
445                     ByteBuffer writeBuffer2 = copy( sslHandler
446                             .getOutNetBuffer() );
447                     filterWrite( nextFilter, session, writeBuffer2, SSL_MARKER );
448                 }
449             }
450         }
451         finally
452         {
453             synchronized( sslHandler )
454             {
455                 sslHandler.setWritingEncryptedData( false );
456             }
457         }
458     }
459 
460     /***
461      * Creates a new Mina byte buffer that is a deep copy of the remaining bytes
462      * in the given buffer (between index buf.position() and buf.limit())
463      *
464      * @param src the buffer to copy
465      * @return the new buffer, ready to read from
466      */
467     private static ByteBuffer copy( java.nio.ByteBuffer src )
468     {
469         ByteBuffer copy = ByteBuffer.allocate( src.remaining() );
470         copy.put( src );
471         copy.flip();
472         return copy;
473     }
474 
475     // Utilities to mainpulate SSLHandler based on IoSession
476 
477     private SSLHandler createSSLSessionHandler( NextFilter nextFilter, IoSession session ) throws SSLException
478     {
479         SSLHandler handler = ( SSLHandler ) sslSessionHandlerMap.get( session );
480         if( handler == null )
481         {
482             synchronized( sslSessionHandlerMap )
483             {
484                 handler = ( SSLHandler ) sslSessionHandlerMap.get( session );
485                 if( handler == null )
486                 {
487                     boolean done = false;
488                     try
489                     {
490                         handler =
491                             new SSLHandler( this, sslContext, session );
492                         sslSessionHandlerMap.put( session, handler );
493                         handler.doHandshake( nextFilter );
494                         done = true;
495                     }
496                     finally 
497                     {
498                         if( !done )
499                         {
500                             sslSessionHandlerMap.remove( session );
501                         }
502                     }
503                 }
504             }
505         }
506         
507         return handler;
508     }
509 
510     private SSLHandler getSSLSessionHandler( IoSession session )
511     {
512         return ( SSLHandler ) sslSessionHandlerMap.get( session );
513     }
514 
515     private void removeSSLSessionHandler( IoSession session )
516     {
517         synchronized( sslSessionHandlerMap )
518         {
519             sslSessionHandlerMap.remove( session );
520         }
521     }
522 }