1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20 package org.apache.mina.filter.reqres;
21
22 import java.util.ArrayList;
23 import java.util.Date;
24 import java.util.HashMap;
25 import java.util.Iterator;
26 import java.util.LinkedHashSet;
27 import java.util.List;
28 import java.util.Map;
29 import java.util.Set;
30 import java.util.concurrent.ScheduledExecutorService;
31 import java.util.concurrent.ScheduledFuture;
32 import java.util.concurrent.TimeUnit;
33
34 import org.apache.mina.common.AttributeKey;
35 import org.apache.mina.common.IoFilterChain;
36 import org.apache.mina.common.IoSession;
37 import org.apache.mina.common.WriteRequest;
38 import org.apache.mina.filter.util.WriteRequestFilter;
39 import org.slf4j.Logger;
40 import org.slf4j.LoggerFactory;
41
42
43
44
45
46
47 public class RequestResponseFilter extends WriteRequestFilter {
48
49 private final AttributeKey RESPONSE_INSPECTOR = new AttributeKey(getClass(), "responseInspector");
50 private final AttributeKey REQUEST_STORE = new AttributeKey(getClass(), "requestStore");
51 private final AttributeKey UNRESPONDED_REQUEST_STORE = new AttributeKey(getClass(), "unrespondedRequestStore");
52
53 private final ResponseInspectorFactory responseInspectorFactory;
54 private final ScheduledExecutorService timeoutScheduler;
55
56 private final Logger logger = LoggerFactory.getLogger(getClass());
57
58 public RequestResponseFilter(final ResponseInspector responseInspector,
59 ScheduledExecutorService timeoutScheduler) {
60 if (responseInspector == null) {
61 throw new NullPointerException("responseInspector");
62 }
63 if (timeoutScheduler == null) {
64 throw new NullPointerException("timeoutScheduler");
65 }
66 this.responseInspectorFactory = new ResponseInspectorFactory() {
67 public ResponseInspector getResponseInspector() {
68 return responseInspector;
69 }
70 };
71 this.timeoutScheduler = timeoutScheduler;
72 }
73
74 public RequestResponseFilter(
75 ResponseInspectorFactory responseInspectorFactory,
76 ScheduledExecutorService timeoutScheduler) {
77 if (responseInspectorFactory == null) {
78 throw new NullPointerException("responseInspectorFactory");
79 }
80 if (timeoutScheduler == null) {
81 throw new NullPointerException("timeoutScheduler");
82 }
83 this.responseInspectorFactory = responseInspectorFactory;
84 this.timeoutScheduler = timeoutScheduler;
85 }
86
87 @Override
88 public void onPreAdd(IoFilterChain parent, String name,
89 NextFilter nextFilter) throws Exception {
90 if (parent.contains(this)) {
91 throw new IllegalArgumentException(
92 "You can't add the same filter instance more than once. Create another instance and add it.");
93 }
94
95 IoSession session = parent.getSession();
96 session.setAttribute(RESPONSE_INSPECTOR, responseInspectorFactory
97 .getResponseInspector());
98 session.setAttribute(REQUEST_STORE, createRequestStore(session));
99 session.setAttribute(UNRESPONDED_REQUEST_STORE, createUnrespondedRequestStore(session));
100 }
101
102 @Override
103 public void onPostRemove(IoFilterChain parent, String name,
104 NextFilter nextFilter) throws Exception {
105 IoSession session = parent.getSession();
106
107 destroyUnrespondedRequestStore(getUnrespondedRequestStore(session));
108 destroyRequestStore(getRequestStore(session));
109
110 session.removeAttribute(UNRESPONDED_REQUEST_STORE);
111 session.removeAttribute(REQUEST_STORE);
112 session.removeAttribute(RESPONSE_INSPECTOR);
113 }
114
115 @Override
116 public void messageReceived(NextFilter nextFilter, IoSession session,
117 Object message) throws Exception {
118 ResponseInspector responseInspector = (ResponseInspector) session
119 .getAttribute(RESPONSE_INSPECTOR);
120 Object requestId = responseInspector.getRequestId(message);
121 if (requestId == null) {
122
123 nextFilter.messageReceived(session, message);
124 return;
125 }
126
127
128 ResponseType type = responseInspector.getResponseType(message);
129 if (type == null) {
130 nextFilter.exceptionCaught(session, new IllegalStateException(
131 responseInspector.getClass().getName()
132 + "#getResponseType() may not return null."));
133 }
134
135 Map<Object, Request> requestStore = getRequestStore(session);
136
137 Request request;
138 switch (type) {
139 case WHOLE:
140 case PARTIAL_LAST:
141 synchronized (requestStore) {
142 request = requestStore.remove(requestId);
143 }
144 break;
145 case PARTIAL:
146 synchronized (requestStore) {
147 request = requestStore.get(requestId);
148 }
149 break;
150 default:
151 throw new InternalError();
152 }
153
154 if (request == null) {
155
156
157 if (logger.isWarnEnabled()) {
158 logger.warn("Unknown request ID '" + requestId
159 + "' for the response message. Timed out already?: "
160 + message);
161 }
162 } else {
163
164
165 if (type != ResponseType.PARTIAL) {
166 ScheduledFuture<?> scheduledFuture = request.getTimeoutFuture();
167 if (scheduledFuture != null) {
168 scheduledFuture.cancel(false);
169 Set<Request> unrespondedRequests = getUnrespondedRequestStore(session);
170 synchronized (unrespondedRequests) {
171 unrespondedRequests.remove(request);
172 }
173 }
174 }
175
176
177 Response response = new Response(request, message, type);
178 request.signal(response);
179 nextFilter.messageReceived(session, response);
180 }
181 }
182
183 @Override
184 protected Object doFilterWrite(
185 final NextFilter nextFilter, IoSession session, WriteRequest writeRequest) throws Exception {
186 Object message = writeRequest.getMessage();
187 if (!(message instanceof Request)) {
188 return null;
189 }
190
191 final Request request = (Request) message;
192 if (request.getTimeoutFuture() != null) {
193 throw new IllegalArgumentException("Request can not be reused.");
194 }
195
196 Map<Object, Request> requestStore = getRequestStore(session);
197 Object oldValue = null;
198 Object requestId = request.getId();
199 synchronized (requestStore) {
200 oldValue = requestStore.get(requestId);
201 if (oldValue == null) {
202 requestStore.put(requestId, request);
203 }
204 }
205 if (oldValue != null) {
206 throw new IllegalStateException(
207 "Duplicate request ID: " + request.getId());
208 }
209
210
211
212 Date timeoutDate = new Date(System.currentTimeMillis());
213 if (Long.MAX_VALUE - request.getTimeoutMillis() < timeoutDate
214 .getTime()) {
215 timeoutDate.setTime(Long.MAX_VALUE);
216 } else {
217 timeoutDate.setTime(timeoutDate.getTime()
218 + request.getTimeoutMillis());
219 }
220
221 TimeoutTask timeoutTask = new TimeoutTask(
222 nextFilter, request, session);
223
224
225 ScheduledFuture<?> timeoutFuture = timeoutScheduler.schedule(
226 timeoutTask, request.getTimeoutMillis(),
227 TimeUnit.MILLISECONDS);
228 request.setTimeoutTask(timeoutTask);
229 request.setTimeoutFuture(timeoutFuture);
230
231
232 Set<Request> unrespondedRequests = getUnrespondedRequestStore(session);
233 synchronized (unrespondedRequests) {
234 unrespondedRequests.add(request);
235 }
236
237 return request.getMessage();
238 }
239
240 @Override
241 public void sessionClosed(NextFilter nextFilter, IoSession session)
242 throws Exception {
243
244
245 Set<Request> unrespondedRequests = getUnrespondedRequestStore(session);
246 List<Request> unrespondedRequestsCopy;
247 synchronized (unrespondedRequests) {
248 unrespondedRequestsCopy = new ArrayList<Request>(
249 unrespondedRequests);
250 unrespondedRequests.clear();
251 }
252
253
254 for (Request r : unrespondedRequestsCopy) {
255 if (r.getTimeoutFuture().cancel(false)) {
256 r.getTimeoutTask().run();
257 }
258 }
259
260
261 Map<Object, Request> requestStore = getRequestStore(session);
262 synchronized (requestStore) {
263 requestStore.clear();
264 }
265
266
267 nextFilter.sessionClosed(session);
268 }
269
270 @SuppressWarnings("unchecked")
271 private Map<Object, Request> getRequestStore(IoSession session) {
272 return (Map<Object, Request>) session.getAttribute(REQUEST_STORE);
273 }
274
275 @SuppressWarnings("unchecked")
276 private Set<Request> getUnrespondedRequestStore(IoSession session) {
277 return (Set<Request>) session.getAttribute(UNRESPONDED_REQUEST_STORE);
278 }
279
280
281
282
283
284
285
286 protected Map<Object, Request> createRequestStore(
287 @SuppressWarnings("unused") IoSession session) {
288 return new HashMap<Object, Request>();
289 }
290
291
292
293
294
295
296
297
298
299
300
301
302
303 protected Set<Request> createUnrespondedRequestStore(
304 @SuppressWarnings("unused") IoSession session) {
305 return new LinkedHashSet<Request>();
306 }
307
308
309
310
311
312
313
314
315 protected void destroyRequestStore(
316 @SuppressWarnings("unused")
317 Map<Object, Request> requestStore) {
318 }
319
320
321
322
323
324
325
326
327 protected void destroyUnrespondedRequestStore(
328 @SuppressWarnings("unused")
329 Set<Request> unrespondedRequestStore) {
330 }
331
332 private class TimeoutTask implements Runnable {
333 private final NextFilter filter;
334
335 private final Request request;
336
337 private final IoSession session;
338
339 private TimeoutTask(NextFilter filter, Request request,
340 IoSession session) {
341 this.filter = filter;
342 this.request = request;
343 this.session = session;
344 }
345
346 public void run() {
347 Set<Request> unrespondedRequests = getUnrespondedRequestStore(session);
348 if (unrespondedRequests != null) {
349 synchronized (unrespondedRequests) {
350 unrespondedRequests.remove(request);
351 }
352 }
353
354 Map<Object, Request> requestStore = getRequestStore(session);
355 Object requestId = request.getId();
356 boolean timedOut;
357 synchronized (requestStore) {
358 if (requestStore.get(requestId) == request) {
359 requestStore.remove(requestId);
360 timedOut = true;
361 } else {
362 timedOut = false;
363 }
364 }
365
366 if (timedOut) {
367
368 RequestTimeoutException e = new RequestTimeoutException(request);
369 request.signal(e);
370 filter.exceptionCaught(session, e);
371 }
372 }
373 }
374 }