001package co.codewizards.cloudstore.core.util;
002
003import static java.nio.charset.StandardCharsets.*;
004import static java.util.Objects.*;
005
006/**
007 * Standard
008 * <a href="http://en.wikipedia.org/wiki/Base64">base64url</a>
009 * (<a href="http://en.wikipedia.org/wiki/Base64#RFC_4648">RFC 4648</a>) encoding and decoding.
010 * <p>
011 * The difference to normal base64 is that base64url replaces '+' by '-' and '/' by '_' in order to make the
012 * encoded string usable in URLs and file names without any escaping.
013 * <p>
014 * This class is an adjusted version of {@code org.apache.commons.codec.binary.Base64}. Thanks to the Apache
015 * Software Foundation!
016 */
017public class Base64Url {
018
019    /**
020     * The base length.
021     */
022    static final int BASELENGTH = 255;
023
024    /**
025     * Lookup length.
026     */
027    static final int LOOKUPLENGTH = 64;
028
029    /**
030     * Used to calculate the number of bits in a byte.
031     */
032    static final int EIGHTBIT = 8;
033
034    /**
035     * Used when encoding something which has fewer than 24 bits.
036     */
037    static final int SIXTEENBIT = 16;
038
039    /**
040     * Used to determine how many bits data contains.
041     */
042    static final int TWENTYFOURBITGROUP = 24;
043
044    /**
045     * Used to get the number of Quadruples.
046     */
047    static final int FOURBYTE = 4;
048
049    /**
050     * Used to test the sign of a byte.
051     */
052    static final int SIGN = -128;
053
054    /**
055     * Byte used to pad output. We discard this and fill it up as needed. This is not passed to the outside!
056     */
057    static final byte PAD = 0;
058
059    // Create arrays to hold the base64 characters and a
060    // lookup for base64 chars
061    private static byte[] base64Alphabet = new byte[BASELENGTH];
062    private static byte[] lookUpBase64Alphabet = new byte[LOOKUPLENGTH];
063
064    // Populating the lookup and character arrays
065    static {
066        for (int i = 0; i < BASELENGTH; i++) {
067            base64Alphabet[i] = (byte) -1;
068        }
069        for (int i = 'Z'; i >= 'A'; i--) {
070            base64Alphabet[i] = (byte) (i - 'A');
071        }
072        for (int i = 'z'; i >= 'a'; i--) {
073            base64Alphabet[i] = (byte) (i - 'a' + 26);
074        }
075        for (int i = '9'; i >= '0'; i--) {
076            base64Alphabet[i] = (byte) (i - '0' + 52);
077        }
078
079        base64Alphabet['-'] = 62;
080        base64Alphabet['_'] = 63;
081
082        for (int i = 0; i <= 25; i++) {
083            lookUpBase64Alphabet[i] = (byte) ('A' + i);
084        }
085
086        for (int i = 26, j = 0; i <= 51; i++, j++) {
087            lookUpBase64Alphabet[i] = (byte) ('a' + j);
088        }
089
090        for (int i = 52, j = 0; i <= 61; i++, j++) {
091            lookUpBase64Alphabet[i] = (byte) ('0' + j);
092        }
093
094        lookUpBase64Alphabet[62] = (byte) '-';
095        lookUpBase64Alphabet[63] = (byte) '_';
096    }
097
098    private static boolean isBase64(final byte octect) {
099        if (octect == PAD) {
100            return true;
101        } else if (base64Alphabet[octect] == -1) {
102            return false;
103        } else {
104            return true;
105        }
106    }
107
108    /**
109     * Tests a given byte array to see if it contains
110     * only valid characters within the Base64 alphabet.
111     *
112     * @param arrayOctect byte array to test
113     * @return true if all bytes are valid characters in the Base64
114     *         alphabet or if the byte array is empty; false, otherwise
115     */
116    public static boolean isArrayByteBase64(byte[] arrayOctect) {
117
118        arrayOctect = discardWhitespace(arrayOctect);
119
120        final int length = arrayOctect.length;
121        if (length == 0) {
122            // shouldn't a 0 length array be valid base64 data?
123            // return false;
124            return true;
125        }
126        for (int i = 0; i < length; i++) {
127            if (!isBase64(arrayOctect[i])) {
128                return false;
129            }
130        }
131        return true;
132    }
133
134    private static byte[] appendPaddingIfNeeded(final byte[] in) {
135        final int remainder = in.length % 4;
136        if (remainder == 0)
137                return in;
138
139        final int missing = 4 - remainder;
140        final byte[] out = new byte[in.length + missing];
141        System.arraycopy(in, 0, out, 0, in.length);
142        return out;
143    }
144
145    private static byte[] discardPaddingIfNeeded(final byte[] in) {
146        if (in[in.length - 1] != PAD)
147                return in;
148
149        int padQty = 1;
150        while (in[in.length - padQty - 1] == PAD)
151                ++padQty;
152
153        final byte[] out = new byte[in.length - padQty];
154        System.arraycopy(in, 0, out, 0, out.length);
155        return out;
156    }
157
158    public static String encodeBase64ToString(final byte[] binaryData) {
159        requireNonNull(binaryData, "binaryData");
160        return new String(encodeBase64(binaryData), US_ASCII);
161    }
162
163    /**
164     * Encodes binary data using the base64url algorithm.
165     *
166     * @param binaryData Array containing binary data to encode.
167     * @return Base64-encoded data.
168     */
169    public static byte[] encodeBase64(final byte[] binaryData) {
170        requireNonNull(binaryData, "binaryData");
171        final int lengthDataBits = binaryData.length * EIGHTBIT;
172        final int fewerThan24bits = lengthDataBits % TWENTYFOURBITGROUP;
173        final int numberTriplets = lengthDataBits / TWENTYFOURBITGROUP;
174        byte encodedData[] = null;
175        int encodedDataLength = 0;
176
177        if (fewerThan24bits != 0) {
178            //data not divisible by 24 bit
179            encodedDataLength = (numberTriplets + 1) * 4;
180        } else {
181            // 16 or 8 bit
182            encodedDataLength = numberTriplets * 4;
183        }
184
185        encodedData = new byte[encodedDataLength];
186
187        byte k = 0, l = 0, b1 = 0, b2 = 0, b3 = 0;
188
189        int encodedIndex = 0;
190        int dataIndex = 0;
191        int i = 0;
192
193        //log.debug("number of triplets = " + numberTriplets);
194        for (i = 0; i < numberTriplets; i++) {
195            dataIndex = i * 3;
196            b1 = binaryData[dataIndex];
197            b2 = binaryData[dataIndex + 1];
198            b3 = binaryData[dataIndex + 2];
199
200            //log.debug("b1= " + b1 +", b2= " + b2 + ", b3= " + b3);
201
202            l = (byte) (b2 & 0x0f);
203            k = (byte) (b1 & 0x03);
204
205            final byte val1 =
206                ((b1 & SIGN) == 0) ? (byte) (b1 >> 2) : (byte) ((b1) >> 2 ^ 0xc0);
207            final byte val2 =
208                ((b2 & SIGN) == 0) ? (byte) (b2 >> 4) : (byte) ((b2) >> 4 ^ 0xf0);
209            final byte val3 =
210                ((b3 & SIGN) == 0) ? (byte) (b3 >> 6) : (byte) ((b3) >> 6 ^ 0xfc);
211
212            encodedData[encodedIndex] = lookUpBase64Alphabet[val1];
213            //log.debug( "val2 = " + val2 );
214            //log.debug( "k4   = " + (k<<4) );
215            //log.debug(  "vak  = " + (val2 | (k<<4)) );
216            encodedData[encodedIndex + 1] =
217                lookUpBase64Alphabet[val2 | (k << 4)];
218            encodedData[encodedIndex + 2] =
219                lookUpBase64Alphabet[(l << 2) | val3];
220            encodedData[encodedIndex + 3] = lookUpBase64Alphabet[b3 & 0x3f];
221
222            encodedIndex += 4;
223        }
224
225        // form integral number of 6-bit groups
226        dataIndex = i * 3;
227
228        if (fewerThan24bits == EIGHTBIT) {
229            b1 = binaryData[dataIndex];
230            k = (byte) (b1 & 0x03);
231            //log.debug("b1=" + b1);
232            //log.debug("b1<<2 = " + (b1>>2) );
233            final byte val1 =
234                ((b1 & SIGN) == 0) ? (byte) (b1 >> 2) : (byte) ((b1) >> 2 ^ 0xc0);
235            encodedData[encodedIndex] = lookUpBase64Alphabet[val1];
236            encodedData[encodedIndex + 1] = lookUpBase64Alphabet[k << 4];
237            encodedData[encodedIndex + 2] = PAD;
238            encodedData[encodedIndex + 3] = PAD;
239        } else if (fewerThan24bits == SIXTEENBIT) {
240
241            b1 = binaryData[dataIndex];
242            b2 = binaryData[dataIndex + 1];
243            l = (byte) (b2 & 0x0f);
244            k = (byte) (b1 & 0x03);
245
246            final byte val1 =
247                ((b1 & SIGN) == 0) ? (byte) (b1 >> 2) : (byte) ((b1) >> 2 ^ 0xc0);
248            final byte val2 =
249                ((b2 & SIGN) == 0) ? (byte) (b2 >> 4) : (byte) ((b2) >> 4 ^ 0xf0);
250
251            encodedData[encodedIndex] = lookUpBase64Alphabet[val1];
252            encodedData[encodedIndex + 1] =
253                lookUpBase64Alphabet[val2 | (k << 4)];
254            encodedData[encodedIndex + 2] = lookUpBase64Alphabet[l << 2];
255            encodedData[encodedIndex + 3] = PAD;
256        }
257
258        return discardPaddingIfNeeded(encodedData);
259    }
260
261    public static byte[] decodeBase64FromString(final String base64String) {
262        final byte[] base64Data = requireNonNull(base64String, "base64String").getBytes(US_ASCII);
263        return decodeBase64(base64Data);
264    }
265
266    /**
267     * Decodes Base64 data into octects
268     *
269     * @param base64Data Byte array containing Base64 data
270     * @return Array containing decoded data.
271     */
272    public static byte[] decodeBase64(final byte[] base64Data) {
273        requireNonNull(base64Data, "base64Data");
274        // RFC 2045 requires that we discard ALL non-Base64 characters
275        return _decodeBase64(appendPaddingIfNeeded(discardNonBase64(base64Data)));
276    }
277    private static byte[] _decodeBase64(final byte[] base64Data) {
278        // handle the edge case, so we don't have to worry about it later
279        if (base64Data.length == 0) {
280            return new byte[0];
281        }
282
283        final int numberQuadruple = base64Data.length / FOURBYTE;
284        byte decodedData[] = null;
285        byte b1 = 0, b2 = 0, b3 = 0, b4 = 0, marker0 = 0, marker1 = 0;
286
287        // Throw away anything not in base64Data
288
289        int encodedIndex = 0;
290        int dataIndex = 0;
291        {
292            // this sizes the output array properly - rlw
293            int lastData = base64Data.length;
294            // ignore the '=' padding
295            while (base64Data[lastData - 1] == PAD) {
296                if (--lastData == 0) {
297                    return new byte[0];
298                }
299            }
300            decodedData = new byte[lastData - numberQuadruple];
301        }
302
303        for (int i = 0; i < numberQuadruple; i++) {
304            dataIndex = i * 4;
305            marker0 = base64Data[dataIndex + 2];
306            marker1 = base64Data[dataIndex + 3];
307
308            b1 = base64Alphabet[base64Data[dataIndex]];
309            b2 = base64Alphabet[base64Data[dataIndex + 1]];
310
311            if (marker0 != PAD && marker1 != PAD) {
312                //No PAD e.g 3cQl
313                b3 = base64Alphabet[marker0];
314                b4 = base64Alphabet[marker1];
315
316                decodedData[encodedIndex] = (byte) (b1 << 2 | b2 >> 4);
317                decodedData[encodedIndex + 1] =
318                    (byte) (((b2 & 0xf) << 4) | ((b3 >> 2) & 0xf));
319                decodedData[encodedIndex + 2] = (byte) (b3 << 6 | b4);
320            } else if (marker0 == PAD) {
321                //Two PAD e.g. 3c[Pad][Pad]
322                decodedData[encodedIndex] = (byte) (b1 << 2 | b2 >> 4);
323            } else if (marker1 == PAD) {
324                //One PAD e.g. 3cQ[Pad]
325                b3 = base64Alphabet[marker0];
326
327                decodedData[encodedIndex] = (byte) (b1 << 2 | b2 >> 4);
328                decodedData[encodedIndex + 1] =
329                    (byte) (((b2 & 0xf) << 4) | ((b3 >> 2) & 0xf));
330            }
331            encodedIndex += 3;
332        }
333        return decodedData;
334    }
335
336    /**
337     * Discards any whitespace from a base-64 encoded block.
338     *
339     * @param data The base-64 encoded data to discard the whitespace
340     * from.
341     * @return The data, less whitespace (see RFC 2045).
342     */
343    static byte[] discardWhitespace(final byte[] data) {
344        final byte groomedData[] = new byte[data.length];
345        int bytesCopied = 0;
346
347        for (int i = 0; i < data.length; i++) {
348            switch (data[i]) {
349            case (byte) ' ' :
350            case (byte) '\n' :
351            case (byte) '\r' :
352            case (byte) '\t' :
353                    break;
354            default:
355                    groomedData[bytesCopied++] = data[i];
356            }
357        }
358
359        final byte packedData[] = new byte[bytesCopied];
360
361        System.arraycopy(groomedData, 0, packedData, 0, bytesCopied);
362
363        return packedData;
364    }
365
366    /**
367     * Discards any characters outside of the base64 alphabet, per
368     * the requirements on page 25 of RFC 2045 - "Any characters
369     * outside of the base64 alphabet are to be ignored in base64
370     * encoded data."
371     *
372     * @param data The base-64 encoded data to groom
373     * @return The data, less non-base64 characters (see RFC 2045).
374     */
375    static byte[] discardNonBase64(final byte[] data) {
376        final byte groomedData[] = new byte[data.length];
377        int bytesCopied = 0;
378
379        for (int i = 0; i < data.length; i++) {
380            if (isBase64(data[i])) {
381                groomedData[bytesCopied++] = data[i];
382            }
383        }
384
385        final byte packedData[] = new byte[bytesCopied];
386
387        System.arraycopy(groomedData, 0, packedData, 0, bytesCopied);
388
389        return packedData;
390    }
391
392}