Skip to content

Commit

Permalink
Adding resiliency to subscription lookup failures due to subscription…
Browse files Browse the repository at this point in the history
…s being mapped to different tenant (#25)
  • Loading branch information
RamjotSingh authored May 13, 2020
1 parent 77abed1 commit 83d2fa1
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 50 deletions.
22 changes: 11 additions & 11 deletions UI/TunnelRelay.UI.Shared/Logger/FileLoggerProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ internal class FileLoggerProvider : ILoggerProvider
private readonly object lockObject = new object();

/// <summary>
/// The stream writer.
/// The text writer.
/// </summary>
private StreamWriter streamWriter = null;
private TextWriter textWriter = null;

/// <summary>
/// The external scope provider to allow setting scope data in messages.
Expand Down Expand Up @@ -57,23 +57,23 @@ public FileLoggerProvider(IOptions<FileLoggerProviderOptions> fileLoggerProvider
/// <returns>An <see cref="ILogger"/> instance to be used for logging.</returns>
public ILogger CreateLogger(string categoryName)
{
if (this.streamWriter == null)
if (this.textWriter == null)
{
lock (this.lockObject)
{
if (this.streamWriter == null)
if (this.textWriter == null)
{
// Opening in shared mode because other logger instances are also using the same file. So all of them can write to the same file.
this.streamWriter = new StreamWriter(new FileStream(this.logFileName, FileMode.Append, FileAccess.Write, FileShare.ReadWrite))
this.textWriter = TextWriter.Synchronized(new StreamWriter(new FileStream(this.logFileName, FileMode.Append, FileAccess.Write, FileShare.ReadWrite))
{
AutoFlush = true,
NewLine = Environment.NewLine,
};
});
}
}
}

return new StreamLogger(this.streamWriter, categoryName)
return new StreamLogger(this.textWriter, categoryName)
{
ExternalScopeProvider = this.externalScopeProvider,
};
Expand Down Expand Up @@ -103,11 +103,11 @@ public void SetScopeProvider(IExternalScopeProvider externalScopeProvider)
/// <param name="releasedManagedResources">Release managed resources.</param>
protected virtual void Dispose(bool releasedManagedResources)
{
if (this.streamWriter != null)
if (this.textWriter != null)
{
this.streamWriter.Flush();
this.streamWriter.Close();
this.streamWriter.Dispose();
this.textWriter.Flush();
this.textWriter.Close();
this.textWriter.Dispose();
}
}
}
Expand Down
23 changes: 18 additions & 5 deletions UI/TunnelRelay.UI.Shared/Logger/StreamLogger.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
namespace TunnelRelay.UI.Logger
{
using System;
using System.Diagnostics;
using System.IO;
using Microsoft.Extensions.Logging;

Expand All @@ -15,18 +16,18 @@ namespace TunnelRelay.UI.Logger
/// <seealso cref="ILogger" />
internal class StreamLogger : ILogger
{
private readonly StreamWriter streamWriter;
private readonly TextWriter textWriter;
private readonly string categoryName;

/// <summary>
/// Initializes a new instance of the <see cref="StreamLogger"/> class.
/// </summary>
/// <param name="streamWriter">The stream writer.</param>
/// <param name="textWriter">The text writer.</param>
/// <param name="categoryName">Name of the category.</param>
/// <exception cref="ArgumentNullException">streamWriter.</exception>
public StreamLogger(StreamWriter streamWriter, string categoryName)
public StreamLogger(TextWriter textWriter, string categoryName)
{
this.streamWriter = streamWriter ?? throw new ArgumentNullException(nameof(streamWriter));
this.textWriter = textWriter ?? throw new ArgumentNullException(nameof(textWriter));
this.categoryName = categoryName;
}

Expand Down Expand Up @@ -79,7 +80,19 @@ public void Log<TState>(LogLevel logLevel, EventId eventId, TState state, Except
message = message + "Exception - " + exception;
}

this.streamWriter.WriteLine(message);
try
{
this.textWriter.WriteLine(message);
}
#pragma warning disable CS0168 // Variable is declared but never used - Kept here for debugging purposes.
catch (Exception ex)
#pragma warning restore CS0168 // Variable is declared but never used - Kept here for debugging purposes.
{
if (Debugger.IsAttached)
{
Debugger.Break();
}
}
}
}
}
217 changes: 183 additions & 34 deletions UI/TunnelRelay.UI.Shared/ResourceManagement/UserAuthenticator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,7 @@ public async Task<List<SubscriptionInner>> GetUserSubscriptionsAsync()
tokenAcquireTasks.Add(Task.Run(async () =>
{
// Optimization to skip refetching tokens. AAD tokens live for 1 hour.
if (!this.tenantBasedTokenMap.ContainsKey(tenant.TenantId))
{
try
{
this.logger.LogInformation("Get token with '{0}' tenant info", tenant.TenantId);
AuthenticationResult tenantizedToken = await this.AcquireAzureManagementTokenAsync(tenant.TenantId, Prompt.NoPrompt, this.userIdentifier).ConfigureAwait(false);
this.tenantBasedTokenMap[this.GetTenantOnToken(tenantizedToken)] = tenantizedToken;
}
catch (Exception ex)
{
this.logger.LogWarning(ex, $"Failed to acquire token for tenant with Id '{tenant.TenantId}'");
}
}
await this.TryUpdateTokenCacheForTenantAsync(tenant.TenantId).ConfigureAwait(false);
}));
});

