1   package org.apache.mina.example.echoserver.ssl;
2   
3   import java.net.InetSocketAddress;
4   import java.net.Socket;
5   import java.net.SocketAddress;
6   import java.nio.charset.Charset;
7   import java.security.cert.CertificateException;
8   import java.util.ArrayList;
9   import java.util.List;
10  
11  import javax.net.ssl.SSLContext;
12  import javax.net.ssl.TrustManager;
13  import javax.net.ssl.X509TrustManager;
14  
15  import junit.framework.TestCase;
16  
17  import org.apache.mina.common.IoAcceptor;
18  import org.apache.mina.common.IoHandlerAdapter;
19  import org.apache.mina.common.IoSession;
20  import org.apache.mina.filter.SSLFilter;
21  import org.apache.mina.filter.codec.ProtocolCodecFilter;
22  import org.apache.mina.filter.codec.textline.TextLineCodecFactory;
23  import org.apache.mina.transport.socket.nio.SocketAcceptor;
24  
25  public class SSLFilterTest extends TestCase {
26  
27      private static final int PORT = 17887;
28  
29      private IoAcceptor acceptor;
30  
31      SocketAddress socketAddress = new InetSocketAddress(PORT);
32  
33      protected void setUp() throws Exception {
34          super.setUp();
35          acceptor = new SocketAcceptor();
36      }
37  
38      protected void tearDown() throws Exception {
39          acceptor.unbindAll();
40          super.tearDown();
41      }
42  
43      public void testMessageSentIsCalled() throws Exception {
44          testMessageSentIsCalled(false);
45      }
46  
47      public void testMessageSentIsCalled_With_SSL() throws Exception {
48          testMessageSentIsCalled(true);
49      }
50  
51      private void testMessageSentIsCalled(boolean useSSL) throws Exception {
52  
53          if (useSSL) {
54              SSLFilter sslFilter = new SSLFilter(BogusSSLContextFactory
55                      .getInstance(true));
56              acceptor.getFilterChain().addLast("sslFilter", sslFilter);
57          }
58          acceptor.getFilterChain().addLast(
59                  "codec",
60                  new ProtocolCodecFilter(new TextLineCodecFactory(Charset
61                          .forName("UTF-8"))));
62  
63          EchoHandler handler = new EchoHandler();
64          acceptor.bind(socketAddress, handler);
65          System.out.println("MINA server started.");
66  
67          Socket socket = getClientSocket(useSSL);
68          int bytesSent = 0;
69          bytesSent += writeMessage(socket, "test-1\n");
70          bytesSent += writeMessage(socket, "test-2\n");
71          byte[] response = new byte[bytesSent];
72          for (int i = 0; i < response.length; i++) {
73              response[i] = (byte) socket.getInputStream().read();
74          }
75          long millis = System.currentTimeMillis();
76          while (handler.sentMessages.size() < 2
77                  && System.currentTimeMillis() < millis + 5000) {
78              Thread.sleep(200);
79          }
80          assertEquals("received what we sent", "test-1\ntest-2\n", new String(
81                  response, "UTF-8"));
82  
83          System.out.println("handler: " + handler.sentMessages);
84          assertEquals("handler should have sent 2 messages:", 2,
85                  handler.sentMessages.size());
86          assertTrue(handler.sentMessages.contains("test-1"));
87          assertTrue(handler.sentMessages.contains("test-2"));
88      }
89  
90      private int writeMessage(Socket socket, String message) throws Exception {
91          byte request[] = message.getBytes("UTF-8");
92          socket.getOutputStream().write(request);
93          return request.length;
94      }
95  
96      private Socket getClientSocket(boolean ssl) throws Exception {
97          if (ssl) {
98              SSLContext ctx = SSLContext.getInstance("TLS");
99              ctx.init(null, trustManagers, null);
100             return ctx.getSocketFactory().createSocket("localhost", PORT);
101         }
102         return new Socket("localhost", PORT);
103     }
104 
105     private static class EchoHandler extends IoHandlerAdapter {
106 
107         List sentMessages = new ArrayList();
108 
109         public void exceptionCaught(IoSession session, Throwable cause)
110                 throws Exception {
111         }
112 
113         public void messageReceived(IoSession session, Object message)
114                 throws Exception {
115             session.write(message);
116         }
117 
118         public void messageSent(IoSession session, Object message)
119                 throws Exception {
120             sentMessages.add(message.toString());
121             if (sentMessages.size() >= 2) {
122                 session.close();
123             }
124         }
125 
126     }
127 
128     TrustManager[] trustManagers = new TrustManager[] { new TrustAnyone() };
129 
130     private static class TrustAnyone implements X509TrustManager {
131         public void checkClientTrusted(
132                 java.security.cert.X509Certificate[] x509Certificates, String s)
133                 throws CertificateException {
134         }
135 
136         public void checkServerTrusted(
137                 java.security.cert.X509Certificate[] x509Certificates, String s)
138                 throws CertificateException {
139         }
140 
141         public java.security.cert.X509Certificate[] getAcceptedIssuers() {
142             return new java.security.cert.X509Certificate[0];
143         }
144     }
145 
146 }