001    /****************************************************************
002     * Licensed to the Apache Software Foundation (ASF) under one   *
003     * or more contributor license agreements.  See the NOTICE file *
004     * distributed with this work for additional information        *
005     * regarding copyright ownership.  The ASF licenses this file   *
006     * to you under the Apache License, Version 2.0 (the            *
007     * "License"); you may not use this file except in compliance   *
008     * with the License.  You may obtain a copy of the License at   *
009     *                                                              *
010     *   http://www.apache.org/licenses/LICENSE-2.0                 *
011     *                                                              *
012     * Unless required by applicable law or agreed to in writing,   *
013     * software distributed under the License is distributed on an  *
014     * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY       *
015     * KIND, either express or implied.  See the License for the    *
016     * specific language governing permissions and limitations      *
017     * under the License.                                           *
018     ****************************************************************/
019    
020    package org.apache.james.mime4j.codec;
021    
022    import java.io.IOException;
023    import java.io.InputStream;
024    
025    import org.apache.james.mime4j.util.ByteArrayBuffer;
026    
027    /**
028     * Performs Base-64 decoding on an underlying stream.
029     */
030    public class Base64InputStream extends InputStream {
031        private static final int ENCODED_BUFFER_SIZE = 1536;
032    
033        private static final int[] BASE64_DECODE = new int[256];
034    
035        static {
036            for (int i = 0; i < 256; i++)
037                BASE64_DECODE[i] = -1;
038            for (int i = 0; i < Base64OutputStream.BASE64_TABLE.length; i++)
039                BASE64_DECODE[Base64OutputStream.BASE64_TABLE[i] & 0xff] = i;
040        }
041    
042        private static final byte BASE64_PAD = '=';
043    
044        private static final int EOF = -1;
045    
046        private final byte[] singleByte = new byte[1];
047    
048        private final InputStream in;
049        private final byte[] encoded;
050        private final ByteArrayBuffer decodedBuf;
051    
052        private int position = 0; // current index into encoded buffer
053        private int size = 0; // current size of encoded buffer
054    
055        private boolean closed = false;
056        private boolean eof; // end of file or pad character reached
057    
058        private final DecodeMonitor monitor;
059    
060        public Base64InputStream(InputStream in, DecodeMonitor monitor) {
061            this(ENCODED_BUFFER_SIZE, in, monitor);
062        }
063    
064        protected Base64InputStream(int bufsize, InputStream in, DecodeMonitor monitor) {
065            if (in == null)
066                throw new IllegalArgumentException();
067            this.encoded = new byte[bufsize];
068            this.decodedBuf = new ByteArrayBuffer(512);
069            this.in = in;
070            this.monitor = monitor;
071        }
072    
073        public Base64InputStream(InputStream in) {
074            this(in, false);
075        }
076    
077        public Base64InputStream(InputStream in, boolean strict) {
078            this(ENCODED_BUFFER_SIZE, in, strict ? DecodeMonitor.STRICT : DecodeMonitor.SILENT);
079        }
080    
081        @Override
082        public int read() throws IOException {
083            if (closed)
084                throw new IOException("Stream has been closed");
085    
086            while (true) {
087                int bytes = read0(singleByte, 0, 1);
088                if (bytes == EOF)
089                    return EOF;
090    
091                if (bytes == 1)
092                    return singleByte[0] & 0xff;
093            }
094        }
095    
096        @Override
097        public int read(byte[] buffer) throws IOException {
098            if (closed)
099                throw new IOException("Stream has been closed");
100    
101            if (buffer == null)
102                throw new NullPointerException();
103    
104            if (buffer.length == 0)
105                return 0;
106    
107            return read0(buffer, 0, buffer.length);
108        }
109    
110        @Override
111        public int read(byte[] buffer, int offset, int length) throws IOException {
112            if (closed)
113                throw new IOException("Stream has been closed");
114    
115            if (buffer == null)
116                throw new NullPointerException();
117    
118            if (offset < 0 || length < 0 || offset + length > buffer.length)
119                throw new IndexOutOfBoundsException();
120    
121            if (length == 0)
122                return 0;
123    
124            return read0(buffer, offset, length);
125        }
126    
127        @Override
128        public void close() throws IOException {
129            if (closed)
130                return;
131    
132            closed = true;
133        }
134    
135        private int read0(final byte[] buffer, final int off, final int len) throws IOException {
136            int from = off;
137            int to = off + len;
138            int index = off;
139    
140            // check if a previous invocation left decoded content
141            if (decodedBuf.length() > 0) {
142                int chunk = Math.min(decodedBuf.length(), len);
143                System.arraycopy(decodedBuf.buffer(), 0, buffer, index, chunk);
144                decodedBuf.remove(0, chunk);
145                index += chunk;
146            }
147    
148            // eof or pad reached?
149    
150            if (eof)
151                return index == from ? EOF : index - from;
152    
153            // decode into given buffer
154    
155            int data = 0; // holds decoded data; up to four sextets
156            int sextets = 0; // number of sextets
157    
158            while (index < to) {
159                // make sure buffer not empty
160    
161                while (position == size) {
162                    int n = in.read(encoded, 0, encoded.length);
163                    if (n == EOF) {
164                        eof = true;
165    
166                        if (sextets != 0) {
167                            // error in encoded data
168                            handleUnexpectedEof(sextets);
169                        }
170    
171                        return index == from ? EOF : index - from;
172                    } else if (n > 0) {
173                        position = 0;
174                        size = n;
175                    } else {
176                        assert n == 0;
177                    }
178                }
179    
180                // decode buffer
181    
182                while (position < size && index < to) {
183                    int value = encoded[position++] & 0xff;
184    
185                    if (value == BASE64_PAD) {
186                        index = decodePad(data, sextets, buffer, index, to);
187                        return index - from;
188                    }
189    
190                    int decoded = BASE64_DECODE[value];
191                    if (decoded < 0) { // -1: not a base64 char
192                        if (value != 0x0D && value != 0x0A && value != 0x20) {
193                            if (monitor.warn("Unexpected base64 byte: "+(byte) value, "ignoring."))
194                                throw new IOException("Unexpected base64 byte");
195                        }
196                        continue;
197                    }
198    
199                    data = (data << 6) | decoded;
200                    sextets++;
201    
202                    if (sextets == 4) {
203                        sextets = 0;
204    
205                        byte b1 = (byte) (data >>> 16);
206                        byte b2 = (byte) (data >>> 8);
207                        byte b3 = (byte) data;
208    
209                        if (index < to - 2) {
210                            buffer[index++] = b1;
211                            buffer[index++] = b2;
212                            buffer[index++] = b3;
213                        } else {
214                            if (index < to - 1) {
215                                buffer[index++] = b1;
216                                buffer[index++] = b2;
217                                decodedBuf.append(b3);
218                            } else if (index < to) {
219                                buffer[index++] = b1;
220                                decodedBuf.append(b2);
221                                decodedBuf.append(b3);
222                            } else {
223                                decodedBuf.append(b1);
224                                decodedBuf.append(b2);
225                                decodedBuf.append(b3);
226                            }
227    
228                            assert index == to;
229                            return to - from;
230                        }
231                    }
232                }
233            }
234    
235            assert sextets == 0;
236            assert index == to;
237            return to - from;
238        }
239    
240        private int decodePad(int data, int sextets, final byte[] buffer,
241                int index, final int end) throws IOException {
242            eof = true;
243    
244            if (sextets == 2) {
245                // one byte encoded as "XY=="
246    
247                byte b = (byte) (data >>> 4);
248                if (index < end) {
249                    buffer[index++] = b;
250                } else {
251                    decodedBuf.append(b);
252                }
253            } else if (sextets == 3) {
254                // two bytes encoded as "XYZ="
255    
256                byte b1 = (byte) (data >>> 10);
257                byte b2 = (byte) ((data >>> 2) & 0xFF);
258    
259                if (index < end - 1) {
260                    buffer[index++] = b1;
261                    buffer[index++] = b2;
262                } else if (index < end) {
263                    buffer[index++] = b1;
264                    decodedBuf.append(b2);
265                } else {
266                    decodedBuf.append(b1);
267                    decodedBuf.append(b2);
268                }
269            } else {
270                // error in encoded data
271                handleUnexpecedPad(sextets);
272            }
273    
274            return index;
275        }
276    
277        private void handleUnexpectedEof(int sextets) throws IOException {
278            if (monitor.warn("Unexpected end of BASE64 stream", "dropping " + sextets + " sextet(s)"))
279                throw new IOException("Unexpected end of BASE64 stream");
280        }
281    
282        private void handleUnexpecedPad(int sextets) throws IOException {
283            if (monitor.warn("Unexpected padding character", "dropping " + sextets + " sextet(s)"))
284                throw new IOException("Unexpected padding character");
285        }
286    }