1   package org.apache.mina.example.echoserver.ssl;
2   
3   import java.net.InetSocketAddress;
4   import java.net.Socket;
5   import java.nio.charset.Charset;
6   import java.security.cert.CertificateException;
7   import java.util.ArrayList;
8   import java.util.List;
9   
10  import javax.net.ssl.SSLContext;
11  import javax.net.ssl.SSLSocket;
12  import javax.net.ssl.TrustManager;
13  import javax.net.ssl.X509TrustManager;
14  
15  import junit.framework.TestCase;
16  
17  import org.apache.mina.common.IoAcceptor;
18  import org.apache.mina.common.IoHandlerAdapter;
19  import org.apache.mina.common.IoSession;
20  import org.apache.mina.filter.codec.ProtocolCodecFilter;
21  import org.apache.mina.filter.codec.textline.TextLineCodecFactory;
22  import org.apache.mina.filter.ssl.SslFilter;
23  import org.apache.mina.transport.socket.nio.NioSocketAcceptor;
24  
25  public class SslFilterTest extends TestCase {
26  
27      private static final int PORT = 17887;
28  
29      private IoAcceptor acceptor;
30  
31      @Override
32      protected void setUp() throws Exception {
33          super.setUp();
34          acceptor = new NioSocketAcceptor();
35      }
36  
37      @Override
38      protected void tearDown() throws Exception {
39          acceptor.setCloseOnDeactivation(true);
40          acceptor.dispose();
41          super.tearDown();
42      }
43  
44      public void testMessageSentIsCalled() throws Exception {
45          testMessageSentIsCalled(false);
46      }
47  
48      public void testMessageSentIsCalled_With_SSL() throws Exception {
49          testMessageSentIsCalled(true);
50      }
51  
52      private void testMessageSentIsCalled(boolean useSSL) throws Exception {
53          SslFilter sslFilter = null;
54          if (useSSL) {
55              sslFilter = new SslFilter(BogusSslContextFactory.getInstance(true));
56              acceptor.getFilterChain().addLast("sslFilter", sslFilter);
57          }
58          acceptor.getFilterChain().addLast(
59                  "codec",
60                  new ProtocolCodecFilter(new TextLineCodecFactory(Charset
61                          .forName("UTF-8"))));
62  
63          EchoHandler handler = new EchoHandler();
64          acceptor.setHandler(handler);
65          acceptor.bind(new InetSocketAddress(PORT));
66          System.out.println("MINA server started.");
67  
68          Socket socket = getClientSocket(useSSL);
69          int bytesSent = 0;
70          bytesSent += writeMessage(socket, "test-1\n");
71  
72          if (useSSL) {
73              // Test renegotiation
74              SSLSocket ss = (SSLSocket) socket;
75              //ss.getSession().invalidate();
76              ss.startHandshake();
77          }
78  
79          bytesSent += writeMessage(socket, "test-2\n");
80  
81          int[] response = new int[bytesSent];
82          for (int i = 0; i < response.length; i++) {
83              response[i] = socket.getInputStream().read();
84          }
85  
86          if (useSSL) {
87              // Read SSL close notify.
88              while (socket.getInputStream().read() >= 0) {
89                  continue;
90              }
91          }
92  
93          socket.close();
94          while (acceptor.getManagedSessions().size() != 0) {
95              Thread.sleep(100);
96          }
97  
98          System.out.println("handler: " + handler.sentMessages);
99          assertEquals("handler should have sent 2 messages:", 2,
100                 handler.sentMessages.size());
101         assertTrue(handler.sentMessages.contains("test-1"));
102         assertTrue(handler.sentMessages.contains("test-2"));
103     }
104 
105     private int writeMessage(Socket socket, String message) throws Exception {
106         byte request[] = message.getBytes("UTF-8");
107         socket.getOutputStream().write(request);
108         return request.length;
109     }
110 
111     private Socket getClientSocket(boolean ssl) throws Exception {
112         if (ssl) {
113             SSLContext ctx = SSLContext.getInstance("TLS");
114             ctx.init(null, trustManagers, null);
115             return ctx.getSocketFactory().createSocket("localhost", PORT);
116         }
117         return new Socket("localhost", PORT);
118     }
119 
120     private static class EchoHandler extends IoHandlerAdapter {
121 
122         List<String> sentMessages = new ArrayList<String>();
123 
124         @Override
125         public void exceptionCaught(IoSession session, Throwable cause)
126                 throws Exception {
127             cause.printStackTrace();
128         }
129 
130         @Override
131         public void messageReceived(IoSession session, Object message)
132                 throws Exception {
133             session.write(message);
134         }
135 
136         @Override
137         public void messageSent(IoSession session, Object message)
138                 throws Exception {
139             sentMessages.add(message.toString());
140             System.out.println(message);
141             if (sentMessages.size() >= 2) {
142                 session.close();
143             }
144         }
145     }
146 
147     TrustManager[] trustManagers = new TrustManager[] { new TrustAnyone() };
148 
149     private static class TrustAnyone implements X509TrustManager {
150         public void checkClientTrusted(
151                 java.security.cert.X509Certificate[] x509Certificates, String s)
152                 throws CertificateException {
153         }
154 
155         public void checkServerTrusted(
156                 java.security.cert.X509Certificate[] x509Certificates, String s)
157                 throws CertificateException {
158         }
159 
160         public java.security.cert.X509Certificate[] getAcceptedIssuers() {
161             return new java.security.cert.X509Certificate[0];
162         }
163     }
164 
165 }