001/**
002 *
003 * Copyright 2018 Florian Schmaus
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 *     http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.jivesoftware.smack.fsm;
018
019import java.io.IOException;
020import java.util.ArrayList;
021import java.util.Iterator;
022import java.util.List;
023import java.util.ListIterator;
024import java.util.concurrent.CopyOnWriteArrayList;
025import java.util.logging.Logger;
026
027import javax.net.ssl.SSLSession;
028
029import org.jivesoftware.smack.AbstractXMPPConnection;
030import org.jivesoftware.smack.ConnectionConfiguration;
031import org.jivesoftware.smack.SmackException;
032import org.jivesoftware.smack.SmackException.NoResponseException;
033import org.jivesoftware.smack.SmackException.NotConnectedException;
034import org.jivesoftware.smack.XMPPException;
035import org.jivesoftware.smack.XMPPException.FailedNonzaException;
036import org.jivesoftware.smack.XMPPException.StreamErrorException;
037import org.jivesoftware.smack.XMPPException.XMPPErrorException;
038import org.jivesoftware.smack.XmppInputOutputFilter;
039import org.jivesoftware.smack.compress.packet.Compress;
040import org.jivesoftware.smack.compress.packet.Compressed;
041import org.jivesoftware.smack.compress.packet.Failure;
042import org.jivesoftware.smack.compression.XmppCompressionFactory;
043import org.jivesoftware.smack.compression.XmppCompressionManager;
044import org.jivesoftware.smack.fsm.StateDescriptorGraph.GraphVertex;
045import org.jivesoftware.smack.packet.IQ;
046import org.jivesoftware.smack.packet.Message;
047import org.jivesoftware.smack.packet.Presence;
048import org.jivesoftware.smack.packet.StreamError;
049import org.jivesoftware.smack.sasl.SASLErrorException;
050import org.jivesoftware.smack.sasl.packet.SaslStreamElements.Challenge;
051import org.jivesoftware.smack.sasl.packet.SaslStreamElements.Success;
052import org.jivesoftware.smack.util.Objects;
053import org.jivesoftware.smack.util.PacketParserUtils;
054
055import org.jxmpp.jid.parts.Resourcepart;
056import org.xmlpull.v1.XmlPullParser;
057
058public abstract class AbstractXmppStateMachineConnection extends AbstractXMPPConnection {
059
060    private static final Logger LOGGER = Logger.getLogger(AbstractXmppStateMachineConnection.class.getName());
061
062    private boolean featuresReceived;
063
064    protected boolean streamResumed;
065
066    private GraphVertex<State> currentStateVertex;
067
068    private final List<XmppInputOutputFilter> inputOutputFilters = new CopyOnWriteArrayList<>();
069
070    protected AbstractXmppStateMachineConnection(ConnectionConfiguration configuration, GraphVertex<StateDescriptor> initialStateDescriptorVertex) {
071        super(configuration);
072        currentStateVertex = StateDescriptorGraph.convertToStateGraph(initialStateDescriptorVertex, this);
073    }
074
075    @Override
076    protected void loginInternal(String username, String password, Resourcepart resource)
077                    throws XMPPException, SmackException, IOException, InterruptedException {
078        WalkStateGraphContext walkStateGraphContext = buildNewWalkTo(AuthenticatedAndResourceBoundStateDescriptor.class)
079                        .withLoginContext(username, password, resource)
080                        .build();
081        walkStateGraph(walkStateGraphContext);
082    }
083
084    protected static WalkStateGraphContextBuilder buildNewWalkTo(Class<? extends StateDescriptor> finalStateClass) {
085        return new WalkStateGraphContextBuilder(finalStateClass);
086    }
087
088    protected static final class WalkStateGraphContext {
089        private final Class<? extends StateDescriptor> finalStateClass;
090        private final Class<? extends StateDescriptor> mandatoryIntermedidateState;
091        private final LoginContext loginContext;
092
093        private final List<StateDescriptor> walkedStateGraphPath = new ArrayList<>();
094
095        private boolean mandatoryIntermediateStateHandled;
096
097        private WalkStateGraphContext(Class<? extends StateDescriptor> finalStateClass, Class<? extends StateDescriptor> mandatoryIntermedidateState, LoginContext loginContext) {
098            this.finalStateClass = Objects.requireNonNull(finalStateClass);
099            this.mandatoryIntermedidateState = mandatoryIntermedidateState;
100            this.loginContext = loginContext;
101        }
102
103
104    }
105
106    protected static final class WalkStateGraphContextBuilder {
107        private final Class<? extends StateDescriptor> finalStateClass;
108        private Class<? extends StateDescriptor> mandatoryIntermedidateState;
109        private LoginContext loginContext;
110
111        private WalkStateGraphContextBuilder(Class<? extends StateDescriptor> finalStateClass) {
112            this.finalStateClass = finalStateClass;
113        }
114
115        public WalkStateGraphContextBuilder withMandatoryIntermediateState(Class<? extends StateDescriptor> mandatoryIntermedidateState) {
116            this.mandatoryIntermedidateState = mandatoryIntermedidateState;
117            return this;
118        }
119
120        public WalkStateGraphContextBuilder withLoginContext(String username, String password, Resourcepart resource) {
121            LoginContext loginContext = new LoginContext(username, password, resource);
122            return withLoginContext(loginContext);
123        }
124
125        public WalkStateGraphContextBuilder withLoginContext(LoginContext loginContext) {
126            this.loginContext = loginContext;
127            return this;
128        }
129
130        public WalkStateGraphContext build() {
131            return new WalkStateGraphContext(finalStateClass, mandatoryIntermedidateState, loginContext);
132        }
133    }
134
135    private TransitionReason attemptEnterState(GraphVertex<State> successorStateVertex,
136                    WalkStateGraphContext walkStateGraphContext)
137                    throws SmackException, XMPPErrorException, SASLErrorException, IOException, InterruptedException, FailedNonzaException {
138        State successorState = successorStateVertex.getElement();
139        TransitionImpossibleReason transitionImpossibleReason = successorState.isTransitionToPossible();
140        if (transitionImpossibleReason != null) {
141            return transitionImpossibleReason;
142        }
143
144        // TODO: Emit a signal that we are about to transition into successorState from currentState
145        TransitionFailedReason transitionFailedReason = successorState.transitionInto(
146                        walkStateGraphContext.loginContext);
147        if (transitionFailedReason != null) {
148            return transitionFailedReason;
149        }
150
151        currentStateVertex = successorStateVertex;
152        // TODO: Emit a signal that the state was changed.
153        return null;
154    }
155
156    protected void walkStateGraph(WalkStateGraphContext walkStateGraphContext)
157                    throws XMPPErrorException, SASLErrorException, IOException, SmackException, InterruptedException, FailedNonzaException {
158        State currentState = currentStateVertex.getElement();
159        StateDescriptor currentStateDescriptor = currentState.getStateDescriptor();
160
161        walkStateGraphContext.walkedStateGraphPath.add(currentStateDescriptor);
162
163        if (currentStateDescriptor.getClass() == walkStateGraphContext.finalStateClass) {
164            // We reached the final state.
165            return;
166        }
167
168        List<GraphVertex<State>> outgoingStateEdges = currentStateVertex.getOutgoingEdges();
169        if (walkStateGraphContext.mandatoryIntermedidateState != null && !walkStateGraphContext.mandatoryIntermediateStateHandled) {
170            // Check if outgoingStateEdges contains the mandatory intermediate state.
171            GraphVertex<State> mandatoryIntermediateStateVertex = null;
172            for (GraphVertex<State> outgoingStateVertex : outgoingStateEdges) {
173                if (outgoingStateVertex.getElement().getStateDescriptor().getClass() == walkStateGraphContext.mandatoryIntermedidateState) {
174                    mandatoryIntermediateStateVertex = outgoingStateVertex;
175                    break;
176                }
177            }
178
179            if (mandatoryIntermediateStateVertex != null) {
180                walkStateGraphContext.mandatoryIntermediateStateHandled = true;
181                TransitionReason reason = attemptEnterState(mandatoryIntermediateStateVertex, walkStateGraphContext);
182                if (reason != null) {
183                    throw new IllegalStateException();
184                }
185
186                walkStateGraph(walkStateGraphContext);
187                return;
188            }
189        }
190
191        Iterator<GraphVertex<State>> iterator = outgoingStateEdges.iterator();
192        GraphVertex<State> successorStateVertex;
193        while (true) {
194            successorStateVertex = iterator.next();
195            State successorState = successorStateVertex.getElement();
196            TransitionReason reason = attemptEnterState(successorStateVertex, walkStateGraphContext);
197            if (reason == null) {
198                break;
199            } else if (reason instanceof TransitionImpossibleReason) {
200                // TODO: Handle transition impossible reason.
201                LOGGER.severe("Transition into " + successorState + " not possible " + reason);
202            } else if (reason instanceof TransitionFailedReason) {
203                // TODO: Handle transition failed reason.
204                LOGGER.severe("Transition into " + successorState + " failed " + reason);
205            } else {
206                throw new AssertionError();
207            }
208
209
210            if (!iterator.hasNext()) {
211                throw new IllegalStateException("We don't know where to go from here");
212            }
213        }
214
215        // Walk the state graph by recursion.
216        walkStateGraph(walkStateGraphContext);
217    }
218
219    protected abstract SSLSession getSSLSession();
220
221    @Override
222    protected void afterFeaturesReceived() {
223        featuresReceived = true;
224        synchronized (this) {
225            notifyAll();
226        }
227    }
228
229    protected final void parseAndProcessElement(String element) throws Exception {
230        XmlPullParser parser = PacketParserUtils.getParserFor(element);
231
232        // Skip the enclosing stream open what is guaranteed to be there.
233        parser.next();
234
235        int event = parser.getEventType();
236        outerloop: while (true) {
237            switch (event) {
238            case XmlPullParser.START_TAG:
239                final String name = parser.getName();
240                // Note that we don't handle "stream" here as it's done in the splitter.
241                switch (name) {
242                case Message.ELEMENT:
243                case IQ.IQ_ELEMENT:
244                case Presence.ELEMENT:
245                    try {
246                        parseAndProcessStanza(parser);
247                    } finally {
248                        // TODO stream management code
249                        // clientHandledStanzasCount = SMUtils.incrementHeight(clientHandledStanzasCount);
250                    }
251                    break;
252                case "error":
253                    StreamError streamError = PacketParserUtils.parseStreamError(parser);
254                    saslFeatureReceived.reportFailure(new StreamErrorException(streamError));
255                    throw new StreamErrorException(streamError);
256                case "features":
257                    parseFeatures(parser);
258                    afterFeaturesReceived();
259                    break;
260                // SASL related top level stream elements
261                case Challenge.ELEMENT:
262                    // The server is challenging the SASL authentication made by the client
263                    String challengeData = parser.nextText();
264                    getSASLAuthentication().challengeReceived(challengeData);
265                    break;
266                case Success.ELEMENT:
267                    Success success = new Success(parser.nextText());
268                    // The SASL authentication with the server was successful. The next step
269                    // will be to bind the resource
270                    getSASLAuthentication().authenticated(success);
271                    sendStreamOpen();
272                    break;
273                default:
274                    parseAndProcessNonza(parser);
275                    break;
276                }
277                break;
278            case XmlPullParser.END_DOCUMENT:
279                break outerloop;
280            }
281            event = parser.next();
282        }
283    }
284
285    protected synchronized void prepareToWaitForFeaturesReceived() {
286        featuresReceived = false;
287    }
288
289    protected void waitForFeaturesReceived(String waitFor) throws InterruptedException, NoResponseException {
290        long waitStartMs = System.currentTimeMillis();
291        long timeoutMs = getReplyTimeout();
292        synchronized (this) {
293            while (!featuresReceived) {
294                long remainingWaitMs = timeoutMs - (System.currentTimeMillis() - waitStartMs);
295                if (remainingWaitMs <= 0) {
296                    throw NoResponseException.newWith(this, waitFor);
297                }
298                wait(remainingWaitMs);
299            }
300        }
301    }
302
303    protected void newStreamOpenWaitForFeaturesSequence(String waitFor) throws NoResponseException, InterruptedException, NotConnectedException {
304        prepareToWaitForFeaturesReceived();
305        sendStreamOpen();
306        waitForFeaturesReceived(waitFor);
307    }
308
309    protected final void addXmppInputOutputFilter(XmppInputOutputFilter xmppInputOutputFilter) {
310        inputOutputFilters.add(0, xmppInputOutputFilter);
311    }
312
313    protected final ListIterator<XmppInputOutputFilter> getXmppInputOutputFilterBeginIterator() {
314        return inputOutputFilters.listIterator();
315    }
316
317    protected final ListIterator<XmppInputOutputFilter> getXmppInputOutputFilterEndIterator() {
318        return inputOutputFilters.listIterator(inputOutputFilters.size());
319    }
320
321    public abstract class State {
322        private final StateDescriptor stateDescriptor;
323
324        protected State(StateDescriptor stateDescriptor) {
325            this.stateDescriptor = stateDescriptor;
326        }
327
328        /**
329         * Check if the state should be activated.
330         *
331         * @return <code>null</code> if the state should be activated.
332         * @throws SmackException in case a Smack exception occurs.
333         */
334        protected TransitionImpossibleReason isTransitionToPossible() throws SmackException {
335            return null;
336        }
337
338        protected abstract TransitionFailedReason transitionInto(LoginContext loginContext)
339                        throws XMPPErrorException, SASLErrorException, IOException, SmackException, InterruptedException, FailedNonzaException;
340
341        StateDescriptor getStateDescriptor() {
342            return stateDescriptor;
343        }
344
345        // TODO: Call this if we reach disconnect.
346        protected void resetState() {
347        }
348
349        @Override
350        public String toString() {
351            return "State " + stateDescriptor + ' ' + AbstractXmppStateMachineConnection.this;
352        }
353    }
354
355    private abstract static class TransitionReason {
356        public final String reason;
357        private TransitionReason(String reason) {
358            this.reason = reason;
359        }
360
361        @Override
362        public final String toString() {
363            return reason;
364        }
365    }
366
367    public static class TransitionImpossibleReason extends TransitionReason {
368        public TransitionImpossibleReason(String reason) {
369            super(reason);
370        }
371    }
372
373    public static class TransitionFailedReason extends TransitionReason {
374        public TransitionFailedReason(String reason) {
375            super(reason);
376        }
377    }
378
379    protected final class NoOpState extends State {
380
381        private NoOpState(StateDescriptor stateDescriptor) {
382            super(stateDescriptor);
383        }
384
385        @Override
386        protected TransitionImpossibleReason isTransitionToPossible() {
387            // Transition into a NoOpState is always possible.
388            return null;
389        }
390
391        @Override
392        protected TransitionFailedReason transitionInto(LoginContext loginContext) throws XMPPErrorException,
393                        SASLErrorException, IOException, SmackException, InterruptedException {
394            // Transition into a NoOpState always succeeds.
395            return null;
396        }
397    }
398
399    public static class DisconnectedStateDescriptor extends StateDescriptor {
400        protected DisconnectedStateDescriptor() {
401            super(DisconnectedState.class);
402        }
403    }
404
405    public class DisconnectedState extends State {
406
407        protected DisconnectedState(StateDescriptor stateDescriptor) {
408            super(stateDescriptor);
409        }
410
411        @Override
412        protected TransitionFailedReason transitionInto(LoginContext loginContext) throws XMPPErrorException,
413                        SASLErrorException, IOException, SmackException, InterruptedException, FailedNonzaException {
414            inputOutputFilters.clear();
415            return null;
416        }
417    }
418
419    public static class ConnectedButUnauthenticatedStateDescriptor extends StateDescriptor {
420        public ConnectedButUnauthenticatedStateDescriptor() {
421            addSuccessor(SaslAuthenticationStateDescriptor.class);
422        }
423    }
424
425    public static class SaslAuthenticationStateDescriptor extends StateDescriptor {
426        public SaslAuthenticationStateDescriptor() {
427            super(SaslAuthenticationState.class, "RFC 6120 § 6");
428            addSuccessor(AuthenticatedButUnboundStateDescriptor.class);
429        }
430    }
431
432    private final class SaslAuthenticationState extends State {
433        protected SaslAuthenticationState(StateDescriptor stateDescriptor) {
434            super(stateDescriptor);
435        }
436
437        @Override
438        protected TransitionFailedReason transitionInto(LoginContext loginContext) throws XMPPErrorException,
439                        SASLErrorException, IOException, SmackException, InterruptedException {
440            prepareToWaitForFeaturesReceived();
441
442            saslAuthentication.authenticate(loginContext.username, loginContext.password, config.getAuthzid(), getSSLSession());
443            // authenticate() will only return if the SASL authentication was successful, but we also need to wait for the next round of stream features.
444
445            waitForFeaturesReceived("server stream features after SASL authentication");
446            return null;
447        }
448    }
449
450    public static class AuthenticatedButUnboundStateDescriptor extends StateDescriptor {
451        public AuthenticatedButUnboundStateDescriptor() {
452            addSuccessor(ResourceBindingStateDescriptor.class);
453            addSuccessor(CompressionStateDescriptor.class);
454        }
455    }
456
457    public static class ResourceBindingStateDescriptor extends StateDescriptor {
458        public ResourceBindingStateDescriptor() {
459            super(ResourceBindingState.class, "RFC 6120 § 7");
460            addSuccessor(AuthenticatedAndResourceBoundStateDescriptor.class);
461        }
462    }
463
464    private class ResourceBindingState extends State {
465        protected ResourceBindingState(StateDescriptor stateDescriptor) {
466            super(stateDescriptor);
467        }
468
469        @Override
470        protected TransitionFailedReason transitionInto(LoginContext loginContext) throws XMPPErrorException,
471                        SASLErrorException, IOException, SmackException, InterruptedException {
472            // TODO: The reportSuccess() is just a quick fix until there is a variant of the
473            // bindResourceAndEstablishSession() method which does not require this.
474            lastFeaturesReceived.reportSuccess();
475            bindResourceAndEstablishSession(loginContext.resource);
476            streamResumed = false;
477            return null;
478        }
479    }
480
481    public static class CompressionStateDescriptor extends StateDescriptor {
482        public CompressionStateDescriptor() {
483            super(CompressionState.class, 138);
484            addSuccessor(AuthenticatedButUnboundStateDescriptor.class);
485            declarePrecedenceOver(ResourceBindingStateDescriptor.class);
486        }
487    }
488
489    private class CompressionState extends State {
490        private XmppCompressionFactory selectedCompressionFactory;
491        private XmppInputOutputFilter usedXmppInputOutputCompressionFitler;
492
493        protected CompressionState(StateDescriptor stateDescriptor) {
494            super(stateDescriptor);
495        }
496
497        @Override
498        protected TransitionImpossibleReason isTransitionToPossible() throws SmackException {
499            if (!config.isCompressionEnabled()) {
500                return new TransitionImpossibleReason("Stream compression disabled");
501            }
502
503            Compress.Feature compressFeature = getFeature(Compress.Feature.ELEMENT, Compress.NAMESPACE);
504            if (compressFeature == null) {
505                return new TransitionImpossibleReason("Stream compression not supported");
506            }
507
508            selectedCompressionFactory = XmppCompressionManager.getBestFactory(compressFeature);
509            if (selectedCompressionFactory == null) {
510                return new TransitionImpossibleReason("No matching compression factory");
511            }
512
513            usedXmppInputOutputCompressionFitler = selectedCompressionFactory.fabricate(config);
514
515            return null;
516        }
517
518        @Override
519        protected TransitionFailedReason transitionInto(LoginContext loginContext) throws XMPPErrorException,
520                        SASLErrorException, IOException, SmackException, InterruptedException, FailedNonzaException {
521            sendAndWaitForResponse(new Compress(selectedCompressionFactory.getCompressionMethod()), Compressed.class, Failure.class);
522
523            addXmppInputOutputFilter(usedXmppInputOutputCompressionFitler);
524
525            newStreamOpenWaitForFeaturesSequence("server stream features after compression enabled");
526
527            return null;
528        }
529
530        @Override
531        protected void resetState() {
532            selectedCompressionFactory = null;
533            usedXmppInputOutputCompressionFitler = null;
534        }
535    }
536
537    public static class AuthenticatedAndResourceBoundStateDescriptor extends StateDescriptor {
538        protected AuthenticatedAndResourceBoundStateDescriptor() {
539            super(AuthenticatedAndResourceBoundState.class);
540        }
541    }
542
543    private class AuthenticatedAndResourceBoundState extends State {
544        protected AuthenticatedAndResourceBoundState(StateDescriptor stateDescriptor) {
545            super(stateDescriptor);
546        }
547
548        @Override
549        protected TransitionFailedReason transitionInto(LoginContext loginContext) throws XMPPErrorException,
550                        SASLErrorException, IOException, SmackException, InterruptedException {
551            // TODO: This false indicates SM resumption.
552            afterSuccessfulLogin(streamResumed);
553            return null;
554        }
555    }
556}