View Javadoc

1   /*
2    *   @(#) $Id: SSLFilter.java 264677 2005-08-30 02:44:35Z trustin $
3    *
4    *   Copyright 2004 The Apache Software Foundation
5    *
6    *   Licensed under the Apache License, Version 2.0 (the "License");
7    *   you may not use this file except in compliance with the License.
8    *   You may obtain a copy of the License at
9    *
10   *       http://www.apache.org/licenses/LICENSE-2.0
11   *
12   *   Unless required by applicable law or agreed to in writing, software
13   *   distributed under the License is distributed on an "AS IS" BASIS,
14   *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15   *   See the License for the specific language governing permissions and
16   *   limitations under the License.
17   *
18   */
19  package org.apache.mina.io.filter;
20  
21  import java.util.logging.Level;
22  import java.util.logging.Logger;
23  
24  import javax.net.ssl.SSLContext;
25  import javax.net.ssl.SSLEngine;
26  import javax.net.ssl.SSLException;
27  import javax.net.ssl.SSLHandshakeException;
28  import javax.net.ssl.SSLSession;
29  
30  import org.apache.mina.common.ByteBuffer;
31  import org.apache.mina.io.IoFilterAdapter;
32  import org.apache.mina.io.IoHandler;
33  import org.apache.mina.io.IoSession;
34  
35  /***
36   * An SSL filter that encrypts and decrypts the data exchanged in the session.
37   * This filter uses an {@link SSLEngine} which was introduced in Java 5, so 
38   * Java version 5 or above is mandatory to use this filter. And please note that
39   * this filter only works for TCP/IP connections.
40   * <p>
41   * This filter logs debug information in {@link Level#FINEST} using {@link Logger}.
42   * 
43   * @author The Apache Directory Project (dev@directory.apache.org)
44   * @version $Rev: 264677 $, $Date: 2005-08-30 11:44:35 +0900 $
45   */
46  public class SSLFilter extends IoFilterAdapter
47  {
48      /***
49       * Session attribute key that stores underlying {@link SSLSession}
50       * for each session.
51       */
52      public static final String SSL_SESSION = SSLFilter.class.getName() + ".SSLSession";
53      
54      private static final String SSL_HANDLER = SSLFilter.class.getName() + ".SSLHandler";
55  
56      private static final Logger log = Logger.getLogger( SSLFilter.class.getName() );
57  
58      /***
59       * A marker which is passed with {@link IoHandler#dataWritten(IoSession, Object)}
60       * when <tt>SSLFilter</tt> writes data other then user actually requested.
61       */
62      private static final Object SSL_MARKER = new Object()
63      {
64          public String toString()
65          {
66              return "SSL_MARKER";
67          }
68      };
69      
70      // SSL Context
71      private SSLContext sslContext;
72  
73      private boolean client;
74      private boolean needClientAuth;
75      private boolean wantClientAuth;
76      private String[] enabledCipherSuites;
77      private String[] enabledProtocols;
78  
79      /***
80       * Creates a new SSL filter using the specified {@link SSLContext}.
81       */
82      public SSLFilter( SSLContext sslContext )
83      {
84          if( sslContext == null )
85          {
86              throw new NullPointerException( "sslContext" );
87          }
88  
89          this.sslContext = sslContext;
90      }
91      
92      /***
93       * Returns the underlying {@link SSLSession} for the specified session.
94       * 
95       * @return <tt>null</tt> if no {@link SSLSession} is initialized yet.
96       */
97      public SSLSession getSSLSession( IoSession session )
98      {
99          return ( SSLSession ) session.getAttribute( SSL_SESSION );
100     }
101 
102     /***
103      * Returns <tt>true</tt> if the engine is set to use client mode
104      * when handshaking.
105      */
106     public boolean isUseClientMode()
107     {
108         return client;
109     }
110     
111     /***
112      * Configures the engine to use client (or server) mode when handshaking.
113      */
114     public void setUseClientMode( boolean clientMode )
115     {
116         this.client = clientMode;
117     }
118     
119     /***
120      * Returns <tt>true</tt> if the engine will <em>require</em> client authentication.
121      * This option is only useful to engines in the server mode.
122      */
123     public boolean isNeedClientAuth()
124     {
125         return needClientAuth;
126     }
127 
128     /***
129      * Configures the engine to <em>require</em> client authentication.
130      * This option is only useful for engines in the server mode.
131      */
132     public void setNeedClientAuth( boolean needClientAuth )
133     {
134         this.needClientAuth = needClientAuth;
135     }
136     
137     
138     /***
139      * Returns <tt>true</tt> if the engine will <em>request</em> client authentication.
140      * This option is only useful to engines in the server mode.
141      */
142     public boolean isWantClientAuth()
143     {
144         return wantClientAuth;
145     }
146     
147     /***
148      * Configures the engine to <em>request</em> client authentication.
149      * This option is only useful for engines in the server mode.
150      */
151     public void setWantClientAuth( boolean wantClientAuth )
152     {
153         this.wantClientAuth = wantClientAuth;
154     }
155     
156     /***
157      * Returns the list of cipher suites to be enabled when {@link SSLEngine}
158      * is initialized.
159      * 
160      * @return <tt>null</tt> means 'use {@link SSLEngine}'s default.'
161      */
162     public String[] getEnabledCipherSuites()
163     {
164         return enabledCipherSuites;
165     }
166     
167     /***
168      * Sets the list of cipher suites to be enabled when {@link SSLEngine}
169      * is initialized.
170      * 
171      * @param cipherSuites <tt>null</tt> means 'use {@link SSLEngine}'s default.'
172      */
173     public void setEnabledCipherSuites( String[] cipherSuites )
174     {
175         this.enabledCipherSuites = cipherSuites;
176     }
177 
178     /***
179      * Returns the list of protocols to be enabled when {@link SSLEngine}
180      * is initialized.
181      * 
182      * @return <tt>null</tt> means 'use {@link SSLEngine}'s default.'
183      */
184     public String[] getEnabledProtocols()
185     {
186         return enabledProtocols;
187     }
188     
189     /***
190      * Sets the list of protocols to be enabled when {@link SSLEngine}
191      * is initialized.
192      * 
193      * @param protocols <tt>null</tt> means 'use {@link SSLEngine}'s default.'
194      */
195     public void setEnabledProtocols( String[] protocols )
196     {
197         this.enabledProtocols = protocols;
198     }
199 
200     // IoFilter impl.
201 
202     public void sessionOpened( NextFilter nextFilter, IoSession session ) throws SSLException
203     {
204         // Create an SSL handler
205         createSSLSessionHandler( nextFilter, session );
206         nextFilter.sessionOpened( session );
207     }
208 
209     public void sessionClosed( NextFilter nextFilter, IoSession session ) throws SSLException
210     {
211         SSLHandler sslHandler = getSSLSessionHandler( session );
212         if( log.isLoggable( Level.FINEST ) )
213         {
214             log.log( Level.FINEST, session + " Closed: " + sslHandler );
215         }
216         if( sslHandler != null )
217         {
218             synchronized( sslHandler )
219             {
220                // Start SSL shutdown process
221                try
222                {
223                   // shut down
224                   sslHandler.shutdown();
225                   
226                   // there might be data to write out here?
227                   writeNetBuffer( nextFilter, session, sslHandler );
228                }
229                finally
230                {
231                   // notify closed session
232                   nextFilter.sessionClosed( session );
233                   
234                   // release buffers
235                   sslHandler.release();
236                   removeSSLSessionHandler( session );
237                }
238             }
239         }
240     }
241    
242     public void dataRead( NextFilter nextFilter, IoSession session,
243                           ByteBuffer buf ) throws SSLException
244     {
245         SSLHandler sslHandler = createSSLSessionHandler( nextFilter, session );
246         if( sslHandler != null )
247         {
248             if( log.isLoggable( Level.FINEST ) )
249             {
250                 log.log( Level.FINEST, session + " Data Read: " + sslHandler + " (" + buf+ ')' );
251             }
252             synchronized( sslHandler )
253             {
254                 try
255                 {
256                     // forward read encrypted data to SSL handler
257                     sslHandler.dataRead( nextFilter, buf.buf() );
258 
259                     // Handle data to be forwarded to application or written to net
260                     handleSSLData( nextFilter, session, sslHandler );
261 
262                     if( sslHandler.isClosed() )
263                     {
264                         if( log.isLoggable( Level.FINEST ) )
265                         {
266                             log.log( Level.FINEST,
267                                      session + " SSL Session closed. Closing connection.." );
268                         }
269                         session.close();
270                     }
271                 }
272                 catch( SSLException ssle )
273                 {
274                     if( !sslHandler.isInitialHandshakeComplete() )
275                     {
276                         SSLException newSSLE = new SSLHandshakeException(
277                                 "Initial SSL handshake failed." );
278                         newSSLE.initCause( ssle );
279                         ssle = newSSLE;
280                     }
281 
282                     throw ssle;
283                 }
284             }
285         }
286         else
287         {
288             nextFilter.dataRead( session, buf );
289         }
290     }
291 
292     public void dataWritten( NextFilter nextFilter, IoSession session,
293                             Object marker )
294     {
295         if( marker != SSL_MARKER )
296         {
297             nextFilter.dataWritten( session, marker );
298         }
299     }
300 
301     public void filterWrite( NextFilter nextFilter, IoSession session, ByteBuffer buf, Object marker ) throws SSLException
302     {
303 
304         SSLHandler handler = createSSLSessionHandler( nextFilter, session );
305         if( log.isLoggable( Level.FINEST ) )
306         {
307             log.log( Level.FINEST, session + " Filtered Write: " + handler );
308         }
309 
310         synchronized( handler )
311         {
312             if( handler.isWritingEncryptedData() )
313             {
314                 // data already encrypted; simply return buffer
315                 if( log.isLoggable( Level.FINEST ) )
316                 {
317                     log.log( Level.FINEST, session + "   already encrypted: " + buf );
318                 }
319                 nextFilter.filterWrite( session, buf, marker );
320                 return;
321             }
322             
323             if( handler.isInitialHandshakeComplete() )
324             {
325                 // SSL encrypt
326                 if( log.isLoggable( Level.FINEST ) )
327                 {
328                     log.log( Level.FINEST, session + " encrypt: " + buf );
329                 }
330                 handler.encrypt( buf.buf() );
331                 ByteBuffer encryptedBuffer = copy( handler
332                         .getOutNetBuffer() );
333 
334                 if( log.isLoggable( Level.FINEST ) )
335                 {
336                     log.log( Level.FINEST, session + " encrypted buf: " + encryptedBuffer);
337                 }
338                 buf.release();
339                 nextFilter.filterWrite( session, encryptedBuffer, marker );
340                 return;
341             }
342             else
343             {
344                 if( !session.isConnected() )
345                 {
346                     if( log.isLoggable( Level.FINEST ) )
347                     {
348                         log.log( Level.FINEST, session + " Write request on closed session." );
349                     }
350                 }
351                 else
352                 {
353                     if( log.isLoggable( Level.FINEST ) )
354                     {
355                         log.log( Level.FINEST, session + " Handshaking is not complete yet. Buffering write request." );
356                     }
357                     handler.scheduleWrite( nextFilter, buf, marker );
358                 }
359             }
360         }
361     }
362 
363     // Utiliities
364 
365     private void handleSSLData( NextFilter nextFilter, IoSession session,
366                                SSLHandler handler ) throws SSLException
367     {
368         // Flush any buffered write requests occurred before handshaking.
369         if( handler.isInitialHandshakeComplete() )
370         {
371             handler.flushScheduledWrites();
372         }
373 
374         // Write encrypted data to be written (if any)
375         writeNetBuffer( nextFilter, session, handler );
376 
377         // handle app. data read (if any)
378         handleAppDataRead( nextFilter, session, handler );
379     }
380 
381     private void handleAppDataRead( NextFilter nextFilter, IoSession session,
382                                    SSLHandler sslHandler )
383     {
384         if( log.isLoggable( Level.FINEST ) )
385         {
386             log.log( Level.FINEST, session + " appBuffer: " + sslHandler.getAppBuffer() );
387         }
388         if( sslHandler.getAppBuffer().hasRemaining() )
389         {
390             // forward read app data
391             ByteBuffer readBuffer = copy( sslHandler.getAppBuffer() );
392             if( log.isLoggable( Level.FINEST ) )
393             {
394                 log.log( Level.FINEST, session + " app data read: " + readBuffer + " (" + readBuffer.getHexDump() + ')' );
395             }
396             nextFilter.dataRead( session, readBuffer );
397         }
398     }
399 
400     void writeNetBuffer( NextFilter nextFilter, IoSession session, SSLHandler sslHandler )
401             throws SSLException
402     {
403         // Check if any net data needed to be writen
404         if( !sslHandler.getOutNetBuffer().hasRemaining() )
405         {
406             // no; bail out
407             return;
408         }
409 
410         // write net data
411 
412         // set flag that we are writing encrypted data
413         // (used in filterWrite() above)
414         synchronized( sslHandler )
415         {
416             sslHandler.setWritingEncryptedData( true );
417         }
418 
419         try
420         {
421             if( log.isLoggable( Level.FINEST ) )
422             {
423                 log.log( Level.FINEST, session + " write outNetBuffer: " +
424                                    sslHandler.getOutNetBuffer() );
425             }
426             ByteBuffer writeBuffer = copy( sslHandler.getOutNetBuffer() );
427             if( log.isLoggable( Level.FINEST ) )
428             {
429                 log.log( Level.FINEST, session + " session write: " + writeBuffer );
430             }
431             //debug("outNetBuffer (after copy): {0}", sslHandler.getOutNetBuffer());
432             filterWrite( nextFilter, session, writeBuffer, SSL_MARKER );
433 
434             // loop while more writes required to complete handshake
435             while( sslHandler.needToCompleteInitialHandshake() )
436             {
437                 try
438                 {
439                     sslHandler.continueHandshake( nextFilter );
440                 }
441                 catch( SSLException ssle )
442                 {
443                     SSLException newSSLE = new SSLHandshakeException(
444                             "Initial SSL handshake failed." );
445                     newSSLE.initCause( ssle );
446                     throw newSSLE;
447                 }
448                 if( sslHandler.getOutNetBuffer().hasRemaining() )
449                 {
450                     if( log.isLoggable( Level.FINEST ) )
451                     {
452                         log.log( Level.FINEST, session + " write outNetBuffer2: " +
453                                            sslHandler.getOutNetBuffer() );
454                     }
455                     ByteBuffer writeBuffer2 = copy( sslHandler
456                             .getOutNetBuffer() );
457                     filterWrite( nextFilter, session, writeBuffer2, SSL_MARKER );
458                 }
459             }
460         }
461         finally
462         {
463             synchronized( sslHandler )
464             {
465                 sslHandler.setWritingEncryptedData( false );
466             }
467         }
468     }
469 
470     /***
471      * Creates a new Mina byte buffer that is a deep copy of the remaining bytes
472      * in the given buffer (between index buf.position() and buf.limit())
473      *
474      * @param src the buffer to copy
475      * @return the new buffer, ready to read from
476      */
477     private static ByteBuffer copy( java.nio.ByteBuffer src )
478     {
479         ByteBuffer copy = ByteBuffer.allocate( src.remaining() );
480         copy.put( src );
481         copy.flip();
482         return copy;
483     }
484 
485     // Utilities to mainpulate SSLHandler based on IoSession
486 
487     private SSLHandler createSSLSessionHandler( NextFilter nextFilter, IoSession session ) throws SSLException
488     {
489         SSLHandler handler = getSSLSessionHandler( session );
490         if( handler == null )
491         {
492             synchronized( session )
493             {
494                 handler = getSSLSessionHandler( session );
495                 if( handler == null )
496                 {
497                     boolean done = false;
498                     try
499                     {
500                         handler =
501                             new SSLHandler( this, sslContext, session );
502                         session.setAttribute( SSL_HANDLER, handler );
503                         handler.doHandshake( nextFilter );
504                         done = true;
505                     }
506                     finally 
507                     {
508                         if( !done )
509                         {
510                             session.removeAttribute( SSL_HANDLER );
511                         }
512                     }
513                 }
514             }
515         }
516         
517         return handler;
518     }
519 
520     private SSLHandler getSSLSessionHandler( IoSession session )
521     {
522         return ( SSLHandler ) session.getAttribute( SSL_HANDLER );
523     }
524 
525     private void removeSSLSessionHandler( IoSession session )
526     {
527         session.removeAttribute( SSL_HANDLER );
528     }
529 }