1   /*
2    *  Licensed to the Apache Software Foundation (ASF) under one
3    *  or more contributor license agreements.  See the NOTICE file
4    *  distributed with this work for additional information
5    *  regarding copyright ownership.  The ASF licenses this file
6    *  to you under the Apache License, Version 2.0 (the
7    *  "License"); you may not use this file except in compliance
8    *  with the License.  You may obtain a copy of the License at
9    *
10   *    http://www.apache.org/licenses/LICENSE-2.0
11   *
12   *  Unless required by applicable law or agreed to in writing,
13   *  software distributed under the License is distributed on an
14   *  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15   *  KIND, either express or implied.  See the License for the
16   *  specific language governing permissions and limitations
17   *  under the License.
18   *
19   */
20  package org.apache.mina.filter;
21  
22  import java.io.ByteArrayInputStream;
23  import java.io.IOException;
24  import java.io.InputStream;
25  import java.net.InetSocketAddress;
26  import java.net.SocketAddress;
27  import java.security.MessageDigest;
28  import java.util.LinkedList;
29  import java.util.Queue;
30  import java.util.Random;
31  import java.util.concurrent.CountDownLatch;
32  
33  import org.apache.mina.common.ByteBuffer;
34  import org.apache.mina.common.IdleStatus;
35  import org.apache.mina.common.IoAcceptor;
36  import org.apache.mina.common.IoConnector;
37  import org.apache.mina.common.IoFilter.NextFilter;
38  import org.apache.mina.common.IoFilter.WriteRequest;
39  import org.apache.mina.common.IoFutureListener;
40  import org.apache.mina.common.IoHandlerAdapter;
41  import org.apache.mina.common.IoSession;
42  import org.apache.mina.common.WriteFuture;
43  import org.apache.mina.common.support.DefaultWriteFuture;
44  import org.apache.mina.transport.socket.nio.SocketAcceptor;
45  import org.apache.mina.transport.socket.nio.SocketAcceptorConfig;
46  import org.apache.mina.transport.socket.nio.SocketConnector;
47  import org.apache.mina.util.AvailablePortFinder;
48  import org.easymock.AbstractMatcher;
49  import org.easymock.MockControl;
50  
51  import junit.framework.TestCase;
52  
53  /**
54   * Tests {@link StreamWriteFilter}.
55   *
56   * @author The Apache Directory Project (mina-dev@directory.apache.org)
57   * @version $Rev$, $Date$
58   */
59  public class StreamWriteFilterTest extends TestCase {
60      MockControl mockSession;
61  
62      MockControl mockNextFilter;
63  
64      IoSession session;
65  
66      NextFilter nextFilter;
67  
68      @Override
69      protected void setUp() throws Exception {
70          super.setUp();
71  
72          /*
73           * Create the mocks.
74           */
75          mockSession = MockControl.createControl(IoSession.class);
76          mockNextFilter = MockControl.createControl(NextFilter.class);
77          session = (IoSession) mockSession.getMock();
78          nextFilter = (NextFilter) mockNextFilter.getMock();
79  
80          session.getAttribute(StreamWriteFilter.CURRENT_STREAM);
81          mockSession.setReturnValue(null);
82      }
83  
84      public void testWriteEmptyStream() throws Exception {
85          StreamWriteFilter filter = new StreamWriteFilter();
86  
87          InputStream stream = new ByteArrayInputStream(new byte[0]);
88          WriteRequest writeRequest = new WriteRequest(stream,
89                  new DummyWriteFuture());
90  
91          /*
92           * Record expectations
93           */
94          nextFilter.messageSent(session, stream);
95  
96          /*
97           * Replay.
98           */
99          mockNextFilter.replay();
100         mockSession.replay();
101 
102         filter.filterWrite(nextFilter, session, writeRequest);
103 
104         /*
105          * Verify.
106          */
107         mockNextFilter.verify();
108         mockSession.verify();
109 
110         assertTrue(writeRequest.getFuture().isWritten());
111     }
112 
113     /**
114      * Tests that the filter just passes objects which aren't InputStreams
115      * through to the next filter.
116      */
117     public void testWriteNonStreamMessage() throws Exception {
118         StreamWriteFilter filter = new StreamWriteFilter();
119 
120         Object message = new Object();
121         WriteRequest writeRequest = new WriteRequest(message,
122                 new DummyWriteFuture());
123 
124         /*
125          * Record expectations
126          */
127         nextFilter.filterWrite(session, writeRequest);
128         session.getAttribute(StreamWriteFilter.CURRENT_STREAM);
129         mockSession.setReturnValue(null);
130         nextFilter.messageSent(session, message);
131 
132         /*
133          * Replay.
134          */
135         mockNextFilter.replay();
136         mockSession.replay();
137 
138         filter.filterWrite(nextFilter, session, writeRequest);
139         filter.messageSent(nextFilter, session, message);
140 
141         /*
142          * Verify.
143          */
144         mockNextFilter.verify();
145         mockSession.verify();
146     }
147 
148     /**
149      * Tests when the contents of the stream fits into one write buffer.
150      */
151     public void testWriteSingleBufferStream() throws Exception {
152         StreamWriteFilter filter = new StreamWriteFilter();
153 
154         byte[] data = new byte[] { 1, 2, 3, 4 };
155 
156         InputStream stream = new ByteArrayInputStream(data);
157         WriteRequest writeRequest = new WriteRequest(stream,
158                 new DummyWriteFuture());
159 
160         /*
161          * Record expectations
162          */
163         session.setAttribute(StreamWriteFilter.CURRENT_STREAM, stream);
164         mockSession.setReturnValue(null);
165         session.setAttribute(StreamWriteFilter.INITIAL_WRITE_FUTURE,
166                 writeRequest.getFuture());
167         mockSession.setReturnValue(null);
168         nextFilter
169                 .filterWrite(session, new WriteRequest(ByteBuffer.wrap(data)));
170         mockNextFilter.setMatcher(new WriteRequestMatcher());
171 
172         session.getAttribute(StreamWriteFilter.CURRENT_STREAM);
173         mockSession.setReturnValue(stream);
174         session.removeAttribute(StreamWriteFilter.CURRENT_STREAM);
175         mockSession.setReturnValue(stream);
176         session.removeAttribute(StreamWriteFilter.INITIAL_WRITE_FUTURE);
177         mockSession.setReturnValue(writeRequest.getFuture());
178         session.removeAttribute(StreamWriteFilter.WRITE_REQUEST_QUEUE);
179         mockSession.setReturnValue(null);
180         nextFilter.messageSent(session, stream);
181 
182         /*
183          * Replay.
184          */
185         mockNextFilter.replay();
186         mockSession.replay();
187 
188         filter.filterWrite(nextFilter, session, writeRequest);
189         filter.messageSent(nextFilter, session, data);
190 
191         /*
192          * Verify.
193          */
194         mockNextFilter.verify();
195         mockSession.verify();
196 
197         assertTrue(writeRequest.getFuture().isWritten());
198     }
199 
200     /**
201      * Tests when the contents of the stream doesn't fit into one write buffer.
202      */
203     public void testWriteSeveralBuffersStream() throws Exception {
204         StreamWriteFilter filter = new StreamWriteFilter();
205         filter.setWriteBufferSize(4);
206 
207         byte[] data = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };
208         byte[] chunk1 = new byte[] { 1, 2, 3, 4 };
209         byte[] chunk2 = new byte[] { 5, 6, 7, 8 };
210         byte[] chunk3 = new byte[] { 9, 10 };
211 
212         InputStream stream = new ByteArrayInputStream(data);
213         WriteRequest writeRequest = new WriteRequest(stream,
214                 new DummyWriteFuture());
215 
216         /*
217          * Record expectations
218          */
219         session.setAttribute(StreamWriteFilter.CURRENT_STREAM, stream);
220         mockSession.setReturnValue(null);
221         session.setAttribute(StreamWriteFilter.INITIAL_WRITE_FUTURE,
222                 writeRequest.getFuture());
223         mockSession.setReturnValue(null);
224         nextFilter.filterWrite(session, new WriteRequest(ByteBuffer
225                 .wrap(chunk1)));
226         mockNextFilter.setMatcher(new WriteRequestMatcher());
227 
228         session.getAttribute(StreamWriteFilter.CURRENT_STREAM);
229         mockSession.setReturnValue(stream);
230         nextFilter.filterWrite(session, new WriteRequest(ByteBuffer
231                 .wrap(chunk2)));
232 
233         session.getAttribute(StreamWriteFilter.CURRENT_STREAM);
234         mockSession.setReturnValue(stream);
235         nextFilter.filterWrite(session, new WriteRequest(ByteBuffer
236                 .wrap(chunk3)));
237 
238         session.getAttribute(StreamWriteFilter.CURRENT_STREAM);
239         mockSession.setReturnValue(stream);
240         session.removeAttribute(StreamWriteFilter.CURRENT_STREAM);
241         mockSession.setReturnValue(stream);
242         session.removeAttribute(StreamWriteFilter.INITIAL_WRITE_FUTURE);
243         mockSession.setReturnValue(writeRequest.getFuture());
244         session.removeAttribute(StreamWriteFilter.WRITE_REQUEST_QUEUE);
245         mockSession.setReturnValue(null);
246         nextFilter.messageSent(session, stream);
247 
248         /*
249          * Replay.
250          */
251         mockNextFilter.replay();
252         mockSession.replay();
253 
254         filter.filterWrite(nextFilter, session, writeRequest);
255         filter.messageSent(nextFilter, session, chunk1);
256         filter.messageSent(nextFilter, session, chunk2);
257         filter.messageSent(nextFilter, session, chunk3);
258 
259         /*
260          * Verify.
261          */
262         mockNextFilter.verify();
263         mockSession.verify();
264 
265         assertTrue(writeRequest.getFuture().isWritten());
266     }
267 
268     public void testWriteWhileWriteInProgress() throws Exception {
269         StreamWriteFilter filter = new StreamWriteFilter();
270 
271         Queue<? extends Object> queue = new LinkedList<Object>();
272         InputStream stream = new ByteArrayInputStream(new byte[5]);
273 
274         /*
275          * Record expectations
276          */
277         mockSession.reset();
278         session.getAttribute(StreamWriteFilter.CURRENT_STREAM);
279         mockSession.setReturnValue(stream);
280         session.getAttribute(StreamWriteFilter.WRITE_REQUEST_QUEUE);
281         mockSession.setReturnValue(queue);
282 
283         /*
284          * Replay.
285          */
286         mockNextFilter.replay();
287         mockSession.replay();
288 
289         WriteRequest wr = new WriteRequest(new Object(), new DummyWriteFuture());
290         filter.filterWrite(nextFilter, session, wr);
291         assertEquals(1, queue.size());
292         assertSame(wr, queue.poll());
293 
294         /*
295          * Verify.
296          */
297         mockNextFilter.verify();
298         mockSession.verify();
299     }
300 
301     public void testWritesWriteRequestQueueWhenFinished() throws Exception {
302         StreamWriteFilter filter = new StreamWriteFilter();
303 
304         WriteRequest wrs[] = new WriteRequest[] {
305                 new WriteRequest(new Object(), new DummyWriteFuture()),
306                 new WriteRequest(new Object(), new DummyWriteFuture()),
307                 new WriteRequest(new Object(), new DummyWriteFuture()) };
308         Queue<WriteRequest> queue = new LinkedList<WriteRequest>();
309         queue.add(wrs[0]);
310         queue.add(wrs[1]);
311         queue.add(wrs[2]);
312         InputStream stream = new ByteArrayInputStream(new byte[0]);
313 
314         /*
315          * Record expectations
316          */
317         mockSession.reset();
318 
319         session.getAttribute(StreamWriteFilter.CURRENT_STREAM);
320         mockSession.setReturnValue(stream);
321         session.removeAttribute(StreamWriteFilter.CURRENT_STREAM);
322         mockSession.setReturnValue(stream);
323         session.removeAttribute(StreamWriteFilter.INITIAL_WRITE_FUTURE);
324         mockSession.setReturnValue(new DefaultWriteFuture(session));
325         session.removeAttribute(StreamWriteFilter.WRITE_REQUEST_QUEUE);
326         mockSession.setReturnValue(queue);
327 
328         nextFilter.filterWrite(session, wrs[0]);
329         session.getAttribute(StreamWriteFilter.CURRENT_STREAM);
330         mockSession.setReturnValue(null);
331         nextFilter.filterWrite(session, wrs[1]);
332         session.getAttribute(StreamWriteFilter.CURRENT_STREAM);
333         mockSession.setReturnValue(null);
334         nextFilter.filterWrite(session, wrs[2]);
335         session.getAttribute(StreamWriteFilter.CURRENT_STREAM);
336         mockSession.setReturnValue(null);
337 
338         nextFilter.messageSent(session, stream);
339 
340         /*
341          * Replay.
342          */
343         mockNextFilter.replay();
344         mockSession.replay();
345 
346         filter.messageSent(nextFilter, session, new Object());
347         assertEquals(0, queue.size());
348 
349         /*
350          * Verify.
351          */
352         mockNextFilter.verify();
353         mockSession.verify();
354     }
355 
356     /**
357      * Tests that {@link StreamWriteFilter#setWriteBufferSize(int)} checks the
358      * specified size.
359      */
360     public void testSetWriteBufferSize() throws Exception {
361         StreamWriteFilter filter = new StreamWriteFilter();
362 
363         try {
364             filter.setWriteBufferSize(0);
365             fail("0 writeBuferSize specified. IllegalArgumentException expected.");
366         } catch (IllegalArgumentException iae) {
367         }
368 
369         try {
370             filter.setWriteBufferSize(-100);
371             fail("Negative writeBuferSize specified. IllegalArgumentException expected.");
372         } catch (IllegalArgumentException iae) {
373         }
374 
375         filter.setWriteBufferSize(1);
376         assertEquals(1, filter.getWriteBufferSize());
377         filter.setWriteBufferSize(1024);
378         assertEquals(1024, filter.getWriteBufferSize());
379     }
380 
381     public void testWriteUsingSocketTransport() throws Exception {
382         IoAcceptor acceptor = new SocketAcceptor();
383         ((SocketAcceptorConfig) acceptor.getDefaultConfig())
384                 .setReuseAddress(true);
385         SocketAddress address = new InetSocketAddress("localhost",
386                 AvailablePortFinder.getNextAvailable());
387 
388         IoConnector connector = new SocketConnector();
389 
390         FixedRandomInputStream stream = new FixedRandomInputStream(
391                 4 * 1024 * 1024);
392 
393         SenderHandler sender = new SenderHandler(stream);
394         ReceiverHandler receiver = new ReceiverHandler(stream.size);
395 
396         acceptor.bind(address, sender);
397 
398         connector.connect(address, receiver);
399         sender.latch.await();
400         receiver.latch.await();
401 
402         acceptor.unbind(address);
403 
404         assertEquals(stream.bytesRead, receiver.bytesRead);
405         assertEquals(stream.size, receiver.bytesRead);
406         byte[] expectedMd5 = stream.digest.digest();
407         byte[] actualMd5 = receiver.digest.digest();
408         assertEquals(expectedMd5.length, actualMd5.length);
409         for (int i = 0; i < expectedMd5.length; i++) {
410             assertEquals(expectedMd5[i], actualMd5[i]);
411         }
412     }
413 
414     private static class FixedRandomInputStream extends InputStream {
415         long size;
416 
417         long bytesRead = 0;
418 
419         Random random = new Random();
420 
421         MessageDigest digest;
422 
423         FixedRandomInputStream(long size) throws Exception {
424             this.size = size;
425             digest = MessageDigest.getInstance("MD5");
426         }
427 
428         @Override
429         public int read() throws IOException {
430             if (isAllWritten())
431                 return -1;
432             bytesRead++;
433             byte b = (byte) random.nextInt(255);
434             digest.update(b);
435             return b;
436         }
437 
438         public long getBytesRead() {
439             return bytesRead;
440         }
441 
442         public long getSize() {
443             return size;
444         }
445 
446         public boolean isAllWritten() {
447             return bytesRead >= size;
448         }
449     }
450 
451     private static class SenderHandler extends IoHandlerAdapter {
452         final CountDownLatch latch = new CountDownLatch( 1 );
453 
454         InputStream inputStream;
455 
456         StreamWriteFilter streamWriteFilter = new StreamWriteFilter();
457 
458         SenderHandler(InputStream inputStream) {
459             this.inputStream = inputStream;
460         }
461 
462         @Override
463         public void sessionCreated(IoSession session) throws Exception {
464             super.sessionCreated(session);
465             session.getFilterChain().addLast("codec", streamWriteFilter);
466         }
467 
468         @Override
469         public void sessionOpened(IoSession session) throws Exception {
470             session.write(inputStream);
471         }
472 
473         @Override
474         public void exceptionCaught(IoSession session, Throwable cause)
475                 throws Exception {
476             latch.countDown();
477         }
478 
479         @Override
480         public void sessionClosed(IoSession session) throws Exception {
481             latch.countDown();
482         }
483 
484         @Override
485         public void sessionIdle(IoSession session, IdleStatus status)
486                 throws Exception {
487             latch.countDown();
488         }
489 
490         @Override
491         public void messageSent(IoSession session, Object message)
492                 throws Exception {
493             if (message == inputStream) {
494                 latch.countDown();
495             }
496         }
497     }
498 
499     private static class ReceiverHandler extends IoHandlerAdapter {
500         final CountDownLatch latch = new CountDownLatch( 1 );
501 
502         long bytesRead = 0;
503 
504         long size = 0;
505 
506         MessageDigest digest;
507 
508         ReceiverHandler(long size) throws Exception {
509             this.size = size;
510             digest = MessageDigest.getInstance("MD5");
511         }
512 
513         @Override
514         public void sessionCreated(IoSession session) throws Exception {
515             super.sessionCreated(session);
516 
517             session.setIdleTime(IdleStatus.READER_IDLE, 5);
518         }
519 
520         @Override
521         public void sessionIdle(IoSession session, IdleStatus status)
522                 throws Exception {
523             session.close();
524         }
525 
526         @Override
527         public void exceptionCaught(IoSession session, Throwable cause)
528                 throws Exception {
529             latch.countDown();
530         }
531 
532         @Override
533         public void sessionClosed(IoSession session) throws Exception {
534             latch.countDown();
535         }
536 
537         @Override
538         public void messageReceived(IoSession session, Object message)
539                 throws Exception {
540             ByteBuffer buf = (ByteBuffer) message;
541             while (buf.hasRemaining()) {
542                 digest.update(buf.get());
543                 bytesRead++;
544             }
545             if (bytesRead >= size) {
546                 session.close();
547             }
548         }
549     }
550 
551     public static class WriteRequestMatcher extends AbstractMatcher {
552         @Override
553         protected boolean argumentMatches(Object expected, Object actual) {
554             if (expected instanceof WriteRequest
555                     && actual instanceof WriteRequest) {
556                 WriteRequest w1 = (WriteRequest) expected;
557                 WriteRequest w2 = (WriteRequest) actual;
558 
559                 return w1.getMessage().equals(w2.getMessage())
560                         && w1.getFuture().isWritten() == w2.getFuture()
561                                 .isWritten();
562             }
563             return super.argumentMatches(expected, actual);
564         }
565     }
566 
567     private static class DummyWriteFuture implements WriteFuture {
568         private boolean written;
569 
570         public boolean isWritten() {
571             return written;
572         }
573 
574         public void setWritten(boolean written) {
575             this.written = written;
576         }
577 
578         public IoSession getSession() {
579             return null;
580         }
581 
582         public Object getLock() {
583             return this;
584         }
585 
586         public void join() {
587         }
588 
589         public boolean join(long timeoutInMillis) {
590             return true;
591         }
592 
593         public boolean isReady() {
594             return true;
595         }
596 
597         public void addListener(IoFutureListener listener) {
598         }
599 
600         public void removeListener(IoFutureListener listener) {
601         }
602     }
603 }