Expand All @@ -164,21 +152,7 @@ public async Task<List<SubscriptionInner>> GetUserSubscriptionsAsync()
{
subscriptionTasks.Add(Task.Run(async () =>
{
List<SubscriptionInner> subscriptionList = new List<SubscriptionInner>();
this.logger.LogTrace("Getting subscriptions for '{0}' tenant.", tenant.TenantId);
TokenCredentials subsCreds = new TokenCredentials(this.tenantBasedTokenMap[tenant.TenantId].AccessToken);
SubscriptionClient subscriptionClient = new SubscriptionClient(subsCreds);
IPage<SubscriptionInner> resp = await subscriptionClient.Subscriptions.ListAsync().ConfigureAwait(false);
subscriptionList.AddRange(resp);
while (!string.IsNullOrEmpty(resp.NextPageLink))
{
resp = await subscriptionClient.Subscriptions.ListNextAsync(resp.NextPageLink).ConfigureAwait(false);
subscriptionList.AddRange(resp);
}
this.logger.LogTrace("Fetched total of '{0}' subscriptions for tenant '{1}'", subscriptionList.Count, tenant.TenantId);
List<SubscriptionInner> subscriptionList = await this.GetSubcriptionsForATenantAsync(tenant.TenantId).ConfigureAwait(false);
subscriptionList.ForEach(subscription => this.subscriptionToTenantMap[subscription] = tenant);
}));
Expand All @@ -187,20 +161,33 @@ public async Task<List<SubscriptionInner>> GetUserSubscriptionsAsync()
await Task.WhenAll(subscriptionTasks).ConfigureAwait(false);

List<Task> locationTasks = new List<Task>();

List<SubscriptionInner> markedForRemovalSubscriptions = new List<SubscriptionInner>();
foreach (KeyValuePair<SubscriptionInner, TenantIdDescription> subscription in this.subscriptionToTenantMap)
{
locationTasks.Add(Task.Run(async () =>
{
TokenCredentials subsCreds = new TokenCredentials(this.tenantBasedTokenMap[subscription.Value.TenantId].AccessToken);
SubscriptionClient subscriptionClient = new SubscriptionClient(subsCreds);
IEnumerable<Location> locations = await subscriptionClient.Subscriptions.ListLocationsAsync(subscription.Key.SubscriptionId).ConfigureAwait(false);
this.subscriptionToLocationMap[subscription.Key] = locations;
try
{
IEnumerable<Location> locations = await this.GetLocationsForSubscriptionAsync(subscription.Key, subscription.Value.TenantId).ConfigureAwait(false);
this.subscriptionToLocationMap[subscription.Key] = locations;
}
catch (Exception ex)
{
this.logger.LogError($"Hit exception while getting locations for subscription Id '{subscription.Key.Id}'. Error '{ex}'");
markedForRemovalSubscriptions.Add(subscription.Key);
}
}));
}

await Task.WhenAll(locationTasks).ConfigureAwait(false);

markedForRemovalSubscriptions.ForEach(subscriptionToRemove =>
{
this.logger.LogWarning($"Removing subscription '{subscriptionToRemove.Id}' from the list since location lookup failed!");
this.subscriptionToTenantMap.TryRemove(subscriptionToRemove, out TenantIdDescription _);
});
}

