1   /*
2    *  Licensed to the Apache Software Foundation (ASF) under one
3    *  or more contributor license agreements.  See the NOTICE file
4    *  distributed with this work for additional information
5    *  regarding copyright ownership.  The ASF licenses this file
6    *  to you under the Apache License, Version 2.0 (the
7    *  "License"); you may not use this file except in compliance
8    *  with the License.  You may obtain a copy of the License at
9    *
10   *    http://www.apache.org/licenses/LICENSE-2.0
11   *
12   *  Unless required by applicable law or agreed to in writing,
13   *  software distributed under the License is distributed on an
14   *  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15   *  KIND, either express or implied.  See the License for the
16   *  specific language governing permissions and limitations
17   *  under the License.
18   *
19   */
20  package org.apache.mina.filter.logging;
21  
22  import static org.junit.Assert.assertEquals;
23  import static org.junit.Assert.assertNotNull;
24  import static org.junit.Assert.assertNull;
25  import static org.junit.Assert.fail;
26  
27  import java.io.IOException;
28  import java.net.InetSocketAddress;
29  import java.net.SocketAddress;
30  import java.util.ArrayList;
31  import java.util.Collections;
32  import java.util.HashSet;
33  import java.util.List;
34  import java.util.Set;
35  import java.util.concurrent.CountDownLatch;
36  
37  import org.apache.log4j.AppenderSkeleton;
38  import org.apache.log4j.Level;
39  import org.apache.log4j.spi.LoggingEvent;
40  import org.apache.mina.core.buffer.IoBuffer;
41  import org.apache.mina.core.filterchain.DefaultIoFilterChainBuilder;
42  import org.apache.mina.core.filterchain.IoFilterAdapter;
43  import org.apache.mina.core.future.ConnectFuture;
44  import org.apache.mina.core.service.IoHandlerAdapter;
45  import org.apache.mina.core.session.IdleStatus;
46  import org.apache.mina.core.session.IoSession;
47  import org.apache.mina.filter.codec.ProtocolCodecFactory;
48  import org.apache.mina.filter.codec.ProtocolCodecFilter;
49  import org.apache.mina.filter.codec.ProtocolDecoder;
50  import org.apache.mina.filter.codec.ProtocolDecoderAdapter;
51  import org.apache.mina.filter.codec.ProtocolDecoderOutput;
52  import org.apache.mina.filter.codec.ProtocolEncoder;
53  import org.apache.mina.filter.codec.ProtocolEncoderAdapter;
54  import org.apache.mina.filter.codec.ProtocolEncoderOutput;
55  import org.apache.mina.filter.executor.ExecutorFilter;
56  import org.apache.mina.filter.statistic.ProfilerTimerFilter;
57  import org.apache.mina.transport.socket.nio.NioSocketAcceptor;
58  import org.apache.mina.transport.socket.nio.NioSocketConnector;
59  import org.junit.After;
60  import org.junit.Before;
61  import org.junit.Test;
62  import org.slf4j.Logger;
63  import org.slf4j.LoggerFactory;
64  
65  /**
66   * Tests {@link MdcInjectionFilter} in variuos scenarios.
67   *
68   * @author <a href="http://mina.apache.org">Apache MINA Project</a>
69   */
70  public class MdcInjectionFilterTest {
71  
72      static Logger LOGGER = LoggerFactory.getLogger(MdcInjectionFilterTest.class);
73      private static final int TIMEOUT = 5000;
74  
75      final MyAppender appender = new MyAppender();
76      private int port;
77      private NioSocketAcceptor acceptor;
78  
79      private Level previousLevelRootLogger;
80  
81      @Before
82      public void setUp() throws Exception {
83          // comment out next line if you want to see normal logging
84          org.apache.log4j.Logger.getRootLogger().removeAllAppenders();
85          previousLevelRootLogger = org.apache.log4j.Logger.getRootLogger().getLevel();
86          org.apache.log4j.Logger.getRootLogger().setLevel(Level.DEBUG);
87          org.apache.log4j.Logger.getRootLogger().addAppender(appender);
88          acceptor = new NioSocketAcceptor();
89      }
90  
91  
92      @After
93      public void tearDown() throws Exception {
94          acceptor.dispose();
95          org.apache.log4j.Logger.getRootLogger().setLevel(previousLevelRootLogger);
96      }
97  
98      @Test
99      public void testSimpleChain() throws IOException, InterruptedException {
100         DefaultIoFilterChainBuilder chain = new DefaultIoFilterChainBuilder();
101         chain.addFirst("mdc-injector", new MdcInjectionFilter());
102         chain.addLast("dummy", new DummyIoFilter());
103         chain.addLast("protocol", new ProtocolCodecFilter(new DummyProtocolCodecFactory()));
104         test(chain);
105     }
106 
107     @Test
108     public void testExecutorFilterAtTheEnd() throws IOException, InterruptedException {
109         DefaultIoFilterChainBuilder chain = new DefaultIoFilterChainBuilder();
110         MdcInjectionFilter mdcInjectionFilter = new MdcInjectionFilter();
111         chain.addFirst("mdc-injector1", mdcInjectionFilter);
112         chain.addLast("dummy", new DummyIoFilter());
113         chain.addLast("protocol", new ProtocolCodecFilter(new DummyProtocolCodecFactory()));
114         chain.addLast("executor" , new ExecutorFilter());
115         chain.addLast("mdc-injector2", mdcInjectionFilter);
116         test(chain);
117     }
118 
119     @Test
120     public void testExecutorFilterAtBeginning() throws IOException, InterruptedException {
121         DefaultIoFilterChainBuilder chain = new DefaultIoFilterChainBuilder();
122         MdcInjectionFilter mdcInjectionFilter = new MdcInjectionFilter();
123         chain.addLast("executor" , new ExecutorFilter());
124         chain.addLast("mdc-injector", mdcInjectionFilter);
125         chain.addLast("dummy", new DummyIoFilter());
126         chain.addLast("protocol", new ProtocolCodecFilter(new DummyProtocolCodecFactory()));
127         test(chain);
128     }
129 
130     @Test
131     public void testExecutorFilterBeforeProtocol() throws IOException, InterruptedException {
132         DefaultIoFilterChainBuilder chain = new DefaultIoFilterChainBuilder();
133         MdcInjectionFilter mdcInjectionFilter = new MdcInjectionFilter();
134         chain.addLast("executor" , new ExecutorFilter());
135         chain.addLast("mdc-injector", mdcInjectionFilter);
136         chain.addLast("dummy", new DummyIoFilter());
137         chain.addLast("protocol", new ProtocolCodecFilter(new DummyProtocolCodecFactory()));
138         test(chain);
139     }
140 
141     @Test
142     public void testMultipleFilters() throws IOException, InterruptedException {
143         DefaultIoFilterChainBuilder chain = new DefaultIoFilterChainBuilder();
144         MdcInjectionFilter mdcInjectionFilter = new MdcInjectionFilter();
145         chain.addLast("executor" , new ExecutorFilter());
146         chain.addLast("mdc-injector", mdcInjectionFilter);
147         chain.addLast("profiler", new ProfilerTimerFilter());
148         chain.addLast("dummy", new DummyIoFilter());
149         chain.addLast("logger", new LoggingFilter());
150         chain.addLast("protocol", new ProtocolCodecFilter(new DummyProtocolCodecFactory()));
151         test(chain);
152     }
153 
154     @Test
155     public void testTwoExecutorFilters() throws IOException, InterruptedException {
156         DefaultIoFilterChainBuilder chain = new DefaultIoFilterChainBuilder();
157         MdcInjectionFilter mdcInjectionFilter = new MdcInjectionFilter();
158         chain.addLast("executor1" , new ExecutorFilter());
159         chain.addLast("mdc-injector1", mdcInjectionFilter);
160         chain.addLast("protocol", new ProtocolCodecFilter(new DummyProtocolCodecFactory()));
161         chain.addLast("dummy", new DummyIoFilter());
162         chain.addLast("executor2" , new ExecutorFilter());
163         // add the MdcInjectionFilter instance after every ExecutorFilter
164         // it's important to use the same MdcInjectionFilter instance
165         chain.addLast("mdc-injector2",  mdcInjectionFilter);
166         test(chain);
167     }
168 
169     @Test
170     public void testOnlyRemoteAddress() throws IOException, InterruptedException {
171         DefaultIoFilterChainBuilder chain = new DefaultIoFilterChainBuilder();
172         chain.addFirst("mdc-injector", new MdcInjectionFilter(
173             MdcInjectionFilter.MdcKey.remoteAddress));
174         chain.addLast("dummy", new DummyIoFilter());
175         chain.addLast("protocol", new ProtocolCodecFilter(new DummyProtocolCodecFactory()));
176         SimpleIoHandler simpleIoHandler = new SimpleIoHandler();
177         acceptor.setHandler(simpleIoHandler);
178         acceptor.bind(new InetSocketAddress(0));
179         port = acceptor.getLocalAddress().getPort();
180         acceptor.setFilterChainBuilder(chain);
181         // create some clients
182         NioSocketConnector connector = new NioSocketConnector();
183         connector.setHandler(new IoHandlerAdapter());
184         connectAndWrite(connector,0);
185         connectAndWrite(connector,1);
186         // wait until Iohandler has received all events
187         simpleIoHandler.messageSentLatch.await();
188         simpleIoHandler.sessionIdleLatch.await();
189         simpleIoHandler.sessionClosedLatch.await();
190         connector.dispose();
191 
192         // make a copy to prevent ConcurrentModificationException
193         List<LoggingEvent> events = new ArrayList<LoggingEvent>(appender.events);
194         // verify that all logging events have correct MDC
195         for (LoggingEvent event : events) {
196             for (MdcInjectionFilter.MdcKey mdcKey : MdcInjectionFilter.MdcKey.values()) {
197               String key = mdcKey.name();
198               Object value = event.getMDC(key);
199               if (mdcKey == MdcInjectionFilter.MdcKey.remoteAddress) {
200                   assertNotNull(
201                       "MDC[remoteAddress] not set for [" + event.getMessage() + "]", value);
202               } else {
203                   assertNull("MDC[" + key + "] set for [" + event.getMessage() + "]", value);
204               }
205             }
206         }
207     }
208 
209     private void test(DefaultIoFilterChainBuilder chain) throws IOException, InterruptedException {
210         // configure the server
211         SimpleIoHandler simpleIoHandler = new SimpleIoHandler();
212         acceptor.setHandler(simpleIoHandler);
213         acceptor.bind(new InetSocketAddress(0));
214         port = acceptor.getLocalAddress().getPort();
215         acceptor.setFilterChainBuilder(chain);
216         // create some clients
217         NioSocketConnector connector = new NioSocketConnector();
218         connector.setHandler(new IoHandlerAdapter());
219         SocketAddress remoteAddressClients[] = new SocketAddress[2];
220         remoteAddressClients[0] = connectAndWrite(connector,0);
221         remoteAddressClients[1] = connectAndWrite(connector,1);
222         // wait until Iohandler has received all events
223         simpleIoHandler.messageSentLatch.await();
224         simpleIoHandler.sessionIdleLatch.await();
225         simpleIoHandler.sessionClosedLatch.await();
226         connector.dispose();
227 
228         // make a copy to prevent ConcurrentModificationException
229         List<LoggingEvent> events = new ArrayList<LoggingEvent>(appender.events);
230 
231         Set<String> loggersToCheck = new HashSet<String>();
232         loggersToCheck.add(MdcInjectionFilterTest.class.getName());
233         loggersToCheck.add(ProtocolCodecFilter.class.getName());
234         loggersToCheck.add(LoggingFilter.class.getName());
235 
236         // verify that all logging events have correct MDC
237         for (LoggingEvent event : events) {
238              
239             if (loggersToCheck.contains(event.getLoggerName())) {
240                 Object remoteAddress = event.getMDC("remoteAddress");
241                 assertNotNull("MDC[remoteAddress] not set for [" + event.getMessage() + "]", remoteAddress);
242                 assertNotNull("MDC[remotePort] not set for [" + event.getMessage() + "]", event.getMDC("remotePort"));
243                 assertEquals(
244                     "every event should have MDC[handlerClass]",
245                     SimpleIoHandler.class.getName(),
246                     event.getMDC("handlerClass") );
247             }
248         }
249         // assert we have received all expected logging events for each client
250         for (int i = 0; i < remoteAddressClients.length; i++) {
251             SocketAddress remoteAddressClient = remoteAddressClients[i];
252             assertEventExists(events, "sessionCreated", remoteAddressClient, null);
253             assertEventExists(events, "sessionOpened", remoteAddressClient, null);
254             assertEventExists(events, "decode", remoteAddressClient, null);
255             assertEventExists(events, "messageReceived-1", remoteAddressClient, null);
256             assertEventExists(events, "messageReceived-2", remoteAddressClient, "user-" + i);
257             assertEventExists(events, "encode", remoteAddressClient, null);
258             assertEventExists(events, "exceptionCaught", remoteAddressClient, "user-" + i);
259             assertEventExists(events, "messageSent-1", remoteAddressClient, "user-" + i);
260             assertEventExists(events, "messageSent-2", remoteAddressClient, null);
261             assertEventExists(events, "sessionIdle", remoteAddressClient, "user-" + i);
262             assertEventExists(events, "sessionClosed", remoteAddressClient, "user-" + i);
263             assertEventExists(events, "sessionClosed", remoteAddressClient, "user-" + i);
264             assertEventExists(events, "DummyIoFilter.sessionOpened", remoteAddressClient, "user-" + i);
265         }
266     }
267 
268     private SocketAddress connectAndWrite(NioSocketConnector connector, int clientNr) {
269         ConnectFuture connectFuture = connector.connect(new InetSocketAddress("localhost", port));
270         connectFuture.awaitUninterruptibly(TIMEOUT);
271         IoBuffer message = IoBuffer.allocate(4).putInt(clientNr).flip();
272         IoSession session = connectFuture.getSession();
273         session.write(message).awaitUninterruptibly(TIMEOUT);
274         return session.getLocalAddress();
275     }
276 
277     private void assertEventExists(List<LoggingEvent> events,
278                                    String message,
279                                    SocketAddress address,
280                                    String user) {
281         InetSocketAddress remoteAddress = (InetSocketAddress) address;
282         for (LoggingEvent event : events) {
283             if (event.getMessage().equals(message) &&
284                 event.getMDC("remoteAddress").equals(remoteAddress.toString()) &&
285                 event.getMDC("remoteIp").equals(remoteAddress.getAddress().getHostAddress()) &&
286                 event.getMDC("remotePort").equals(remoteAddress.getPort()+"") ) {
287                 if (user == null && event.getMDC("user") == null) {
288                     return;
289                 }
290                 if (user != null && user.equals(event.getMDC("user"))) {
291                     return;
292                 }
293                 return;
294             }
295         }
296         fail("No LoggingEvent found from [" + remoteAddress +"] with message [" + message + "]");
297     }
298 
299     private static class SimpleIoHandler extends IoHandlerAdapter {
300         CountDownLatch sessionIdleLatch = new CountDownLatch(2);
301         CountDownLatch sessionClosedLatch = new CountDownLatch(2);
302         CountDownLatch messageSentLatch = new CountDownLatch(2);
303 
304         /**
305          * Default constructor
306          */
307         public SimpleIoHandler() {
308             super();
309         }
310         
311         @Override
312         public void sessionCreated(IoSession session) throws Exception {
313             LOGGER.info("sessionCreated");
314             session.getConfig().setIdleTime(IdleStatus.BOTH_IDLE, 1);
315         }
316 
317         @Override
318         public void sessionOpened(IoSession session) throws Exception {
319             LOGGER.info("sessionOpened");
320         }
321 
322         @Override
323         public void sessionClosed(IoSession session) throws Exception {
324             LOGGER.info("sessionClosed");
325             sessionClosedLatch.countDown();
326         }
327 
328         @Override
329         public void sessionIdle(IoSession session, IdleStatus status) throws Exception {
330             LOGGER.info("sessionIdle");
331             sessionIdleLatch.countDown();
332             session.close(true);
333         }
334 
335         @Override
336         public void exceptionCaught(IoSession session, Throwable cause) throws Exception {
337             LOGGER.info("exceptionCaught", cause);
338         }
339 
340         @Override
341         public void messageReceived(IoSession session, Object message) throws Exception {
342             LOGGER.info("messageReceived-1");
343             // adding a custom property to the context
344             String user = "user-" + message;
345             MdcInjectionFilter.setProperty(session, "user", user);
346             LOGGER.info("messageReceived-2");
347             session.getService().broadcast(message);
348             throw new RuntimeException("just a test, forcing exceptionCaught");
349         }
350 
351         @Override
352         public void messageSent(IoSession session, Object message) throws Exception {
353             LOGGER.info("messageSent-1");
354             MdcInjectionFilter.removeProperty(session, "user");
355             LOGGER.info("messageSent-2");
356             messageSentLatch.countDown();
357         }
358     }
359 
360     private static class DummyProtocolCodecFactory implements ProtocolCodecFactory {
361         /**
362          * Default constructor
363          */
364         public DummyProtocolCodecFactory() {
365             super();
366         }
367         
368         public ProtocolEncoder getEncoder(IoSession session) throws Exception {
369             return new ProtocolEncoderAdapter() {
370                 public void encode(IoSession session, Object message, ProtocolEncoderOutput out) throws Exception {
371                     LOGGER.info("encode");
372                     IoBuffer buffer = IoBuffer.allocate(4).putInt(123).flip();
373                     out.write(buffer);
374                 }
375             };
376         }
377 
378         public ProtocolDecoder getDecoder(IoSession session) throws Exception {
379             return new ProtocolDecoderAdapter() {
380                 public void decode(IoSession session, IoBuffer in, ProtocolDecoderOutput out) throws Exception {
381                     if (in.remaining() >= 4) {
382                         int value = in.getInt();
383                         LOGGER.info("decode");
384                         out.write(value);
385                     }
386                 }
387             };
388         }
389     }
390 
391     private static class MyAppender extends AppenderSkeleton {
392         List<LoggingEvent> events = Collections.synchronizedList(new ArrayList<LoggingEvent>());
393 
394         /**
395          * Default constructor
396          */
397         public MyAppender() {
398             super();
399         }
400         
401         @Override
402         protected void append(final LoggingEvent loggingEvent) {
403             loggingEvent.getMDCCopy();
404             events.add(loggingEvent);
405         }
406 
407         @Override
408         public boolean requiresLayout() {
409             return false;
410         }
411 
412         @Override
413         public void close() {
414             // Do nothing
415         }
416     }
417 
418     static class DummyIoFilter extends IoFilterAdapter {
419         @Override
420         public void sessionOpened(NextFilter nextFilter, IoSession session) throws Exception {
421             LOGGER.info("DummyIoFilter.sessionOpened");
422             nextFilter.sessionOpened(session);
423         }
424     }
425 
426 }