Skip to content

Commit

Permalink
Lift restriction of having uniqueness of subscription registration (#…
Browse files Browse the repository at this point in the history
…1720)

* Lift restriction of having uniqueness of subscription registration

Remove checking if the subscription with the same arguments was already registered.
Address some synchronization concerns.

Closes  #1706

* Fix tests
  • Loading branch information
sav007 authored Nov 1, 2019
1 parent fc37c15 commit 5927e4c
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
import com.apollographql.apollo.subscription.SubscriptionTransport;
import org.jetbrains.annotations.NotNull;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Timer;
import java.util.TimerTask;
import java.util.UUID;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
Expand All @@ -37,7 +37,7 @@ public final class RealSubscriptionManager implements SubscriptionManager {
static final long CONNECTION_ACKNOWLEDGE_TIMEOUT = TimeUnit.SECONDS.toMillis(5);
static final long INACTIVITY_TIMEOUT = TimeUnit.SECONDS.toMillis(10);

Map<String, SubscriptionRecord> subscriptions = new LinkedHashMap<>();
Map<UUID, SubscriptionRecord> subscriptions = new LinkedHashMap<>();
volatile SubscriptionManagerState state = SubscriptionManagerState.DISCONNECTED;
final AutoReleaseTimer timer = new AutoReleaseTimer();

Expand Down Expand Up @@ -82,8 +82,7 @@ public RealSubscriptionManager(@NotNull ScalarTypeAdapters scalarTypeAdapters,
}

@Override
public <T> void subscribe(@NotNull final Subscription<?, T, ?> subscription,
@NotNull final SubscriptionManager.Callback<T> callback) {
public <T> void subscribe(@NotNull final Subscription<?, T, ?> subscription, @NotNull final SubscriptionManager.Callback<T> callback) {
checkNotNull(subscription, "subscription == null");
checkNotNull(callback, "callback == null");
dispatcher.execute(new Runnable() {
Expand Down Expand Up @@ -111,11 +110,15 @@ public void run() {
*/
@Override
public void start() {
final SubscriptionManagerState oldState;
synchronized (this) {
oldState = state;
if (state == SubscriptionManagerState.STOPPED) {
setStateAndNotify(SubscriptionManagerState.DISCONNECTED);
state = SubscriptionManagerState.DISCONNECTED;
}
}

notifyStateChanged(oldState, state);
}

/**
Expand All @@ -127,14 +130,11 @@ public void start() {
*/
@Override
public void stop() {
synchronized (this) {
setStateAndNotify(SubscriptionManagerState.STOPPING);
ArrayList<SubscriptionRecord> values = new ArrayList<>(subscriptions.values());
for (SubscriptionRecord subscription : values) {
doUnsubscribe(subscription.subscription);
dispatcher.execute(new Runnable() {
@Override public void run() {
doStop();
}
disconnect(true);
}
});
}

@Override public SubscriptionManagerState getState() {
Expand All @@ -150,67 +150,112 @@ public void stop() {
}

void doSubscribe(Subscription subscription, SubscriptionManager.Callback callback) {
if (state == SubscriptionManagerState.STOPPING || state == SubscriptionManagerState.STOPPED) {
final SubscriptionManagerState oldState;
synchronized (this) {
oldState = state;

if (state != SubscriptionManagerState.STOPPING && state != SubscriptionManagerState.STOPPED) {
timer.cancelTask(INACTIVITY_TIMEOUT_TIMER_TASK_ID);

final UUID subscriptionId = UUID.randomUUID();

subscriptions.put(subscriptionId, new SubscriptionRecord(subscriptionId, subscription, callback));
if (state == SubscriptionManagerState.DISCONNECTED) {
state = SubscriptionManagerState.CONNECTING;
transport.connect();
} else if (state == SubscriptionManagerState.ACTIVE) {
transport.send(new OperationClientMessage.Start(subscriptionId.toString(), subscription, scalarTypeAdapters));
}
}
}

if (oldState == SubscriptionManagerState.STOPPING || oldState == SubscriptionManagerState.STOPPED) {
callback.onError(new ApolloSubscriptionException(
"Illegal state: " + state.name() + " for subscriptions to be created."
+ " SubscriptionManager.start() must be called to re-enable subscriptions."));
return;
} else if (oldState == SubscriptionManagerState.CONNECTED) {
callback.onConnected();
}
timer.cancelTask(INACTIVITY_TIMEOUT_TIMER_TASK_ID);

String subscriptionId = idForSubscription(subscription);
notifyStateChanged(oldState, state);
}

void doUnsubscribe(Subscription subscription) {
synchronized (this) {
if (subscriptions.containsKey(subscriptionId)) {
callback.onError(new ApolloSubscriptionException("Already subscribed"));
return;
SubscriptionRecord subscriptionRecord = null;
for (SubscriptionRecord record : subscriptions.values()) {
if (record.subscription == subscription) {
subscriptionRecord = record;
}
}

subscriptions.put(subscriptionId, new SubscriptionRecord(subscription, callback));
if (state == SubscriptionManagerState.DISCONNECTED) {
setStateAndNotify(SubscriptionManagerState.CONNECTING);
transport.connect();
} else if (state == SubscriptionManagerState.ACTIVE) {
transport.send(new OperationClientMessage.Start(subscriptionId, subscription, scalarTypeAdapters));
if (subscriptionRecord != null) {
subscriptions.remove(subscriptionRecord.id);
if (state == SubscriptionManagerState.ACTIVE || state == SubscriptionManagerState.STOPPING) {
transport.send(new OperationClientMessage.Stop(subscriptionRecord.id.toString()));
}
}

if (subscriptions.isEmpty() && state != SubscriptionManagerState.STOPPING) {
startInactivityTimer();
}
}
}

void doUnsubscribe(Subscription subscription) {
String subscriptionId = idForSubscription(subscription);

SubscriptionRecord subscriptionRecord;
void doStop() {
final Collection<SubscriptionRecord> subscriptionRecords;
final SubscriptionManagerState oldState;
synchronized (this) {
subscriptionRecord = subscriptions.remove(subscriptionId);
if ((subscriptionRecord != null) && (state == SubscriptionManagerState.ACTIVE || state == SubscriptionManagerState.STOPPING)) {
transport.send(new OperationClientMessage.Stop(subscriptionId));
}
oldState = state;
state = SubscriptionManagerState.STOPPING;

if (subscriptions.isEmpty() && state != SubscriptionManagerState.STOPPING) {
startInactivityTimer();
subscriptionRecords = subscriptions.values();

if (oldState == SubscriptionManagerState.ACTIVE) {
for (SubscriptionRecord subscriptionRecord : subscriptionRecords) {
transport.send(new OperationClientMessage.Stop(subscriptionRecord.id.toString()));
}
}

state = SubscriptionManagerState.STOPPED;

transport.disconnect(new OperationClientMessage.Terminate());
subscriptions = new LinkedHashMap<>();
}

for (SubscriptionRecord record : subscriptionRecords) {
record.notifyOnCompleted();
}

notifyStateChanged(oldState, SubscriptionManagerState.STOPPING);
notifyStateChanged(SubscriptionManagerState.STOPPING, state);
}

void onTransportConnected() {
final Collection<SubscriptionRecord> subscriptionRecords;

final SubscriptionManagerState oldState;
synchronized (this) {
oldState = state;

if (state == SubscriptionManagerState.CONNECTING) {
subscriptionRecords = subscriptions.values();
setStateAndNotify(SubscriptionManagerState.CONNECTED);
state = SubscriptionManagerState.CONNECTED;
transport.send(new OperationClientMessage.Init(connectionParams.provide()));
} else {
subscriptionRecords = Collections.emptyList();
}

if (state == SubscriptionManagerState.CONNECTED) {
timer.schedule(CONNECTION_ACKNOWLEDGE_TIMEOUT_TIMER_TASK_ID, connectionAcknowledgeTimeoutTimerTask, CONNECTION_ACKNOWLEDGE_TIMEOUT);
}
}

for (SubscriptionRecord record : subscriptionRecords) {
record.callback.onConnected();
}

if (state == SubscriptionManagerState.CONNECTED) {
timer.schedule(CONNECTION_ACKNOWLEDGE_TIMEOUT_TIMER_TASK_ID, connectionAcknowledgeTimeoutTimerTask,
CONNECTION_ACKNOWLEDGE_TIMEOUT);
}
notifyStateChanged(oldState, state);
}

void onConnectionAcknowledgeTimeout() {
Expand All @@ -234,12 +279,7 @@ public void run() {
}

void onTransportFailure(Throwable t) {
Collection<SubscriptionRecord> subscriptionRecords;
synchronized (this) {
subscriptionRecords = subscriptions.values();
disconnect(true);
}

Collection<SubscriptionRecord> subscriptionRecords = disconnect(true);
for (SubscriptionRecord record : subscriptionRecords) {
record.notifyOnNetworkError(t);
}
Expand Down Expand Up @@ -269,48 +309,62 @@ void onOperationServerMessage(OperationServerMessage message) {
*
* @param force if true, always disconnect web socket, regardless of the status of {@link #subscriptions}
*/
void disconnect(boolean force) {
Collection<SubscriptionRecord> disconnect(boolean force) {
final SubscriptionManagerState oldState;
final Collection<SubscriptionRecord> subscriptionRecords;
synchronized (this) {
oldState = state;
subscriptionRecords = subscriptions.values();
if (force || subscriptions.isEmpty()) {
transport.disconnect(new OperationClientMessage.Terminate());
SubscriptionManagerState disconnectionState = (state == SubscriptionManagerState.STOPPING) ? SubscriptionManagerState.STOPPED
: SubscriptionManagerState.DISCONNECTED;
setStateAndNotify(disconnectionState);
state = (state == SubscriptionManagerState.STOPPING) ? SubscriptionManagerState.STOPPED : SubscriptionManagerState.DISCONNECTED;
subscriptions = new LinkedHashMap<>();
}
}

notifyStateChanged(oldState, state);

return subscriptionRecords;
}

void onConnectionHeartbeatTimeout() {
final SubscriptionManagerState oldState;
synchronized (this) {
oldState = state;
state = SubscriptionManagerState.DISCONNECTED;
transport.disconnect(new OperationClientMessage.Terminate());
setStateAndNotify(SubscriptionManagerState.DISCONNECTED);

setStateAndNotify(SubscriptionManagerState.CONNECTING);
state = SubscriptionManagerState.CONNECTING;
transport.connect();
}

notifyStateChanged(oldState, SubscriptionManagerState.DISCONNECTED);
notifyStateChanged(SubscriptionManagerState.DISCONNECTED, SubscriptionManagerState.CONNECTING);
}

void onConnectionClosed() {
Collection<SubscriptionRecord> subscriptionRecords;
final SubscriptionManagerState oldState;
synchronized (this) {
oldState = state;

subscriptionRecords = subscriptions.values();
setStateAndNotify(SubscriptionManagerState.DISCONNECTED);
state = SubscriptionManagerState.DISCONNECTED;
subscriptions = new LinkedHashMap<>();
}

for (SubscriptionRecord record : subscriptionRecords) {
record.callback.onTerminated();
}

notifyStateChanged(oldState, state);
}

private void resetConnectionKeepAliveTimerTask() {
if (connectionHeartbeatTimeoutMs <= 0) {
return;
}
synchronized (this) {
timer.schedule(CONNECTION_KEEP_ALIVE_TIMEOUT_TIMER_TASK_ID, connectionHeartbeatTimeoutTimerTask,
connectionHeartbeatTimeoutMs);
timer.schedule(CONNECTION_KEEP_ALIVE_TIMEOUT_TIMER_TASK_ID, connectionHeartbeatTimeoutTimerTask, connectionHeartbeatTimeoutMs);
}
}

Expand All @@ -323,7 +377,11 @@ private void onOperationDataServerMessage(OperationServerMessage.Data message) {
String subscriptionId = message.id != null ? message.id : "";
SubscriptionRecord subscriptionRecord;
synchronized (this) {
subscriptionRecord = subscriptions.get(subscriptionId);
try {
subscriptionRecord = subscriptions.get(UUID.fromString(subscriptionId));
} catch (IllegalArgumentException e) {
subscriptionRecord = null;
}
}

if (subscriptionRecord != null) {
Expand All @@ -347,17 +405,22 @@ private void onOperationDataServerMessage(OperationServerMessage.Data message) {
}

private void onConnectionAcknowledgeServerMessage() {
timer.cancelTask(CONNECTION_ACKNOWLEDGE_TIMEOUT_TIMER_TASK_ID);
final SubscriptionManagerState oldState;
synchronized (this) {
oldState = state;

timer.cancelTask(CONNECTION_ACKNOWLEDGE_TIMEOUT_TIMER_TASK_ID);

if (state == SubscriptionManagerState.CONNECTED) {
setStateAndNotify(SubscriptionManagerState.ACTIVE);
for (Map.Entry<String, SubscriptionRecord> entry : subscriptions.entrySet()) {
String subscriptionId = entry.getKey();
Subscription<?, ?, ?> subscription = entry.getValue().subscription;
transport.send(new OperationClientMessage.Start(subscriptionId, subscription, scalarTypeAdapters));
state = SubscriptionManagerState.ACTIVE;
for (SubscriptionRecord subscriptionRecord : subscriptions.values()) {
transport.send(new OperationClientMessage.Start(subscriptionRecord.id.toString(), subscriptionRecord.subscription,
scalarTypeAdapters));
}
}
}

notifyStateChanged(oldState, state);
}

private void onErrorServerMessage(OperationServerMessage.Error message) {
Expand All @@ -379,31 +442,36 @@ private void onCompleteServerMessage(OperationServerMessage.Complete message) {
private SubscriptionRecord removeSubscriptionById(String subscriptionId) {
SubscriptionRecord subscriptionRecord;
synchronized (this) {
subscriptionRecord = subscriptions.remove(subscriptionId);
try {
subscriptionRecord = subscriptions.remove(UUID.fromString(subscriptionId));
} catch (IllegalArgumentException e) {
subscriptionRecord = null;
}

if (subscriptions.isEmpty()) {
startInactivityTimer();
}
}
return subscriptionRecord;
}

private void setStateAndNotify(SubscriptionManagerState newState) {
SubscriptionManagerState oldState = state;
state = newState;
private void notifyStateChanged(SubscriptionManagerState oldState, SubscriptionManagerState newState) {
if (oldState == newState) {
return;
}

for (OnSubscriptionManagerStateChangeListener onStateChangeListener : onStateChangeListeners) {
onStateChangeListener.onStateChange(oldState, newState);
}
}

static String idForSubscription(Subscription<?, ?, ?> subscription) {
return subscription.operationId() + "$" + subscription.variables().valueMap().hashCode();
}

private static class SubscriptionRecord {
final UUID id;
final Subscription<?, ?, ?> subscription;
final SubscriptionManager.Callback<?> callback;

SubscriptionRecord(Subscription<?, ?, ?> subscription, SubscriptionManager.Callback<?> callback) {
SubscriptionRecord(UUID id, Subscription<?, ?, ?> subscription, SubscriptionManager.Callback<?> callback) {
this.id = id;
this.subscription = subscription;
this.callback = callback;
}
Expand Down Expand Up @@ -504,7 +572,6 @@ public void run() {

timer.schedule(timerTask, delay);
}

}

void cancelTask(int taskId) {
Expand Down
Loading

0 comments on commit 5927e4c

Please sign in to comment.