return this.subscriptionToTenantMap.Keys.OrderBy(subs => subs.DisplayName).ToList();
Expand All @@ -226,6 +213,145 @@ public string GetSubscriptionSpecificUserToken(SubscriptionInner subscription)
return this.tenantBasedTokenMap[this.subscriptionToTenantMap[subscription].TenantId].AccessToken;
}

/// <summary>
/// Gets subscriptions associated with a tenantId.
/// </summary>
/// <param name="tenantId">Tenant Id.</param>
/// <returns>List of subscriptions under that tenant.</returns>
private async Task<List<SubscriptionInner>> GetSubcriptionsForATenantAsync(string tenantId)
{
List<SubscriptionInner> subscriptionList = new List<SubscriptionInner>();
this.logger.LogTrace("Getting subscriptions for '{0}' tenant.", tenantId);

await this.TryUpdateTokenCacheForTenantAsync(tenantId).ConfigureAwait(false);

if (this.tenantBasedTokenMap.ContainsKey(tenantId))
{
TokenCredentials subsCreds = new TokenCredentials(this.tenantBasedTokenMap[tenantId].AccessToken);
SubscriptionClient subscriptionClient = new SubscriptionClient(subsCreds);

IPage<SubscriptionInner> resp = await subscriptionClient.Subscriptions.ListAsync().ConfigureAwait(false);
subscriptionList.AddRange(resp);

while (!string.IsNullOrEmpty(resp.NextPageLink))
{
resp = await subscriptionClient.Subscriptions.ListNextAsync(resp.NextPageLink).ConfigureAwait(false);
subscriptionList.AddRange(resp);
}

this.logger.LogTrace("Fetched total of '{0}' subscriptions for tenant '{1}'", subscriptionList.Count, tenantId);
}
else
{
this.logger.LogWarning($"Could not get token for tenant with Id '{tenantId}'. Returning 0 subscriptions");
}

return subscriptionList;
}

/// <summary>
/// Gets the list of locations enabled in a subscription.
/// </summary>
/// <param name="subscription">Subscription to look for in details.</param>
/// <param name="preferredTenantId">Preferred tenant Id.</param>
/// <returns>Collection of locations.</returns>
private async Task<IEnumerable<Location>> GetLocationsForSubscriptionAsync(SubscriptionInner subscription, string preferredTenantId)
{
await this.TryUpdateTokenCacheForTenantAsync(preferredTenantId).ConfigureAwait(false);

HashSet<string> secondaryTenantLookups;

if (this.tenantBasedTokenMap.ContainsKey(preferredTenantId))
{
TokenCredentials subsCreds = new TokenCredentials(this.tenantBasedTokenMap[preferredTenantId].AccessToken);

SubscriptionClient subscriptionClient = new SubscriptionClient(subsCreds);

try
{
return await subscriptionClient.Subscriptions.ListLocationsAsync(subscription.SubscriptionId).ConfigureAwait(false);
}
catch (Exception ex)
{
this.logger.LogWarning(
ex,
$"Preferred tenant based locations lookup for subscription with Id '{subscription.SubscriptionId}' used tenant '{preferredTenantId}'. Looping thru tenants.");

secondaryTenantLookups = this.GetTheTenantIdLookupForSubscriptionOnInvalidAuthenticationTokenError(ex as CloudException);
}
}
else
{
secondaryTenantLookups = this.GetTheTenantIdLookupForSubscriptionOnInvalidAuthenticationTokenError();
}

// If we have reached this point. Either token acquire for preferred tenant failed or we hit an exception.
foreach (string tenantId in secondaryTenantLookups)
{
await this.TryUpdateTokenCacheForTenantAsync(tenantId).ConfigureAwait(false);

if (this.tenantBasedTokenMap.ContainsKey(tenantId))
{
TokenCredentials subsCreds = new TokenCredentials(this.tenantBasedTokenMap[tenantId].AccessToken);

SubscriptionClient subscriptionClient = new SubscriptionClient(subsCreds);

try
{
return await subscriptionClient.Subscriptions.ListLocationsAsync(subscription.SubscriptionId).ConfigureAwait(false);
}
catch (Exception ex)
{
this.logger.LogWarning(
ex,
$"Lookup with TenantId '{tenantId}' failed for subscription with Id '{subscription.SubscriptionId}'. Moving on to next tenant.");
}
}
}

// If we reach this point. We just all up failed to get locations in any possible way. Throwing exception so the subscription can be removed from the list.
throw new UnauthorizedAccessException("Failed to get subscription location");
}

