1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20 package org.apache.mina.filter.ssl;
21
22 import java.net.InetSocketAddress;
23 import java.nio.ByteBuffer;
24 import java.util.Queue;
25 import java.util.concurrent.ConcurrentLinkedQueue;
26
27 import javax.net.ssl.SSLContext;
28 import javax.net.ssl.SSLEngine;
29 import javax.net.ssl.SSLEngineResult;
30 import javax.net.ssl.SSLException;
31 import javax.net.ssl.SSLHandshakeException;
32
33 import org.apache.mina.common.DefaultWriteFuture;
34 import org.apache.mina.common.DefaultWriteRequest;
35 import org.apache.mina.common.IoBuffer;
36 import org.apache.mina.common.IoEventType;
37 import org.apache.mina.common.IoFilterEvent;
38 import org.apache.mina.common.IoSession;
39 import org.apache.mina.common.WriteFuture;
40 import org.apache.mina.common.WriteRequest;
41 import org.apache.mina.common.IoFilter.NextFilter;
42 import org.apache.mina.util.CircularQueue;
43 import org.slf4j.Logger;
44 import org.slf4j.LoggerFactory;
45
46
47
48
49
50
51
52
53
54
55
56
57 class SslHandler {
58
59 private final Logger logger = LoggerFactory.getLogger(getClass());
60 private final SslFilter parent;
61 private final SSLContext ctx;
62 private final IoSession session;
63 private final Queue<IoFilterEvent> preHandshakeEventQueue = new CircularQueue<IoFilterEvent>();
64 private final Queue<IoFilterEvent> filterWriteEventQueue = new ConcurrentLinkedQueue<IoFilterEvent>();
65 private final Queue<IoFilterEvent> messageReceivedEventQueue = new ConcurrentLinkedQueue<IoFilterEvent>();
66 private SSLEngine sslEngine;
67
68
69
70
71 private IoBuffer inNetBuffer;
72
73
74
75
76 private IoBuffer outNetBuffer;
77
78
79
80
81 private IoBuffer appBuffer;
82
83
84
85
86 private final IoBuffer emptyBuffer = IoBuffer.allocate(0);
87
88 private SSLEngineResult.HandshakeStatus handshakeStatus;
89 private boolean initialHandshakeComplete;
90 private boolean handshakeComplete;
91 private boolean writingEncryptedData;
92
93
94
95
96
97
98
99 public SslHandler(SslFilter parent, SSLContext sslc, IoSession session)
100 throws SSLException {
101 this.parent = parent;
102 this.session = session;
103 this.ctx = sslc;
104 init();
105 }
106
107 public void init() throws SSLException {
108 if (sslEngine != null) {
109 return;
110 }
111
112 InetSocketAddress peer = (InetSocketAddress) session
113 .getAttribute(SslFilter.PEER_ADDRESS);
114 if (peer == null) {
115 sslEngine = ctx.createSSLEngine();
116 } else {
117 sslEngine = ctx.createSSLEngine(peer.getHostName(), peer.getPort());
118 }
119 sslEngine.setUseClientMode(parent.isUseClientMode());
120
121 if (parent.isWantClientAuth()) {
122 sslEngine.setWantClientAuth(true);
123 }
124
125 if (parent.isNeedClientAuth()) {
126 sslEngine.setNeedClientAuth(true);
127 }
128
129 if (parent.getEnabledCipherSuites() != null) {
130 sslEngine.setEnabledCipherSuites(parent.getEnabledCipherSuites());
131 }
132
133 if (parent.getEnabledProtocols() != null) {
134 sslEngine.setEnabledProtocols(parent.getEnabledProtocols());
135 }
136
137 sslEngine.beginHandshake();
138 handshakeStatus = sslEngine.getHandshakeStatus();
139
140 handshakeComplete = false;
141 initialHandshakeComplete = false;
142 writingEncryptedData = false;
143 }
144
145
146
147
148 public void destroy() {
149 if (sslEngine == null) {
150 return;
151 }
152
153
154 try {
155 sslEngine.closeInbound();
156 } catch (SSLException e) {
157 logger.debug(
158 "Unexpected exception from SSLEngine.closeInbound().", e);
159 }
160
161
162 if (outNetBuffer != null) {
163 outNetBuffer.capacity(sslEngine.getSession().getPacketBufferSize());
164 } else {
165 createOutNetBuffer(0);
166 }
167 try {
168 do {
169 outNetBuffer.clear();
170 } while (sslEngine.wrap(emptyBuffer.buf(), outNetBuffer.buf()).bytesProduced() > 0);
171 } catch (SSLException e) {
172
173 } finally {
174 destroyOutNetBuffer();
175 }
176
177 sslEngine.closeOutbound();
178 sslEngine = null;
179
180 preHandshakeEventQueue.clear();
181 }
182
183 private void destroyOutNetBuffer() {
184 outNetBuffer.free();
185 outNetBuffer = null;
186 }
187
188 public SslFilter getParent() {
189 return parent;
190 }
191
192 public IoSession getSession() {
193 return session;
194 }
195
196
197
198
199 public boolean isWritingEncryptedData() {
200 return writingEncryptedData;
201 }
202
203
204
205
206 public boolean isHandshakeComplete() {
207 return handshakeComplete;
208 }
209
210 public boolean isInboundDone() {
211 return sslEngine == null || sslEngine.isInboundDone();
212 }
213
214 public boolean isOutboundDone() {
215 return sslEngine == null || sslEngine.isOutboundDone();
216 }
217
218
219
220
221 public boolean needToCompleteHandshake() {
222 return handshakeStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP && !isInboundDone();
223 }
224
225 public void schedulePreHandshakeWriteRequest(NextFilter nextFilter,
226 WriteRequest writeRequest) {
227 preHandshakeEventQueue.add(new IoFilterEvent(nextFilter,
228 IoEventType.WRITE, session, writeRequest));
229 }
230
231 public void flushPreHandshakeEvents() throws SSLException {
232 IoFilterEvent scheduledWrite;
233
234 while ((scheduledWrite = preHandshakeEventQueue.poll()) != null) {
235 parent.filterWrite(scheduledWrite.getNextFilter(), session,
236 (WriteRequest) scheduledWrite.getParameter());
237 }
238 }
239
240 public void scheduleFilterWrite(NextFilter nextFilter, WriteRequest writeRequest) {
241 filterWriteEventQueue.add(new IoFilterEvent(nextFilter, IoEventType.WRITE, session, writeRequest));
242 }
243
244 public void scheduleMessageReceived(NextFilter nextFilter, Object message) {
245 messageReceivedEventQueue.add(new IoFilterEvent(nextFilter, IoEventType.MESSAGE_RECEIVED, session, message));
246 }
247
248 public void flushScheduledEvents() {
249
250 if (Thread.holdsLock(this)) {
251 return;
252 }
253
254 IoFilterEvent e;
255
256
257
258 synchronized (this) {
259 while ((e = filterWriteEventQueue.poll()) != null) {
260 e.getNextFilter().filterWrite(session, (WriteRequest) e.getParameter());
261 }
262 }
263
264 while ((e = messageReceivedEventQueue.poll()) != null) {
265 e.getNextFilter().messageReceived(session, e.getParameter());
266 }
267 }
268
269
270
271
272
273
274
275
276
277
278 public void messageReceived(NextFilter nextFilter, ByteBuffer buf) throws SSLException {
279
280 if (inNetBuffer == null) {
281 inNetBuffer = IoBuffer.allocate(buf.remaining()).setAutoExpand(true);
282 }
283
284 inNetBuffer.put(buf);
285 if (!handshakeComplete) {
286 handshake(nextFilter);
287 } else {
288 decrypt(nextFilter);
289 }
290
291 if (isInboundDone()) {
292
293 int inNetBufferPosition = inNetBuffer == null? 0 : inNetBuffer.position();
294 buf.position(buf.position() - inNetBufferPosition);
295 inNetBuffer = null;
296 }
297 }
298
299
300
301
302
303
304 public IoBuffer fetchAppBuffer() {
305 IoBuffer appBuffer = this.appBuffer.flip();
306 this.appBuffer = null;
307 return appBuffer;
308 }
309
310
311
312
313
314
315 public IoBuffer fetchOutNetBuffer() {
316 IoBuffer answer = outNetBuffer;
317 if (answer == null) {
318 return emptyBuffer;
319 }
320
321 outNetBuffer = null;
322 return answer.shrink();
323 }
324
325
326
327
328
329
330
331 public void encrypt(ByteBuffer src) throws SSLException {
332 if (!handshakeComplete) {
333 throw new IllegalStateException();
334 }
335
336 if (!src.hasRemaining()) {
337 if (outNetBuffer == null) {
338 outNetBuffer = emptyBuffer;
339 }
340 return;
341 }
342
343 createOutNetBuffer(src.remaining());
344
345
346 while (src.hasRemaining()) {
347
348 SSLEngineResult result = sslEngine.wrap(src, outNetBuffer.buf());
349 if (result.getStatus() == SSLEngineResult.Status.OK) {
350 if (result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_TASK) {
351 doTasks();
352 }
353 } else if (result.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) {
354 outNetBuffer.capacity(outNetBuffer.capacity() << 1);
355 outNetBuffer.limit(outNetBuffer.capacity());
356 } else {
357 throw new SSLException("SSLEngine error during encrypt: "
358 + result.getStatus() + " src: " + src
359 + "outNetBuffer: " + outNetBuffer);
360 }
361 }
362
363 outNetBuffer.flip();
364 }
365
366
367
368
369
370
371
372
373 public boolean closeOutbound() throws SSLException {
374 if (sslEngine == null || sslEngine.isOutboundDone()) {
375 return false;
376 }
377
378 sslEngine.closeOutbound();
379
380 createOutNetBuffer(0);
381 SSLEngineResult result;
382 for (;;) {
383 result = sslEngine.wrap(emptyBuffer.buf(), outNetBuffer.buf());
384 if (result.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) {
385 outNetBuffer.capacity(outNetBuffer.capacity() << 1);
386 outNetBuffer.limit(outNetBuffer.capacity());
387 } else {
388 break;
389 }
390 }
391
392 if (result.getStatus() != SSLEngineResult.Status.CLOSED) {
393 throw new SSLException("Improper close state: " + result);
394 }
395 outNetBuffer.flip();
396 return true;
397 }
398
399
400
401
402
403
404 private void decrypt(NextFilter nextFilter) throws SSLException {
405
406 if (!handshakeComplete) {
407 throw new IllegalStateException();
408 }
409
410 unwrap(nextFilter);
411 }
412
413
414
415
416
417 private void checkStatus(SSLEngineResult res)
418 throws SSLException {
419
420 SSLEngineResult.Status status = res.getStatus();
421
422
423
424
425
426
427
428
429
430 if (status != SSLEngineResult.Status.OK
431 && status != SSLEngineResult.Status.CLOSED
432 && status != SSLEngineResult.Status.BUFFER_UNDERFLOW) {
433 throw new SSLException("SSLEngine error during decrypt: " + status
434 + " inNetBuffer: " + inNetBuffer + "appBuffer: "
435 + appBuffer);
436 }
437 }
438
439
440
441
442 public void handshake(NextFilter nextFilter) throws SSLException {
443 for (; ;) {
444 if (handshakeStatus == SSLEngineResult.HandshakeStatus.FINISHED) {
445 session.setAttribute(
446 SslFilter.SSL_SESSION, sslEngine.getSession());
447 handshakeComplete = true;
448 if (!initialHandshakeComplete
449 && session.containsAttribute(SslFilter.USE_NOTIFICATION)) {
450
451
452 initialHandshakeComplete = true;
453 scheduleMessageReceived(nextFilter,
454 SslFilter.SESSION_SECURED);
455 }
456 break;
457 } else if (handshakeStatus == SSLEngineResult.HandshakeStatus.NEED_TASK) {
458 handshakeStatus = doTasks();
459 } else if (handshakeStatus == SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
460
461 SSLEngineResult.Status status = unwrapHandshake(nextFilter);
462 if (status == SSLEngineResult.Status.BUFFER_UNDERFLOW
463 || isInboundDone()) {
464
465 break;
466 }
467 } else if (handshakeStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP) {
468
469
470 if (outNetBuffer != null && outNetBuffer.hasRemaining()) {
471 break;
472 }
473
474 SSLEngineResult result;
475 createOutNetBuffer(0);
476 for (;;) {
477 result = sslEngine.wrap(emptyBuffer.buf(), outNetBuffer.buf());
478 if (result.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) {
479 outNetBuffer.capacity(outNetBuffer.capacity() << 1);
480 outNetBuffer.limit(outNetBuffer.capacity());
481 } else {
482 break;
483 }
484 }
485
486 outNetBuffer.flip();
487 handshakeStatus = result.getHandshakeStatus();
488 writeNetBuffer(nextFilter);
489 } else {
490 throw new IllegalStateException("Invalid Handshaking State"
491 + handshakeStatus);
492 }
493 }
494 }
495
496 private void createOutNetBuffer(int expectedRemaining) {
497
498
499 int capacity = Math.max(
500 expectedRemaining,
501 sslEngine.getSession().getPacketBufferSize());
502
503 if (outNetBuffer != null) {
504 outNetBuffer.capacity(capacity);
505 } else {
506 outNetBuffer = IoBuffer.allocate(capacity).minimumCapacity(0);
507 }
508 }
509
510 public WriteFuture writeNetBuffer(NextFilter nextFilter)
511 throws SSLException {
512
513 if (outNetBuffer == null || !outNetBuffer.hasRemaining()) {
514
515 return null;
516 }
517
518
519
520 writingEncryptedData = true;
521
522
523 WriteFuture writeFuture = null;
524
525 try {
526 IoBuffer writeBuffer = fetchOutNetBuffer();
527 writeFuture = new DefaultWriteFuture(session);
528 parent.filterWrite(nextFilter, session, new DefaultWriteRequest(
529 writeBuffer, writeFuture));
530
531
532 while (needToCompleteHandshake()) {
533 try {
534 handshake(nextFilter);
535 } catch (SSLException ssle) {
536 SSLException newSsle = new SSLHandshakeException(
537 "SSL handshake failed.");
538 newSsle.initCause(ssle);
539 throw newSsle;
540 }
541
542 IoBuffer outNetBuffer = fetchOutNetBuffer();
543 if (outNetBuffer != null && outNetBuffer.hasRemaining()) {
544 writeFuture = new DefaultWriteFuture(session);
545 parent.filterWrite(nextFilter, session,
546 new DefaultWriteRequest(outNetBuffer, writeFuture));
547 }
548 }
549 } finally {
550 writingEncryptedData = false;
551 }
552
553 return writeFuture;
554 }
555
556 private void unwrap(NextFilter nextFilter) throws SSLException {
557
558 if (inNetBuffer != null) {
559 inNetBuffer.flip();
560 }
561
562 if (inNetBuffer == null || !inNetBuffer.hasRemaining()) {
563 return;
564 }
565
566 SSLEngineResult res = unwrap0();
567
568
569 if (inNetBuffer.hasRemaining()) {
570 inNetBuffer.compact();
571 } else {
572 inNetBuffer = null;
573 }
574
575 checkStatus(res);
576
577 renegotiateIfNeeded(nextFilter, res);
578 }
579
580 private SSLEngineResult.Status unwrapHandshake(NextFilter nextFilter) throws SSLException {
581
582 if (inNetBuffer != null) {
583 inNetBuffer.flip();
584 }
585
586 if (inNetBuffer == null || !inNetBuffer.hasRemaining()) {
587
588 return SSLEngineResult.Status.BUFFER_UNDERFLOW;
589 }
590
591 SSLEngineResult res = unwrap0();
592 handshakeStatus = res.getHandshakeStatus();
593
594 checkStatus(res);
595
596
597
598 if (handshakeStatus == SSLEngineResult.HandshakeStatus.FINISHED
599 && res.getStatus() == SSLEngineResult.Status.OK
600 && inNetBuffer.hasRemaining()) {
601 res = unwrap0();
602
603
604 if (inNetBuffer.hasRemaining()) {
605 inNetBuffer.compact();
606 } else {
607 inNetBuffer = null;
608 }
609
610 renegotiateIfNeeded(nextFilter, res);
611 } else {
612
613 if (inNetBuffer.hasRemaining()) {
614 inNetBuffer.compact();
615 } else {
616 inNetBuffer = null;
617 }
618 }
619
620 return res.getStatus();
621 }
622
623 private void renegotiateIfNeeded(NextFilter nextFilter, SSLEngineResult res)
624 throws SSLException {
625 if (res.getStatus() != SSLEngineResult.Status.CLOSED
626 && res.getStatus() != SSLEngineResult.Status.BUFFER_UNDERFLOW
627 && res.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
628
629 handshakeComplete = false;
630 handshakeStatus = res.getHandshakeStatus();
631 handshake(nextFilter);
632 }
633 }
634
635 private SSLEngineResult unwrap0() throws SSLException {
636 if (appBuffer == null) {
637 appBuffer = IoBuffer.allocate(inNetBuffer.remaining());
638 } else {
639 appBuffer.expand(inNetBuffer.remaining());
640 }
641
642 SSLEngineResult res;
643 do {
644 res = sslEngine.unwrap(inNetBuffer.buf(), appBuffer.buf());
645 if (res.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) {
646 appBuffer.capacity(appBuffer.capacity() << 1);
647 appBuffer.limit(appBuffer.capacity());
648 continue;
649 }
650 } while ((res.getStatus() == SSLEngineResult.Status.OK || res.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) &&
651 (handshakeComplete && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING ||
652 res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP));
653
654 return res;
655 }
656
657
658
659
660 private SSLEngineResult.HandshakeStatus doTasks() {
661
662
663
664
665 Runnable runnable;
666 while ((runnable = sslEngine.getDelegatedTask()) != null) {
667 runnable.run();
668 }
669 return sslEngine.getHandshakeStatus();
670 }
671
672
673
674
675
676
677
678
679 public static IoBuffer copy(ByteBuffer src) {
680 IoBuffer copy = IoBuffer.allocate(src.remaining());
681 copy.put(src);
682 copy.flip();
683 return copy;
684 }
685 }