001/*
002 *  Licensed under the Apache License, Version 2.0 (the "License");
003 *  you may not use this file except in compliance with the License.
004 *  You may obtain a copy of the License at
005 *
006 *       http://www.apache.org/licenses/LICENSE-2.0
007 *
008 *  Unless required by applicable law or agreed to in writing, software
009 *  distributed under the License is distributed on an "AS IS" BASIS,
010 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
011 *  See the License for the specific language governing permissions and
012 *  limitations under the License.
013 *  under the License.
014 */
015
016package org.apache.commons.imaging.formats.jpeg.decoder;
017
018import static org.apache.commons.imaging.common.BinaryFunctions.read2Bytes;
019import static org.apache.commons.imaging.common.BinaryFunctions.readBytes;
020
021import java.awt.image.BufferedImage;
022import java.awt.image.ColorModel;
023import java.awt.image.DataBuffer;
024import java.awt.image.DirectColorModel;
025import java.awt.image.Raster;
026import java.awt.image.WritableRaster;
027import java.io.ByteArrayInputStream;
028import java.io.IOException;
029import java.util.Arrays;
030import java.util.Properties;
031
032import org.apache.commons.imaging.ImageReadException;
033import org.apache.commons.imaging.common.BinaryFileParser;
034import org.apache.commons.imaging.common.bytesource.ByteSource;
035import org.apache.commons.imaging.formats.jpeg.JpegConstants;
036import org.apache.commons.imaging.formats.jpeg.JpegUtils;
037import org.apache.commons.imaging.formats.jpeg.segments.DhtSegment;
038import org.apache.commons.imaging.formats.jpeg.segments.DqtSegment;
039import org.apache.commons.imaging.formats.jpeg.segments.SofnSegment;
040import org.apache.commons.imaging.formats.jpeg.segments.SosSegment;
041
042public class JpegDecoder extends BinaryFileParser implements JpegUtils.Visitor {
043    /*
044     * JPEG is an advanced image format that takes significant computation to
045     * decode. Keep decoding fast: - Don't allocate memory inside loops,
046     * allocate it once and reuse. - Minimize calculations per pixel and per
047     * block (using lookup tables for YCbCr->RGB conversion doubled
048     * performance). - Math.round() is slow, use (int)(x+0.5f) instead for
049     * positive numbers.
050     */
051
052    private final DqtSegment.QuantizationTable[] quantizationTables = new DqtSegment.QuantizationTable[4];
053    private final DhtSegment.HuffmanTable[] huffmanDCTables = new DhtSegment.HuffmanTable[4];
054    private final DhtSegment.HuffmanTable[] huffmanACTables = new DhtSegment.HuffmanTable[4];
055    private SofnSegment sofnSegment;
056    private SosSegment sosSegment;
057    private final float[][] scaledQuantizationTables = new float[4][];
058    private BufferedImage image;
059    private ImageReadException imageReadException;
060    private IOException ioException;
061    private final int[] zz = new int[64];
062    private final int[] blockInt = new int[64];
063    private final float[] block = new float[64];
064
065    @Override
066    public boolean beginSOS() {
067        return true;
068    }
069
070    @Override
071    public void visitSOS(final int marker, final byte[] markerBytes, final byte[] imageData) {
072        final ByteArrayInputStream is = new ByteArrayInputStream(imageData);
073        try {
074            final int segmentLength = read2Bytes("segmentLength", is, "Not a Valid JPEG File", getByteOrder());
075            final byte[] sosSegmentBytes = readBytes("SosSegment",
076                    is, segmentLength - 2, "Not a Valid JPEG File");
077            sosSegment = new SosSegment(marker, sosSegmentBytes);
078
079            int hMax = 0;
080            int vMax = 0;
081            for (int i = 0; i < sofnSegment.numberOfComponents; i++) {
082                hMax = Math.max(hMax,
083                        sofnSegment.getComponents(i).horizontalSamplingFactor);
084                vMax = Math.max(vMax,
085                        sofnSegment.getComponents(i).verticalSamplingFactor);
086            }
087            final int hSize = 8 * hMax;
088            final int vSize = 8 * vMax;
089
090            final JpegInputStream bitInputStream = new JpegInputStream(is);
091            final int xMCUs = (sofnSegment.width + hSize - 1) / hSize;
092            final int yMCUs = (sofnSegment.height + vSize - 1) / vSize;
093            final Block[] mcu = allocateMCUMemory();
094            final Block[] scaledMCU = new Block[mcu.length];
095            for (int i = 0; i < scaledMCU.length; i++) {
096                scaledMCU[i] = new Block(hSize, vSize);
097            }
098            final int[] preds = new int[sofnSegment.numberOfComponents];
099            ColorModel colorModel;
100            WritableRaster raster;
101            if (sofnSegment.numberOfComponents == 3) {
102                colorModel = new DirectColorModel(24, 0x00ff0000, 0x0000ff00,
103                        0x000000ff);
104                raster = Raster.createPackedRaster(DataBuffer.TYPE_INT,
105                        sofnSegment.width, sofnSegment.height, new int[] {
106                                0x00ff0000, 0x0000ff00, 0x000000ff }, null);
107            } else if (sofnSegment.numberOfComponents == 1) {
108                colorModel = new DirectColorModel(24, 0x00ff0000, 0x0000ff00,
109                        0x000000ff);
110                raster = Raster.createPackedRaster(DataBuffer.TYPE_INT,
111                        sofnSegment.width, sofnSegment.height, new int[] {
112                                0x00ff0000, 0x0000ff00, 0x000000ff }, null);
113                // FIXME: why do images come out too bright with CS_GRAY?
114                // colorModel = new ComponentColorModel(
115                // ColorSpace.getInstance(ColorSpace.CS_GRAY), false, true,
116                // Transparency.OPAQUE, DataBuffer.TYPE_BYTE);
117                // raster = colorModel.createCompatibleWritableRaster(
118                // sofnSegment.width, sofnSegment.height);
119            } else {
120                throw new ImageReadException(sofnSegment.numberOfComponents
121                        + " components are invalid or unsupported");
122            }
123            final DataBuffer dataBuffer = raster.getDataBuffer();
124
125            for (int y1 = 0; y1 < vSize * yMCUs; y1 += vSize) {
126                for (int x1 = 0; x1 < hSize * xMCUs; x1 += hSize) {
127                    readMCU(bitInputStream, preds, mcu);
128                    rescaleMCU(mcu, hSize, vSize, scaledMCU);
129                    int srcRowOffset = 0;
130                    int dstRowOffset = y1 * sofnSegment.width + x1;
131                    for (int y2 = 0; y2 < vSize && y1 + y2 < sofnSegment.height; y2++) {
132                        for (int x2 = 0; x2 < hSize
133                                && x1 + x2 < sofnSegment.width; x2++) {
134                            if (scaledMCU.length == 3) {
135                                final int Y = scaledMCU[0].samples[srcRowOffset + x2];
136                                final int Cb = scaledMCU[1].samples[srcRowOffset + x2];
137                                final int Cr = scaledMCU[2].samples[srcRowOffset + x2];
138                                final int rgb = YCbCrConverter.convertYCbCrToRGB(Y,
139                                        Cb, Cr);
140                                dataBuffer.setElem(dstRowOffset + x2, rgb);
141                            } else if (mcu.length == 1) {
142                                final int Y = scaledMCU[0].samples[srcRowOffset + x2];
143                                dataBuffer.setElem(dstRowOffset + x2, (Y << 16)
144                                        | (Y << 8) | Y);
145                            } else {
146                                throw new ImageReadException(
147                                        "Unsupported JPEG with " + mcu.length
148                                                + " components");
149                            }
150                        }
151                        srcRowOffset += hSize;
152                        dstRowOffset += sofnSegment.width;
153                    }
154                }
155            }
156            image = new BufferedImage(colorModel, raster,
157                    colorModel.isAlphaPremultiplied(), new Properties());
158            // byte[] remainder = super.getStreamBytes(is);
159            // for (int i = 0; i < remainder.length; i++)
160            // {
161            // System.out.println("" + i + " = " +
162            // Integer.toHexString(remainder[i]));
163            // }
164        } catch (final ImageReadException imageReadEx) {
165            imageReadException = imageReadEx;
166        } catch (final IOException ioEx) {
167            ioException = ioEx;
168        } catch (final RuntimeException ex) {
169            // Corrupt images can throw NPE and IOOBE
170            imageReadException = new ImageReadException("Error parsing JPEG",
171                    ex);
172        }
173    }
174
175    @Override
176    public boolean visitSegment(final int marker, final byte[] markerBytes,
177            final int segmentLength, final byte[] segmentLengthBytes, final byte[] segmentData)
178            throws ImageReadException, IOException {
179        final int[] sofnSegments = {
180                JpegConstants.SOF0_MARKER,
181                JpegConstants.SOF1_MARKER,
182                JpegConstants.SOF2_MARKER,
183                JpegConstants.SOF3_MARKER,
184                JpegConstants.SOF5_MARKER,
185                JpegConstants.SOF6_MARKER,
186                JpegConstants.SOF7_MARKER,
187                JpegConstants.SOF9_MARKER,
188                JpegConstants.SOF10_MARKER,
189                JpegConstants.SOF11_MARKER,
190                JpegConstants.SOF13_MARKER,
191                JpegConstants.SOF14_MARKER,
192                JpegConstants.SOF15_MARKER,
193        };
194
195        if (Arrays.binarySearch(sofnSegments, marker) >= 0) {
196            if (marker != JpegConstants.SOF0_MARKER) {
197                throw new ImageReadException("Only sequential, baseline JPEGs "
198                        + "are supported at the moment");
199            }
200            sofnSegment = new SofnSegment(marker, segmentData);
201        } else if (marker == JpegConstants.DQT_MARKER) {
202            final DqtSegment dqtSegment = new DqtSegment(marker, segmentData);
203            for (int i = 0; i < dqtSegment.quantizationTables.size(); i++) {
204                final DqtSegment.QuantizationTable table = dqtSegment.quantizationTables.get(i);
205                if (0 > table.destinationIdentifier
206                        || table.destinationIdentifier >= quantizationTables.length) {
207                    throw new ImageReadException(
208                            "Invalid quantization table identifier "
209                                    + table.destinationIdentifier);
210                }
211                quantizationTables[table.destinationIdentifier] = table;
212                final int[] quantizationMatrixInt = new int[64];
213                ZigZag.zigZagToBlock(table.getElements(), quantizationMatrixInt);
214                final float[] quantizationMatrixFloat = new float[64];
215                for (int j = 0; j < 64; j++) {
216                    quantizationMatrixFloat[j] = quantizationMatrixInt[j];
217                }
218                Dct.scaleDequantizationMatrix(quantizationMatrixFloat);
219                scaledQuantizationTables[table.destinationIdentifier] = quantizationMatrixFloat;
220            }
221        } else if (marker == JpegConstants.DHT_MARKER) {
222            final DhtSegment dhtSegment = new DhtSegment(marker, segmentData);
223            for (int i = 0; i < dhtSegment.huffmanTables.size(); i++) {
224                final DhtSegment.HuffmanTable table = dhtSegment.huffmanTables.get(i);
225                DhtSegment.HuffmanTable[] tables;
226                if (table.tableClass == 0) {
227                    tables = huffmanDCTables;
228                } else if (table.tableClass == 1) {
229                    tables = huffmanACTables;
230                } else {
231                    throw new ImageReadException("Invalid huffman table class "
232                            + table.tableClass);
233                }
234                if (0 > table.destinationIdentifier
235                        || table.destinationIdentifier >= tables.length) {
236                    throw new ImageReadException(
237                            "Invalid huffman table identifier "
238                                    + table.destinationIdentifier);
239                }
240                tables[table.destinationIdentifier] = table;
241            }
242        }
243        return true;
244    }
245
246    private void rescaleMCU(final Block[] dataUnits, final int hSize, final int vSize, final Block[] ret) {
247        for (int i = 0; i < dataUnits.length; i++) {
248            final Block dataUnit = dataUnits[i];
249            if (dataUnit.width == hSize && dataUnit.height == vSize) {
250                System.arraycopy(dataUnit.samples, 0, ret[i].samples, 0, hSize
251                        * vSize);
252            } else {
253                final int hScale = hSize / dataUnit.width;
254                final int vScale = vSize / dataUnit.height;
255                if (hScale == 2 && vScale == 2) {
256                    int srcRowOffset = 0;
257                    int dstRowOffset = 0;
258                    for (int y = 0; y < dataUnit.height; y++) {
259                        for (int x = 0; x < hSize; x++) {
260                            final int sample = dataUnit.samples[srcRowOffset + (x >> 1)];
261                            ret[i].samples[dstRowOffset + x] = sample;
262                            ret[i].samples[dstRowOffset + hSize + x] = sample;
263                        }
264                        srcRowOffset += dataUnit.width;
265                        dstRowOffset += 2 * hSize;
266                    }
267                } else {
268                    // FIXME: optimize
269                    int dstRowOffset = 0;
270                    for (int y = 0; y < vSize; y++) {
271                        for (int x = 0; x < hSize; x++) {
272                            ret[i].samples[dstRowOffset + x] = dataUnit.samples[(y / vScale)
273                                    * dataUnit.width + (x / hScale)];
274                        }
275                        dstRowOffset += hSize;
276                    }
277                }
278            }
279        }
280    }
281
282    private Block[] allocateMCUMemory() throws ImageReadException {
283        final Block[] mcu = new Block[sosSegment.numberOfComponents];
284        for (int i = 0; i < sosSegment.numberOfComponents; i++) {
285            final SosSegment.Component scanComponent = sosSegment.getComponents(i);
286            SofnSegment.Component frameComponent = null;
287            for (int j = 0; j < sofnSegment.numberOfComponents; j++) {
288                if (sofnSegment.getComponents(j).componentIdentifier == scanComponent.scanComponentSelector) {
289                    frameComponent = sofnSegment.getComponents(j);
290                    break;
291                }
292            }
293            if (frameComponent == null) {
294                throw new ImageReadException("Invalid component");
295            }
296            final Block fullBlock = new Block(
297                    8 * frameComponent.horizontalSamplingFactor,
298                    8 * frameComponent.verticalSamplingFactor);
299            mcu[i] = fullBlock;
300        }
301        return mcu;
302    }
303
304    private void readMCU(final JpegInputStream is, final int[] preds, final Block[] mcu)
305            throws IOException, ImageReadException {
306        for (int i = 0; i < sosSegment.numberOfComponents; i++) {
307            final SosSegment.Component scanComponent = sosSegment.getComponents(i);
308            SofnSegment.Component frameComponent = null;
309            for (int j = 0; j < sofnSegment.numberOfComponents; j++) {
310                if (sofnSegment.getComponents(j).componentIdentifier == scanComponent.scanComponentSelector) {
311                    frameComponent = sofnSegment.getComponents(j);
312                    break;
313                }
314            }
315            if (frameComponent == null) {
316                throw new ImageReadException("Invalid component");
317            }
318            final Block fullBlock = mcu[i];
319            for (int y = 0; y < frameComponent.verticalSamplingFactor; y++) {
320                for (int x = 0; x < frameComponent.horizontalSamplingFactor; x++) {
321                    Arrays.fill(zz, 0);
322                    // page 104 of T.81
323                    final int t = decode(
324                            is,
325                            huffmanDCTables[scanComponent.dcCodingTableSelector]);
326                    int diff = receive(t, is);
327                    diff = extend(diff, t);
328                    zz[0] = preds[i] + diff;
329                    preds[i] = zz[0];
330
331                    // "Decode_AC_coefficients", figure F.13, page 106 of T.81
332                    int k = 1;
333                    while (true) {
334                        final int rs = decode(
335                                is,
336                                huffmanACTables[scanComponent.acCodingTableSelector]);
337                        final int ssss = rs & 0xf;
338                        final int rrrr = rs >> 4;
339                        final int r = rrrr;
340
341                        if (ssss == 0) {
342                            if (r == 15) {
343                                k += 16;
344                            } else {
345                                break;
346                            }
347                        } else {
348                            k += r;
349
350                            // "Decode_ZZ(k)", figure F.14, page 107 of T.81
351                            zz[k] = receive(ssss, is);
352                            zz[k] = extend(zz[k], ssss);
353
354                            if (k == 63) {
355                                break;
356                            } else {
357                                k++;
358                            }
359                        }
360                    }
361
362                    final int shift = (1 << (sofnSegment.precision - 1));
363                    final int max = (1 << sofnSegment.precision) - 1;
364
365                    final float[] scaledQuantizationTable = scaledQuantizationTables[frameComponent.quantTabDestSelector];
366                    ZigZag.zigZagToBlock(zz, blockInt);
367                    for (int j = 0; j < 64; j++) {
368                        block[j] = blockInt[j] * scaledQuantizationTable[j];
369                    }
370                    Dct.inverseDCT8x8(block);
371
372                    int dstRowOffset = 8 * y * 8
373                            * frameComponent.horizontalSamplingFactor + 8 * x;
374                    int srcNext = 0;
375                    for (int yy = 0; yy < 8; yy++) {
376                        for (int xx = 0; xx < 8; xx++) {
377                            float sample = block[srcNext++];
378                            sample += shift;
379                            int result;
380                            if (sample < 0) {
381                                result = 0;
382                            } else if (sample > max) {
383                                result = max;
384                            } else {
385                                result = fastRound(sample);
386                            }
387                            fullBlock.samples[dstRowOffset + xx] = result;
388                        }
389                        dstRowOffset += 8 * frameComponent.horizontalSamplingFactor;
390                    }
391                }
392            }
393        }
394    }
395
396    private static int fastRound(final float x) {
397        return (int) (x + 0.5f);
398    }
399
400    private int extend(int v, final int t) {
401        // "EXTEND", section F.2.2.1, figure F.12, page 105 of T.81
402        int vt = (1 << (t - 1));
403        if (v < vt) {
404            vt = (-1 << t) + 1;
405            v += vt;
406        }
407        return v;
408    }
409
410    private int receive(final int ssss, final JpegInputStream is) throws IOException,
411            ImageReadException {
412        // "RECEIVE", section F.2.2.4, figure F.17, page 110 of T.81
413        int i = 0;
414        int v = 0;
415        while (i != ssss) {
416            i++;
417            v = (v << 1) + is.nextBit();
418        }
419        return v;
420    }
421
422    private int decode(final JpegInputStream is, final DhtSegment.HuffmanTable huffmanTable)
423            throws IOException, ImageReadException {
424        // "DECODE", section F.2.2.3, figure F.16, page 109 of T.81
425        int i = 1;
426        int code = is.nextBit();
427        while (code > huffmanTable.getMaxCode(i)) {
428            i++;
429            code = (code << 1) | is.nextBit();
430        }
431        int j = huffmanTable.getValPtr(i);
432        j += code - huffmanTable.getMinCode(i);
433        return huffmanTable.getHuffVal(j);
434    }
435
436    public BufferedImage decode(final ByteSource byteSource) throws IOException,
437            ImageReadException {
438        final JpegUtils jpegUtils = new JpegUtils();
439        jpegUtils.traverseJFIF(byteSource, this);
440        if (imageReadException != null) {
441            throw imageReadException;
442        }
443        if (ioException != null) {
444            throw ioException;
445        }
446        return image;
447    }
448}