001/** 002 * 003 * Copyright 2014-2017 Florian Schmaus 004 * 005 * Licensed under the Apache License, Version 2.0 (the "License"); 006 * you may not use this file except in compliance with the License. 007 * You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.jivesoftware.smack.sasl.core; 018 019import java.io.UnsupportedEncodingException; 020import java.security.InvalidKeyException; 021import java.security.SecureRandom; 022import java.util.Collections; 023import java.util.HashMap; 024import java.util.Map; 025import java.util.Random; 026 027import javax.security.auth.callback.CallbackHandler; 028 029import org.jivesoftware.smack.SmackException; 030import org.jivesoftware.smack.sasl.SASLMechanism; 031import org.jivesoftware.smack.util.ByteUtils; 032import org.jivesoftware.smack.util.SHA1; 033import org.jivesoftware.smack.util.StringUtils; 034import org.jivesoftware.smack.util.stringencoder.Base64; 035import org.jxmpp.util.cache.Cache; 036import org.jxmpp.util.cache.LruCache; 037 038public abstract class ScramMechanism extends SASLMechanism { 039 040 private static final int RANDOM_ASCII_BYTE_COUNT = 32; 041 private static final byte[] CLIENT_KEY_BYTES = toBytes("Client Key"); 042 private static final byte[] SERVER_KEY_BYTES = toBytes("Server Key"); 043 private static final byte[] ONE = new byte[] { 0, 0, 0, 1 }; 044 045 private static final ThreadLocal<SecureRandom> SECURE_RANDOM = new ThreadLocal<SecureRandom>() { 046 @Override 047 protected SecureRandom initialValue() { 048 return new SecureRandom(); 049 } 050 }; 051 052 private static final Cache<String, Keys> CACHE = new LruCache<String, Keys>(10); 053 054 private final ScramHmac scramHmac; 055 056 protected ScramMechanism(ScramHmac scramHmac) { 057 this.scramHmac = scramHmac; 058 } 059 060 private enum State { 061 INITIAL, 062 AUTH_TEXT_SENT, 063 RESPONSE_SENT, 064 VALID_SERVER_RESPONSE, 065 } 066 067 /** 068 * The state of the this instance of SASL SCRAM-SHA1 authentication. 069 */ 070 private State state = State.INITIAL; 071 072 /** 073 * The client's random ASCII which is used as nonce 074 */ 075 private String clientRandomAscii; 076 077 private String clientFirstMessageBare; 078 private byte[] serverSignature; 079 080 @Override 081 protected void authenticateInternal(CallbackHandler cbh) throws SmackException { 082 throw new UnsupportedOperationException("CallbackHandler not (yet) supported"); 083 } 084 085 @Override 086 protected byte[] getAuthenticationText() throws SmackException { 087 clientRandomAscii = getRandomAscii(); 088 String saslPrepedAuthcId = saslPrep(authenticationId); 089 clientFirstMessageBare = "n=" + escape(saslPrepedAuthcId) + ",r=" + clientRandomAscii; 090 String clientFirstMessage = getGS2Header() + clientFirstMessageBare; 091 state = State.AUTH_TEXT_SENT; 092 return toBytes(clientFirstMessage); 093 } 094 095 @Override 096 public String getName() { 097 String name = "SCRAM-" + scramHmac.getHmacName(); 098 return name; 099 } 100 101 @Override 102 public void checkIfSuccessfulOrThrow() throws SmackException { 103 if (state != State.VALID_SERVER_RESPONSE) { 104 throw new SmackException("SCRAM-SHA1 is missing valid server response"); 105 } 106 } 107 108 @Override 109 public boolean authzidSupported() { 110 return true; 111 } 112 113 @Override 114 protected byte[] evaluateChallenge(byte[] challenge) throws SmackException { 115 String challengeString; 116 try { 117 // TODO: Where is it specified that this is an UTF-8 encoded string? 118 challengeString = new String(challenge, StringUtils.UTF8); 119 } 120 catch (UnsupportedEncodingException e) { 121 throw new AssertionError(e); 122 } 123 124 switch (state) { 125 case AUTH_TEXT_SENT: 126 final String serverFirstMessage = challengeString; 127 Map<Character, String> attributes = parseAttributes(challengeString); 128 129 // Handle server random ASCII (nonce) 130 String rvalue = attributes.get('r'); 131 if (rvalue == null) { 132 throw new SmackException("Server random ASCII is null"); 133 } 134 if (rvalue.length() <= clientRandomAscii.length()) { 135 throw new SmackException("Server random ASCII is shorter then client random ASCII"); 136 } 137 String receivedClientRandomAscii = rvalue.substring(0, clientRandomAscii.length()); 138 if (!receivedClientRandomAscii.equals(clientRandomAscii)) { 139 throw new SmackException("Received client random ASCII does not match client random ASCII"); 140 } 141 142 // Handle iterations 143 int iterations; 144 String iterationsString = attributes.get('i'); 145 if (iterationsString == null) { 146 throw new SmackException("Iterations attribute not set"); 147 } 148 try { 149 iterations = Integer.parseInt(iterationsString); 150 } 151 catch (NumberFormatException e) { 152 throw new SmackException("Exception parsing iterations", e); 153 } 154 155 // Handle salt 156 String salt = attributes.get('s'); 157 if (salt == null) { 158 throw new SmackException("SALT not send"); 159 } 160 161 // Parsing and error checking is done, we can now begin to calculate the values 162 163 // First the client-final-message-without-proof 164 String channelBinding = "c=" + Base64.encodeToString(getCBindInput()); 165 String clientFinalMessageWithoutProof = channelBinding + ",r=" + rvalue; 166 167 // AuthMessage := client-first-message-bare + "," + server-first-message + "," + 168 // client-final-message-without-proof 169 byte[] authMessage = toBytes(clientFirstMessageBare + ',' + serverFirstMessage + ',' 170 + clientFinalMessageWithoutProof); 171 172 // RFC 5802 § 5.1 "Note that a client implementation MAY cache ClientKey&ServerKey … for later reauthentication … 173 // as it is likely that the server is going to advertise the same salt value upon reauthentication." 174 // Note that we also mangle the mechanism's name into the cache key, since the cache is used by multiple 175 // mechanisms. 176 final String cacheKey = password + ',' + salt + ',' + getName(); 177 byte[] serverKey, clientKey; 178 Keys keys = CACHE.lookup(cacheKey); 179 if (keys == null) { 180 // SaltedPassword := Hi(Normalize(password), salt, i) 181 byte[] saltedPassword = hi(saslPrep(password), Base64.decode(salt), iterations); 182 183 // ServerKey := HMAC(SaltedPassword, "Server Key") 184 serverKey = hmac(saltedPassword, SERVER_KEY_BYTES); 185 186 // ClientKey := HMAC(SaltedPassword, "Client Key") 187 clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES); 188 189 keys = new Keys(clientKey, serverKey); 190 CACHE.put(cacheKey, keys); 191 } 192 else { 193 serverKey = keys.serverKey; 194 clientKey = keys.clientKey; 195 } 196 197 // ServerSignature := HMAC(ServerKey, AuthMessage) 198 serverSignature = hmac(serverKey, authMessage); 199 200 // StoredKey := H(ClientKey) 201 byte[] storedKey = SHA1.bytes(clientKey); 202 203 // ClientSignature := HMAC(StoredKey, AuthMessage) 204 byte[] clientSignature = hmac(storedKey, authMessage); 205 206 // ClientProof := ClientKey XOR ClientSignature 207 byte[] clientProof = new byte[clientKey.length]; 208 for (int i = 0; i < clientProof.length; i++) { 209 clientProof[i] = (byte) (clientKey[i] ^ clientSignature[i]); 210 } 211 212 String clientFinalMessage = clientFinalMessageWithoutProof + ",p=" + Base64.encodeToString(clientProof); 213 state = State.RESPONSE_SENT; 214 return toBytes(clientFinalMessage); 215 case RESPONSE_SENT: 216 String clientCalculatedServerFinalMessage = "v=" + Base64.encodeToString(serverSignature); 217 if (!clientCalculatedServerFinalMessage.equals(challengeString)) { 218 throw new SmackException("Server final message does not match calculated one"); 219 } 220 state = State.VALID_SERVER_RESPONSE; 221 break; 222 default: 223 throw new SmackException("Invalid state"); 224 } 225 return null; 226 } 227 228 private final String getGS2Header() { 229 String authzidPortion = ""; 230 if (authorizationId != null) { 231 authzidPortion = "a=" + authorizationId; 232 } 233 234 String cbName = getChannelBindingName(); 235 assert(StringUtils.isNotEmpty(cbName)); 236 237 return cbName + ',' + authzidPortion + ","; 238 } 239 240 private final byte[] getCBindInput() throws SmackException { 241 byte[] cbindData = getChannelBindingData(); 242 byte[] gs2Header = toBytes(getGS2Header()); 243 244 if (cbindData == null) { 245 return gs2Header; 246 } 247 248 return ByteUtils.concact(gs2Header, cbindData); 249 } 250 251 protected String getChannelBindingName() { 252 if (sslSession != null && connectionConfiguration.isEnabledSaslMechanism(getName() + "-PLUS")) { 253 // Announce that we support Channel Binding, i.e., the '-PLUS' flavor of this SASL mechanism, but that we 254 // believe the server does not. 255 return "y"; 256 } 257 return "n"; 258 } 259 260 /** 261 * 262 * @return the Channel Binding data. 263 * @throws SmackException 264 */ 265 protected byte[] getChannelBindingData() throws SmackException { 266 return null; 267 } 268 269 private static Map<Character, String> parseAttributes(String string) throws SmackException { 270 if (string.length() == 0) { 271 return Collections.emptyMap(); 272 } 273 274 String[] keyValuePairs = string.split(","); 275 Map<Character, String> res = new HashMap<Character, String>(keyValuePairs.length, 1); 276 for (String keyValuePair : keyValuePairs) { 277 if (keyValuePair.length() < 3) { 278 throw new SmackException("Invalid Key-Value pair: " + keyValuePair); 279 } 280 char key = keyValuePair.charAt(0); 281 if (keyValuePair.charAt(1) != '=') { 282 throw new SmackException("Invalid Key-Value pair: " + keyValuePair); 283 } 284 String value = keyValuePair.substring(2); 285 res.put(key, value); 286 } 287 288 return res; 289 } 290 291 /** 292 * Generate random ASCII. 293 * <p> 294 * This method is non-static and package-private for unit testing purposes. 295 * </p> 296 * @return A String of 32 random printable ASCII characters. 297 */ 298 String getRandomAscii() { 299 int count = 0; 300 char[] randomAscii = new char[RANDOM_ASCII_BYTE_COUNT]; 301 final Random random = SECURE_RANDOM.get(); 302 while (count < RANDOM_ASCII_BYTE_COUNT) { 303 int r = random.nextInt(128); 304 char c = (char) r; 305 // RFC 5802 § 5.1 specifies 'r:' to exclude the ',' character and to be only printable ASCII characters 306 if (!isPrintableNonCommaAsciiChar(c)) { 307 continue; 308 } 309 randomAscii[count++] = c; 310 } 311 return new String(randomAscii); 312 } 313 314 private static boolean isPrintableNonCommaAsciiChar(char c) { 315 if (c == ',') { 316 return false; 317 } 318 // RFC 5802 § 7. 'printable': Contains all chars within 0x21 (33d) to 0x2b (43d) and 0x2d (45d) to 0x7e (126) 319 // aka. "Printable ASCII except ','". Since we already filter the ASCII ',' (0x2c, 44d) above, we only have to 320 // ensure that c is within [33, 126]. 321 return c > 32 && c < 127; 322 } 323 324 /** 325 * Escapes usernames or passwords for SASL SCRAM-SHA1. 326 * <p> 327 * According to RFC 5802 § 5.1 'n:' 328 * "The characters ',' or '=' in usernames are sent as '=2C' and '=3D' respectively." 329 * </p> 330 * 331 * @param string 332 * @return the escaped string 333 */ 334 private static String escape(String string) { 335 StringBuilder sb = new StringBuilder((int) (string.length() * 1.1)); 336 for (int i = 0; i < string.length(); i++) { 337 char c = string.charAt(i); 338 switch (c) { 339 case ',': 340 sb.append("=2C"); 341 break; 342 case '=': 343 sb.append("=3D"); 344 break; 345 default: 346 sb.append(c); 347 break; 348 } 349 } 350 return sb.toString(); 351 } 352 353 /** 354 * RFC 5802 § 2.2 HMAC(key, str) 355 * 356 * @param key 357 * @param str 358 * @return the HMAC-SHA1 value of the input. 359 * @throws SmackException 360 */ 361 private byte[] hmac(byte[] key, byte[] str) throws SmackException { 362 try { 363 return scramHmac.hmac(key, str); 364 } 365 catch (InvalidKeyException e) { 366 throw new SmackException(getName() + " Exception", e); 367 } 368 } 369 370 /** 371 * RFC 5802 § 2.2 Hi(str, salt, i) 372 * <p> 373 * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the pseudorandom function 374 * (PRF) and with dkLen == output length of HMAC() == output length of H(). 375 * </p> 376 * 377 * @param normalizedPassword the normalized password. 378 * @param salt 379 * @param iterations 380 * @return the result of the Hi function. 381 * @throws SmackException 382 */ 383 private byte[] hi(String normalizedPassword, byte[] salt, int iterations) throws SmackException { 384 byte[] key; 385 try { 386 // According to RFC 5802 § 2.2, the resulting string of the normalization is also in UTF-8. 387 key = normalizedPassword.getBytes(StringUtils.UTF8); 388 } 389 catch (UnsupportedEncodingException e) { 390 throw new AssertionError(); 391 } 392 // U1 := HMAC(str, salt + INT(1)) 393 byte[] u = hmac(key, ByteUtils.concact(salt, ONE)); 394 byte[] res = u.clone(); 395 for (int i = 1; i < iterations; i++) { 396 u = hmac(key, u); 397 for (int j = 0; j < u.length; j++) { 398 res[j] ^= u[j]; 399 } 400 } 401 return res; 402 } 403 404 private static class Keys { 405 private final byte[] clientKey; 406 private final byte[] serverKey; 407 408 public Keys(byte[] clientKey, byte[] serverKey) { 409 this.clientKey = clientKey; 410 this.serverKey = serverKey; 411 } 412 } 413}