1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20 package org.apache.mina.filter.traffic;
21
22 import java.util.concurrent.ScheduledExecutorService;
23 import java.util.concurrent.TimeUnit;
24
25 import org.apache.mina.common.AttributeKey;
26 import org.apache.mina.common.IoBuffer;
27 import org.apache.mina.common.IoFilter;
28 import org.apache.mina.common.IoFilterAdapter;
29 import org.apache.mina.common.IoFilterChain;
30 import org.apache.mina.common.IoSession;
31 import org.apache.mina.common.TrafficMask;
32 import org.apache.mina.common.WriteRequest;
33
34
35
36
37
38
39
40
41
42
43
44 public class TrafficShapingFilter extends IoFilterAdapter {
45
46 private final AttributeKey STATE = new AttributeKey(getClass(), "state");
47
48 private final ScheduledExecutorService scheduledExecutor;
49 private final MessageSizeEstimator messageSizeEstimator;
50 private volatile int maxReadThroughput;
51 private volatile int maxWriteThroughput;
52
53 public TrafficShapingFilter(
54 ScheduledExecutorService scheduledExecutor,
55 int maxReadThroughput, int maxWriteThroughput) {
56 this(scheduledExecutor, null, maxReadThroughput, maxWriteThroughput);
57 }
58
59 public TrafficShapingFilter(
60 ScheduledExecutorService scheduledExecutor,
61 MessageSizeEstimator messageSizeEstimator,
62 int maxReadThroughput, int maxWriteThroughput) {
63 if (scheduledExecutor == null) {
64 throw new NullPointerException("scheduledExecutor");
65 }
66
67 if (messageSizeEstimator == null) {
68 messageSizeEstimator = new DefaultMessageSizeEstimator() {
69 @Override
70 public int estimateSize(Object message) {
71 if (message instanceof IoBuffer) {
72 return ((IoBuffer) message).remaining();
73 }
74 return super.estimateSize(message);
75 }
76 };
77 }
78
79 this.scheduledExecutor = scheduledExecutor;
80 this.messageSizeEstimator = messageSizeEstimator;
81 setMaxReadThroughput(maxReadThroughput);
82 setMaxWriteThroughput(maxWriteThroughput);
83 }
84
85 public ScheduledExecutorService getScheduledExecutor() {
86 return scheduledExecutor;
87 }
88
89 public MessageSizeEstimator getMessageSizeEstimator() {
90 return messageSizeEstimator;
91 }
92
93 public int getMaxReadThroughput() {
94 return maxReadThroughput;
95 }
96
97 public void setMaxReadThroughput(int maxReadThroughput) {
98 if (maxReadThroughput < 0) {
99 maxReadThroughput = 0;
100 }
101 this.maxReadThroughput = maxReadThroughput;
102 }
103
104 public int getMaxWriteThroughput() {
105 return maxWriteThroughput;
106 }
107
108 public void setMaxWriteThroughput(int maxWriteThroughput) {
109 if (maxWriteThroughput < 0) {
110 maxWriteThroughput = 0;
111 }
112 this.maxWriteThroughput = maxWriteThroughput;
113 }
114
115 @Override
116 public void onPreAdd(IoFilterChain parent, String name,
117 NextFilter nextFilter) throws Exception {
118 if (parent.contains(this)) {
119 throw new IllegalArgumentException(
120 "You can't add the same filter instance more than once. Create another instance and add it.");
121 }
122 parent.getSession().setAttribute(STATE, new State());
123 adjustReadBufferSize(parent.getSession());
124 }
125
126 @Override
127 public void onPostRemove(IoFilterChain parent, String name,
128 NextFilter nextFilter) throws Exception {
129 parent.getSession().removeAttribute(STATE);
130 }
131
132 @Override
133 public void messageReceived(NextFilter nextFilter, final IoSession session,
134 Object message) throws Exception {
135
136 int maxReadThroughput = this.maxReadThroughput;
137 if (maxReadThroughput == 0) {
138 nextFilter.messageReceived(session, message);
139 }
140
141 final State state = (State) session.getAttribute(STATE);
142 long currentTime = System.currentTimeMillis();
143
144 long suspendTime = 0;
145 boolean firstRead = false;
146 synchronized (state) {
147 state.readBytes += messageSizeEstimator.estimateSize(message);
148
149 if (!state.suspendedRead) {
150 if (state.readStartTime == 0) {
151 firstRead = true;
152 state.readStartTime = currentTime - 1000;
153 }
154
155 long throughput =
156 (state.readBytes * 1000 / (currentTime - state.readStartTime));
157 if (throughput >= maxReadThroughput) {
158 suspendTime = Math.max(
159 0,
160 state.readBytes * 1000 / maxReadThroughput -
161 (firstRead? 0 : currentTime - state.readStartTime));
162
163 state.readBytes = 0;
164 state.readStartTime = 0;
165 state.suspendedRead = suspendTime != 0;
166
167 adjustReadBufferSize(session);
168 }
169 }
170 }
171
172 if (suspendTime != 0) {
173 session.suspendRead();
174 scheduledExecutor.schedule(new Runnable() {
175 public void run() {
176 synchronized (state) {
177 state.suspendedRead = false;
178 }
179 session.resumeRead();
180 }
181 }, suspendTime, TimeUnit.MILLISECONDS);
182 }
183
184 nextFilter.messageReceived(session, message);
185 }
186
187 private void adjustReadBufferSize(IoSession session) {
188 int maxReadThroughput = this.maxReadThroughput;
189 if (maxReadThroughput == 0) {
190 return;
191 }
192
193 if (session.getConfig().getReadBufferSize() > maxReadThroughput) {
194 session.getConfig().setReadBufferSize(maxReadThroughput);
195 }
196 if (session.getConfig().getMaxReadBufferSize() > maxReadThroughput) {
197 session.getConfig().setMaxReadBufferSize(maxReadThroughput);
198 }
199 }
200
201 @Override
202 public void messageSent(NextFilter nextFilter, final IoSession session,
203 WriteRequest writeRequest) throws Exception {
204
205 int maxWriteThroughput = this.maxWriteThroughput;
206 if (maxWriteThroughput == 0) {
207 nextFilter.messageSent(session, writeRequest);
208 }
209
210 final State state = (State) session.getAttribute(STATE);
211 long currentTime = System.currentTimeMillis();
212
213 long suspendTime = 0;
214 boolean firstWrite = false;
215 synchronized (state) {
216 state.writtenBytes += messageSizeEstimator.estimateSize(writeRequest.getMessage());
217 if (!state.suspendedWrite) {
218 if (state.writeStartTime == 0) {
219 firstWrite = true;
220 state.writeStartTime = currentTime - 1000;
221 }
222
223 long throughput =
224 (state.writtenBytes * 1000 / (currentTime - state.writeStartTime));
225 if (throughput >= maxWriteThroughput) {
226 suspendTime = Math.max(
227 0,
228 state.writtenBytes * 1000 / maxWriteThroughput -
229 (firstWrite? 0 : currentTime - state.writeStartTime));
230
231 state.writtenBytes = 0;
232 state.writeStartTime = 0;
233 state.suspendedWrite = suspendTime != 0;
234 }
235 }
236 }
237
238 if (suspendTime != 0) {
239 session.suspendWrite();
240 scheduledExecutor.schedule(new Runnable() {
241 public void run() {
242 synchronized (state) {
243 state.suspendedWrite = false;
244 }
245 session.resumeWrite();
246 }
247 }, suspendTime, TimeUnit.MILLISECONDS);
248 }
249
250 nextFilter.messageSent(session, writeRequest);
251 }
252
253 @Override
254 public void filterSetTrafficMask(NextFilter nextFilter, IoSession session,
255 TrafficMask trafficMask) throws Exception {
256 State state = (State) session.getAttribute(STATE);
257 boolean suspendedRead;
258 boolean suspendedWrite;
259 synchronized (state) {
260 suspendedRead = state.suspendedRead;
261 suspendedWrite = state.suspendedWrite;
262 }
263
264 if (suspendedRead) {
265 trafficMask = trafficMask.and(TrafficMask.WRITE);
266 }
267
268 if (suspendedWrite) {
269 trafficMask = trafficMask.and(TrafficMask.READ);
270 }
271
272 nextFilter.filterSetTrafficMask(session, trafficMask);
273 }
274
275 private static class State {
276 private long readStartTime;
277 private long writeStartTime;
278 private boolean suspendedRead;
279 private boolean suspendedWrite;
280 private long readBytes;
281 private long writtenBytes;
282 }
283 }