/// <summary>
/// Processes an optional cloud exception and returns an optimistic list of tenants to try against to get subscription info.
/// </summary>
/// <param name="cloudException">Optional cloud exception.</param>
/// <returns>Collection of tenantIds to try.</returns>
private HashSet<string> GetTheTenantIdLookupForSubscriptionOnInvalidAuthenticationTokenError(CloudException cloudException = null)
{
List<string> tenantLookup = new List<string>();

// If Azure thinks we used the wrong tenant, go thru the tenant list and check if we can acquire the token for any other
// tenant and get subscription details. If we can't we drop the subscription and move on.
if (cloudException?.Body?.Code == "InvalidAuthenticationTokenTenant")
{
// Try to optimistically get tenantId from WWW-authenticate header.
// The header looks like
// Bearer authorization_uri="https://login.windows.net/<TenantId>", error="invalid_token", error_description="The access token is from the wrong issuer. It must match the tenant associated with this subscription. Please use correct authority to get the token."
if (cloudException.Response?.Headers != null && cloudException.Response.Headers.ContainsKey("WWW-Authenticate"))
{
string wwwAuthHeader = cloudException.Response.Headers["WWW-Authenticate"].First();

string authUrlPart = wwwAuthHeader.Split(",").FirstOrDefault(split => split.StartsWith("Bearer", StringComparison.OrdinalIgnoreCase));

string loginAuthority = authUrlPart.Replace("Bearer authorization_uri=\"", string.Empty, StringComparison.OrdinalIgnoreCase);
loginAuthority = loginAuthority.Substring(0, loginAuthority.Length - 1);

if (Uri.IsWellFormedUriString(loginAuthority, UriKind.Absolute))
{
string tenantId = new Uri(loginAuthority).AbsolutePath.TrimStart('/');

tenantLookup.Add(tenantId);
}
}
}

tenantLookup.AddRange(this.tenantBasedTokenMap.Keys);

return new HashSet<string>(tenantLookup);
}

/// <summary>
/// Gets the tenantId on the issued token.
/// </summary>
Expand Down Expand Up @@ -261,6 +387,29 @@ private string GetTenantOnToken(AuthenticationResult authenticationResult)
}
}

/// <summary>
/// For a given token tries to update the token cache.
/// </summary>
/// <param name="tenantId">Tenant Id to update cache for.</param>
/// <returns>Task tracking operation.</returns>
private async Task TryUpdateTokenCacheForTenantAsync(string tenantId)
{
// Optimization to skip refetching tokens. AAD tokens live for 1 hour.
if (!this.tenantBasedTokenMap.ContainsKey(tenantId))
{
try
{
this.logger.LogInformation("Get token with '{0}' tenant info", tenantId);
AuthenticationResult tenantizedToken = await this.AcquireAzureManagementTokenAsync(tenantId, Prompt.NoPrompt, this.userIdentifier).ConfigureAwait(false);
this.tenantBasedTokenMap[this.GetTenantOnToken(tenantizedToken)] = tenantizedToken;
}
catch (Exception ex)
{
this.logger.LogWarning(ex, $"Failed to acquire token for tenant with Id '{tenantId}'");
}
}
}

/// <summary>
/// Acquires the Azure management token asynchronously.
/// </summary>
Expand Down

0 comments on commit 83d2fa1

Please sign in to comment.