001/**
002 *
003 * Copyright 2014-2017 Florian Schmaus
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 *     http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.jivesoftware.smack.sasl.core;
018
019import java.io.UnsupportedEncodingException;
020import java.security.InvalidKeyException;
021import java.security.SecureRandom;
022import java.util.Collections;
023import java.util.HashMap;
024import java.util.Map;
025import java.util.Random;
026
027import javax.security.auth.callback.CallbackHandler;
028
029import org.jivesoftware.smack.SmackException;
030import org.jivesoftware.smack.sasl.SASLMechanism;
031import org.jivesoftware.smack.util.ByteUtils;
032import org.jivesoftware.smack.util.SHA1;
033import org.jivesoftware.smack.util.StringUtils;
034import org.jivesoftware.smack.util.stringencoder.Base64;
035import org.jxmpp.util.cache.Cache;
036import org.jxmpp.util.cache.LruCache;
037
038public abstract class ScramMechanism extends SASLMechanism {
039
040    private static final int RANDOM_ASCII_BYTE_COUNT = 32;
041    private static final byte[] CLIENT_KEY_BYTES = toBytes("Client Key");
042    private static final byte[] SERVER_KEY_BYTES = toBytes("Server Key");
043    private static final byte[] ONE = new byte[] { 0, 0, 0, 1 };
044
045    private static final ThreadLocal<SecureRandom> SECURE_RANDOM = new ThreadLocal<SecureRandom>() {
046        @Override
047        protected SecureRandom initialValue() {
048            return new SecureRandom();
049        }
050    };
051
052    private static final Cache<String, Keys> CACHE = new LruCache<String, Keys>(10);
053
054    private final ScramHmac scramHmac;
055
056    protected ScramMechanism(ScramHmac scramHmac) {
057        this.scramHmac = scramHmac;
058    }
059
060    private enum State {
061        INITIAL,
062        AUTH_TEXT_SENT,
063        RESPONSE_SENT,
064        VALID_SERVER_RESPONSE,
065    }
066
067    /**
068     * The state of the this instance of SASL SCRAM-SHA1 authentication.
069     */
070    private State state = State.INITIAL;
071
072    /**
073     * The client's random ASCII which is used as nonce
074     */
075    private String clientRandomAscii;
076
077    private String clientFirstMessageBare;
078    private byte[] serverSignature;
079
080    @Override
081    protected void authenticateInternal(CallbackHandler cbh) throws SmackException {
082        throw new UnsupportedOperationException("CallbackHandler not (yet) supported");
083    }
084
085    @Override
086    protected byte[] getAuthenticationText() throws SmackException {
087        clientRandomAscii = getRandomAscii();
088        String saslPrepedAuthcId = saslPrep(authenticationId);
089        clientFirstMessageBare = "n=" + escape(saslPrepedAuthcId) + ",r=" + clientRandomAscii;
090        String clientFirstMessage = getGS2Header() + clientFirstMessageBare;
091        state = State.AUTH_TEXT_SENT;
092        return toBytes(clientFirstMessage);
093    }
094
095    @Override
096    public String getName() {
097        String name = "SCRAM-" + scramHmac.getHmacName();
098        return name;
099    }
100
101    @Override
102    public void checkIfSuccessfulOrThrow() throws SmackException {
103        if (state != State.VALID_SERVER_RESPONSE) {
104            throw new SmackException("SCRAM-SHA1 is missing valid server response");
105        }
106    }
107
108    @Override
109    public boolean authzidSupported() {
110        return true;
111    }
112
113    @Override
114    protected byte[] evaluateChallenge(byte[] challenge) throws SmackException {
115        String challengeString;
116        try {
117            // TODO: Where is it specified that this is an UTF-8 encoded string?
118            challengeString = new String(challenge, StringUtils.UTF8);
119        }
120        catch (UnsupportedEncodingException e) {
121            throw new AssertionError(e);
122        }
123
124        switch (state) {
125        case AUTH_TEXT_SENT:
126            final String serverFirstMessage = challengeString;
127            Map<Character, String> attributes = parseAttributes(challengeString);
128
129            // Handle server random ASCII (nonce)
130            String rvalue = attributes.get('r');
131            if (rvalue == null) {
132                throw new SmackException("Server random ASCII is null");
133            }
134            if (rvalue.length() <= clientRandomAscii.length()) {
135                throw new SmackException("Server random ASCII is shorter then client random ASCII");
136            }
137            String receivedClientRandomAscii = rvalue.substring(0, clientRandomAscii.length());
138            if (!receivedClientRandomAscii.equals(clientRandomAscii)) {
139                throw new SmackException("Received client random ASCII does not match client random ASCII");
140            }
141
142            // Handle iterations
143            int iterations;
144            String iterationsString = attributes.get('i');
145            if (iterationsString == null) {
146                throw new SmackException("Iterations attribute not set");
147            }
148            try {
149                iterations = Integer.parseInt(iterationsString);
150            }
151            catch (NumberFormatException e) {
152                throw new SmackException("Exception parsing iterations", e);
153            }
154
155            // Handle salt
156            String salt = attributes.get('s');
157            if (salt == null) {
158                throw new SmackException("SALT not send");
159            }
160
161            // Parsing and error checking is done, we can now begin to calculate the values
162
163            // First the client-final-message-without-proof
164            String channelBinding = "c=" + Base64.encodeToString(getCBindInput());
165            String clientFinalMessageWithoutProof = channelBinding + ",r=" + rvalue;
166
167            // AuthMessage := client-first-message-bare + "," + server-first-message + "," +
168            // client-final-message-without-proof
169            byte[] authMessage = toBytes(clientFirstMessageBare + ',' + serverFirstMessage + ','
170                            + clientFinalMessageWithoutProof);
171
172            // RFC 5802 § 5.1 "Note that a client implementation MAY cache ClientKey&ServerKey … for later reauthentication …
173            // as it is likely that the server is going to advertise the same salt value upon reauthentication."
174            // Note that we also mangle the mechanism's name into the cache key, since the cache is used by multiple
175            // mechanisms.
176            final String cacheKey = password + ',' + salt + ',' + getName();
177            byte[] serverKey, clientKey;
178            Keys keys = CACHE.lookup(cacheKey);
179            if (keys == null) {
180                // SaltedPassword := Hi(Normalize(password), salt, i)
181                byte[] saltedPassword = hi(saslPrep(password), Base64.decode(salt), iterations);
182
183                // ServerKey := HMAC(SaltedPassword, "Server Key")
184                serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
185
186                // ClientKey := HMAC(SaltedPassword, "Client Key")
187                clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
188
189                keys = new Keys(clientKey, serverKey);
190                CACHE.put(cacheKey, keys);
191            }
192            else {
193                serverKey = keys.serverKey;
194                clientKey = keys.clientKey;
195            }
196
197            // ServerSignature := HMAC(ServerKey, AuthMessage)
198            serverSignature = hmac(serverKey, authMessage);
199
200            // StoredKey := H(ClientKey)
201            byte[] storedKey = SHA1.bytes(clientKey);
202
203            // ClientSignature := HMAC(StoredKey, AuthMessage)
204            byte[] clientSignature = hmac(storedKey, authMessage);
205
206            // ClientProof := ClientKey XOR ClientSignature
207            byte[] clientProof = new byte[clientKey.length];
208            for (int i = 0; i < clientProof.length; i++) {
209                clientProof[i] = (byte) (clientKey[i] ^ clientSignature[i]);
210            }
211
212            String clientFinalMessage = clientFinalMessageWithoutProof + ",p=" + Base64.encodeToString(clientProof);
213            state = State.RESPONSE_SENT;
214            return toBytes(clientFinalMessage);
215        case RESPONSE_SENT:
216            String clientCalculatedServerFinalMessage = "v=" + Base64.encodeToString(serverSignature);
217            if (!clientCalculatedServerFinalMessage.equals(challengeString)) {
218                throw new SmackException("Server final message does not match calculated one");
219            }
220            state = State.VALID_SERVER_RESPONSE;
221            break;
222        default:
223            throw new SmackException("Invalid state");
224        }
225        return null;
226    }
227
228    private final String getGS2Header() {
229        String authzidPortion = "";
230        if (authorizationId != null) {
231            authzidPortion = "a=" + authorizationId;
232        }
233
234        String cbName = getChannelBindingName();
235        assert(StringUtils.isNotEmpty(cbName));
236
237        return cbName + ',' + authzidPortion + ",";
238    }
239
240    private final byte[] getCBindInput() throws SmackException {
241        byte[] cbindData = getChannelBindingData();
242        byte[] gs2Header = toBytes(getGS2Header());
243
244        if (cbindData == null) {
245            return gs2Header;
246        }
247
248        return ByteUtils.concact(gs2Header, cbindData);
249    }
250
251    protected String getChannelBindingName() {
252        if (sslSession != null && connectionConfiguration.isEnabledSaslMechanism(getName() + "-PLUS")) {
253            // Announce that we support Channel Binding, i.e., the '-PLUS' flavor of this SASL mechanism, but that we
254            // believe the server does not.
255            return "y";
256        }
257        return "n";
258    }
259
260    /**
261     * 
262     * @return the Channel Binding data.
263     * @throws SmackException
264     */
265    protected byte[] getChannelBindingData() throws SmackException {
266        return null;
267    }
268
269    private static Map<Character, String> parseAttributes(String string) throws SmackException {
270        if (string.length() == 0) {
271            return Collections.emptyMap();
272        }
273
274        String[] keyValuePairs = string.split(",");
275        Map<Character, String> res = new HashMap<Character, String>(keyValuePairs.length, 1);
276        for (String keyValuePair : keyValuePairs) {
277            if (keyValuePair.length() < 3) {
278                throw new SmackException("Invalid Key-Value pair: " + keyValuePair);
279            }
280            char key = keyValuePair.charAt(0);
281            if (keyValuePair.charAt(1) != '=') {
282                throw new SmackException("Invalid Key-Value pair: " + keyValuePair);
283            }
284            String value = keyValuePair.substring(2);
285            res.put(key, value);
286        }
287
288        return res;
289    }
290
291    /**
292     * Generate random ASCII.
293     * <p>
294     * This method is non-static and package-private for unit testing purposes.
295     * </p>
296     * @return A String of 32 random printable ASCII characters.
297     */
298    String getRandomAscii() {
299        int count = 0;
300        char[] randomAscii = new char[RANDOM_ASCII_BYTE_COUNT];
301        final Random random = SECURE_RANDOM.get();
302        while (count < RANDOM_ASCII_BYTE_COUNT) {
303            int r = random.nextInt(128);
304            char c = (char) r;
305            // RFC 5802 § 5.1 specifies 'r:' to exclude the ',' character and to be only printable ASCII characters
306            if (!isPrintableNonCommaAsciiChar(c)) {
307                continue;
308            }
309            randomAscii[count++] = c;
310        }
311        return new String(randomAscii);
312    }
313
314    private static boolean isPrintableNonCommaAsciiChar(char c) {
315        if (c == ',') {
316            return false;
317        }
318        // RFC 5802 § 7. 'printable': Contains all chars within 0x21 (33d) to 0x2b (43d) and 0x2d (45d) to 0x7e (126)
319        // aka. "Printable ASCII except ','". Since we already filter the ASCII ',' (0x2c, 44d) above, we only have to
320        // ensure that c is within [33, 126].
321        return c > 32 && c < 127;
322    }
323
324    /**
325     * Escapes usernames or passwords for SASL SCRAM-SHA1.
326     * <p>
327     * According to RFC 5802 § 5.1 'n:'
328     * "The characters ',' or '=' in usernames are sent as '=2C' and '=3D' respectively."
329     * </p>
330     *
331     * @param string
332     * @return the escaped string
333     */
334    private static String escape(String string) {
335        StringBuilder sb = new StringBuilder((int) (string.length() * 1.1));
336        for (int i = 0; i < string.length(); i++) {
337            char c = string.charAt(i);
338            switch (c) {
339            case ',':
340                sb.append("=2C");
341                break;
342            case '=':
343                sb.append("=3D");
344                break;
345            default:
346                sb.append(c);
347                break;
348            }
349        }
350        return sb.toString();
351    }
352
353    /**
354     * RFC 5802 § 2.2 HMAC(key, str)
355     * 
356     * @param key
357     * @param str
358     * @return the HMAC-SHA1 value of the input.
359     * @throws SmackException 
360     */
361    private byte[] hmac(byte[] key, byte[] str) throws SmackException {
362        try {
363            return scramHmac.hmac(key, str);
364        }
365        catch (InvalidKeyException e) {
366            throw new SmackException(getName() + " Exception", e);
367        }
368    }
369
370    /**
371     * RFC 5802 § 2.2 Hi(str, salt, i)
372     * <p>
373     * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the pseudorandom function
374     * (PRF) and with dkLen == output length of HMAC() == output length of H().
375     * </p>
376     * 
377     * @param normalizedPassword the normalized password.
378     * @param salt
379     * @param iterations
380     * @return the result of the Hi function.
381     * @throws SmackException 
382     */
383    private byte[] hi(String normalizedPassword, byte[] salt, int iterations) throws SmackException {
384        byte[] key;
385        try {
386            // According to RFC 5802 § 2.2, the resulting string of the normalization is also in UTF-8.
387            key = normalizedPassword.getBytes(StringUtils.UTF8);
388        }
389        catch (UnsupportedEncodingException e) {
390            throw new AssertionError();
391        }
392        // U1 := HMAC(str, salt + INT(1))
393        byte[] u = hmac(key, ByteUtils.concact(salt, ONE));
394        byte[] res = u.clone();
395        for (int i = 1; i < iterations; i++) {
396            u = hmac(key, u);
397            for (int j = 0; j < u.length; j++) {
398                res[j] ^= u[j];
399            }
400        }
401        return res;
402    }
403
404    private static class Keys {
405        private final byte[] clientKey;
406        private final byte[] serverKey;
407
408        public Keys(byte[] clientKey, byte[] serverKey) {
409            this.clientKey = clientKey;
410            this.serverKey = serverKey;
411        }
412    }
413}