diff --git a/.github/actions/spell-check/expect.txt b/.github/actions/spell-check/expect.txt index d4b925f503..b8fd30e638 100644 --- a/.github/actions/spell-check/expect.txt +++ b/.github/actions/spell-check/expect.txt @@ -976,6 +976,7 @@ NTAPI ntdll NTSTATUS NTSYSAPI +nullability NULLCURSOR nullonfailure numberbox diff --git a/src/modules/cmdpal/Microsoft.CmdPal.UI.ViewModels/CommandProviderWrapper.cs b/src/modules/cmdpal/Microsoft.CmdPal.UI.ViewModels/CommandProviderWrapper.cs index fbe419e34e..48f4174c5d 100644 --- a/src/modules/cmdpal/Microsoft.CmdPal.UI.ViewModels/CommandProviderWrapper.cs +++ b/src/modules/cmdpal/Microsoft.CmdPal.UI.ViewModels/CommandProviderWrapper.cs @@ -215,9 +215,7 @@ public sealed class CommandProviderWrapper : ICommandProviderContext } catch (Exception e) { - Logger.LogError("Failed to load commands from extension"); - Logger.LogError($"Extension was {Extension!.PackageFamilyName}"); - Logger.LogError(e.ToString()); + Logger.LogError($"Failed to load commands from extension {Extension!.PackageFamilyName}", e); if (!displayInfoInitialized) { diff --git a/src/modules/cmdpal/Microsoft.CmdPal.UI.ViewModels/TopLevelCommandManager.cs b/src/modules/cmdpal/Microsoft.CmdPal.UI.ViewModels/TopLevelCommandManager.cs index f30f960665..907a4f1d86 100644 --- a/src/modules/cmdpal/Microsoft.CmdPal.UI.ViewModels/TopLevelCommandManager.cs +++ b/src/modules/cmdpal/Microsoft.CmdPal.UI.ViewModels/TopLevelCommandManager.cs @@ -5,6 +5,7 @@ using System.Collections.Immutable; using System.Collections.ObjectModel; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; using CommunityToolkit.Mvvm.Messaging; @@ -19,13 +20,18 @@ using Microsoft.Extensions.DependencyInjection; namespace Microsoft.CmdPal.UI.ViewModels; -public partial class TopLevelCommandManager : ObservableObject, +public sealed partial class TopLevelCommandManager : ObservableObject, IRecipient, IRecipient, IRecipient, IRecipient, IDisposable { + private static readonly TimeSpan ExtensionStartTimeout = TimeSpan.FromSeconds(10); + private static readonly TimeSpan CommandLoadTimeout = TimeSpan.FromSeconds(10); + private static readonly TimeSpan BackgroundStartTimeout = TimeSpan.FromSeconds(60); + private static readonly TimeSpan BackgroundCommandLoadTimeout = TimeSpan.FromSeconds(60); + private readonly IServiceProvider _serviceProvider; private readonly ICommandProviderCache _commandProviderCache; private readonly TaskScheduler _taskScheduler; @@ -39,11 +45,14 @@ public partial class TopLevelCommandManager : ObservableObject, // deadlock. private readonly Lock _dockBandsLock = new(); private readonly SupersedingAsyncGate _reloadCommandsGate; + private CancellationTokenSource _extensionLoadCts = new(); + private CancellationToken _currentExtensionLoadCancellationToken; public TopLevelCommandManager(IServiceProvider serviceProvider, ICommandProviderCache commandProviderCache) { _serviceProvider = serviceProvider; _commandProviderCache = commandProviderCache; + _currentExtensionLoadCancellationToken = _extensionLoadCts.Token; _taskScheduler = _serviceProvider.GetService()!; WeakReferenceMessenger.Default.Register(this); WeakReferenceMessenger.Default.Register(this); @@ -260,8 +269,15 @@ public partial class TopLevelCommandManager : ObservableObject, private async Task ReloadAllCommandsAsyncCore(CancellationToken cancellationToken) { IsLoading = true; + + // Invalidate any background continuations from the previous load cycle + await _extensionLoadCts.CancelAsync().ConfigureAwait(false); + _extensionLoadCts.Dispose(); + _extensionLoadCts = new(); + _currentExtensionLoadCancellationToken = _extensionLoadCts.Token; + var extensionService = _serviceProvider.GetService()!; - await extensionService.SignalStopExtensionsAsync(); + await extensionService.SignalStopExtensionsAsync().ConfigureAwait(false); lock (TopLevelCommands) { @@ -273,8 +289,8 @@ public partial class TopLevelCommandManager : ObservableObject, DockBands.Clear(); } - await LoadBuiltinsAsync(); - _ = Task.Run(LoadExtensionsAsync); + await LoadBuiltinsAsync().ConfigureAwait(false); + _ = Task.Run(LoadExtensionsAsync, cancellationToken); } // Load commands from our extensions. Called on a background thread. @@ -292,16 +308,15 @@ public partial class TopLevelCommandManager : ObservableObject, extensionService.OnExtensionAdded -= ExtensionService_OnExtensionAdded; extensionService.OnExtensionRemoved -= ExtensionService_OnExtensionRemoved; - var extensions = (await extensionService.GetInstalledExtensionsAsync()).ToImmutableList(); + var ct = _currentExtensionLoadCancellationToken; + + var extensions = (await extensionService.GetInstalledExtensionsAsync().ConfigureAwait(false)).ToImmutableList(); lock (_commandProvidersLock) { _extensionCommandProviders.Clear(); } - if (extensions is not null) - { - await StartExtensionsAndGetCommands(extensions); - } + await StartExtensionsAndGetCommands(extensions, ct).ConfigureAwait(false); extensionService.OnExtensionAdded += ExtensionService_OnExtensionAdded; extensionService.OnExtensionRemoved += ExtensionService_OnExtensionRemoved; @@ -316,46 +331,219 @@ public partial class TopLevelCommandManager : ObservableObject, private void ExtensionService_OnExtensionAdded(IExtensionService sender, IEnumerable extensions) { + var ct = _currentExtensionLoadCancellationToken; + // When we get an extension install event, hop off to a BG thread - _ = Task.Run(async () => - { - // for each newly installed extension, start it and get commands - // from it. One single package might have more than one - // IExtensionWrapper in it. - await StartExtensionsAndGetCommands(extensions); - }); + _ = Task.Run( + async () => + { + // for each newly installed extension, start it and get commands + // from it. One single package might have more than one + // IExtensionWrapper in it. + await StartExtensionsAndGetCommands(extensions, ct).ConfigureAwait(false); + }, + ct); } - private async Task StartExtensionsAndGetCommands(IEnumerable extensions) + private async Task StartExtensionsAndGetCommands(IEnumerable extensions, CancellationToken ct) { - var timer = new Stopwatch(); - timer.Start(); + var timer = Stopwatch.StartNew(); // Start all extensions in parallel - var startTasks = extensions.Select(StartExtensionWithTimeoutAsync); + var startResults = await Task.WhenAll(extensions.Select(TryStartExtensionAsync)).ConfigureAwait(false); - // Wait for all extensions to start - var wrappers = (await Task.WhenAll(startTasks)).Where(wrapper => wrapper is not null).Select(w => w!).ToList(); + var startedWrappers = new List(); + foreach (var r in startResults) + { + if (r.IsStarted) + { + startedWrappers.Add(r.Wrapper); + } + else if (r.IsTimedOut) + { + _ = StartExtensionWhenReadyAsync(r.Extension, r.PendingStartTask, r.Stopwatch, ct); + } + } + // Register started extensions and load their commands + var loadSummary = await RegisterAndLoadCommandsAsync(startedWrappers, ct).ConfigureAwait(false); + + timer.Stop(); + Logger.LogInfo($"Loaded {loadSummary.CommandCount} command(s) and {loadSummary.DockBandCount} band(s) from {startedWrappers.Count} extension(s) in {timer.ElapsedMilliseconds} ms"); + } + + private async Task RegisterAndLoadCommandsAsync(ICollection wrappers, CancellationToken ct) + { lock (_commandProvidersLock) { _extensionCommandProviders.AddRange(wrappers); } // Load the commands from the providers in parallel - var loadTasks = wrappers.Select(LoadCommandsWithTimeoutAsync); + var loadResults = await Task.WhenAll(wrappers.Select(w => TryLoadCommandsAsync(w, ct))).ConfigureAwait(false); - var commandSets = (await Task.WhenAll(loadTasks)).Where(results => results is not null).Select(r => r!).ToList(); + var totalCommands = 0; + var totalDockBands = 0; + var timedOut = new List(); + List commandsToAdd = []; + List dockBandsToAdd = []; - foreach (var providerObjects in commandSets) + foreach (var r in loadResults) { - var commandsCount = providerObjects.Commands?.Count() ?? 0; - var bandsCount = providerObjects.DockBands?.Count() ?? 0; - Logger.LogDebug($"(some provider) Loaded {commandsCount} commands and {bandsCount} bands"); - - lock (TopLevelCommands) + if (r.IsLoaded) { - if (providerObjects.Commands is IEnumerable commands) + var commands = r.TopLevelObjectSets.Commands; + if (commands is not null) + { + foreach (var c in commands) + { + commandsToAdd.Add(c); + totalCommands++; + } + } + + var bands = r.TopLevelObjectSets.DockBands; + if (bands is not null) + { + foreach (var b in bands) + { + dockBandsToAdd.Add(b); + totalDockBands++; + } + } + } + else if (r.IsTimedOut) + { + timedOut.Add(r); + } + } + + lock (TopLevelCommands) + { + foreach (var c in commandsToAdd) + { + TopLevelCommands.Add(c); + } + } + + lock (_dockBandsLock) + { + foreach (var b in dockBandsToAdd) + { + DockBands.Add(b); + } + } + + // Fire background continuations for timed-out loads outside the lock + foreach (var r in timedOut) + { + // It's weird to repeat the condition here, but it allows the compiler to track nullability of other properties + if (r.IsTimedOut) + { + _ = AppendCommandsWhenReadyAsync(r.Wrapper, r.PendingLoadTask, r.Stopwatch, ct); + } + } + + return new RegisterAndLoadSummary(totalCommands, totalDockBands); + } + + private async Task TryStartExtensionAsync(IExtensionWrapper extension) + { + Logger.LogDebug($"Starting {extension.PackageFullName}"); + var sw = Stopwatch.StartNew(); + var ct = _currentExtensionLoadCancellationToken; + var startTask = extension.StartExtensionAsync(); + try + { + await startTask.WaitAsync(ExtensionStartTimeout, ct).ConfigureAwait(false); + Logger.LogInfo($"Started extension {extension.PackageFullName} in {sw.ElapsedMilliseconds} ms"); + return ExtensionStartResult.Started(extension, new CommandProviderWrapper(extension, _taskScheduler, _commandProviderCache)); + } + catch (TimeoutException) + { + Logger.LogWarning($"Starting extension {extension.PackageFullName} timed out after {sw.ElapsedMilliseconds} ms, continuing in background"); + return ExtensionStartResult.TimedOut(extension, startTask, sw); + } + catch (OperationCanceledException) + { + Logger.LogDebug($"Starting extension {extension.PackageFullName} was cancelled after {sw.ElapsedMilliseconds} ms"); + return ExtensionStartResult.Failed(extension); + } + catch (Exception ex) + { + Logger.LogError($"Failed to start extension {extension.PackageFullName} after {sw.ElapsedMilliseconds} ms: {ex}"); + return ExtensionStartResult.Failed(extension); + } + } + + private async Task StartExtensionWhenReadyAsync( + IExtensionWrapper extension, + Task startTask, + Stopwatch sw, + CancellationToken ct) + { + try + { + await startTask.WaitAsync(BackgroundStartTimeout, ct).ConfigureAwait(false); + + var wrapper = new CommandProviderWrapper(extension, _taskScheduler, _commandProviderCache); + Logger.LogInfo($"Late-started extension {extension.PackageFullName} in {sw.ElapsedMilliseconds} ms, loading commands and bands"); + + await RegisterAndLoadCommandsAsync([wrapper], ct).ConfigureAwait(false); + } + catch (OperationCanceledException) + { + // Reload happened -- discard stale results + } + catch (Exception ex) + { + Logger.LogError($"Background start/load of extension {extension.PackageFullName} failed after {sw.ElapsedMilliseconds} ms: {ex}"); + } + } + + private async Task TryLoadCommandsAsync(CommandProviderWrapper wrapper, CancellationToken ct) + { + var sw = Stopwatch.StartNew(); + var loadTask = LoadTopLevelCommandsFromProvider(wrapper); + try + { + var result = await loadTask.WaitAsync(CommandLoadTimeout, ct).ConfigureAwait(false); + var commandCount = result.Commands?.Count ?? 0; + var dockBandCount = result.DockBands?.Count ?? 0; + Logger.LogInfo($"Loaded {commandCount} command(s) and {dockBandCount} band(s) from {wrapper.ExtensionHost?.Extension?.PackageFullName} in {sw.ElapsedMilliseconds} ms"); + return CommandLoadResult.Loaded(wrapper, result); + } + catch (TimeoutException) + { + Logger.LogWarning($"Loading commands and bands from {wrapper.ExtensionHost?.Extension?.PackageFullName} timed out after {sw.ElapsedMilliseconds} ms, continuing in background"); + return CommandLoadResult.TimedOut(wrapper, loadTask, sw); + } + catch (OperationCanceledException) + { + Logger.LogDebug($"Loading commands and bands from {wrapper.ExtensionHost?.Extension?.PackageFullName} was cancelled after {sw.ElapsedMilliseconds} ms"); + return CommandLoadResult.Failed(wrapper); + } + catch (Exception ex) + { + Logger.LogError($"Failed to load commands and bands for extension {wrapper.ExtensionHost?.Extension?.PackageFullName} after {sw.ElapsedMilliseconds} ms: {ex}"); + return CommandLoadResult.Failed(wrapper); + } + } + + private async Task AppendCommandsWhenReadyAsync( + CommandProviderWrapper wrapper, + Task loadTask, + Stopwatch sw, + CancellationToken ct) + { + try + { + var topLevelObjectSets = await loadTask.WaitAsync(BackgroundCommandLoadTimeout, ct).ConfigureAwait(false); + + var commands = topLevelObjectSets.Commands; + if (commands is not null) + { + lock (TopLevelCommands) { foreach (var c in commands) { @@ -364,57 +552,30 @@ public partial class TopLevelCommandManager : ObservableObject, } } - lock (_dockBandsLock) + var dockBands = topLevelObjectSets.DockBands; + if (dockBands is not null) { - if (providerObjects.DockBands is IEnumerable bands) + lock (_dockBandsLock) { - foreach (var c in bands) + foreach (var band in dockBands) { - DockBands.Add(c); + DockBands.Add(band); } } } + + Logger.LogInfo($"Late-loaded {commands?.Count ?? 0} command(s) and {dockBands?.Count ?? 0} band(s) from {wrapper.ExtensionHost?.Extension?.PackageFullName} in {sw.ElapsedMilliseconds} ms"); } - - timer.Stop(); - Logger.LogDebug($"Loading extensions took {timer.ElapsedMilliseconds} ms"); - } - - private async Task StartExtensionWithTimeoutAsync(IExtensionWrapper extension) - { - Logger.LogDebug($"Starting {extension.PackageFullName}"); - try + catch (OperationCanceledException) { - await extension.StartExtensionAsync().WaitAsync(TimeSpan.FromSeconds(10)); - return new CommandProviderWrapper(extension, _taskScheduler, _commandProviderCache); + // Reload happened - discard stale results } catch (Exception ex) { - Logger.LogError($"Failed to start extension {extension.PackageFullName}: {ex}"); - return null; // Return null for failed extensions + Logger.LogError($"Background loading of commands and bands from {wrapper.ExtensionHost?.Extension?.PackageFullName} failed after {sw.ElapsedMilliseconds} ms: {ex}"); } } - private record TopLevelObjectSets(IEnumerable? Commands, IEnumerable? DockBands); - - private async Task LoadCommandsWithTimeoutAsync(CommandProviderWrapper wrapper) - { - try - { - return await LoadTopLevelCommandsFromProvider(wrapper!).WaitAsync(TimeSpan.FromSeconds(10)); - } - catch (TimeoutException) - { - Logger.LogError($"Loading commands from {wrapper!.ExtensionHost?.Extension?.PackageFullName} timed out"); - } - catch (Exception ex) - { - Logger.LogError($"Failed to load commands for extension {wrapper!.ExtensionHost?.Extension?.PackageFullName}: {ex}"); - } - - return null; - } - private void ExtensionService_OnExtensionRemoved(IExtensionService sender, IEnumerable extensions) { // When we get an extension uninstall event, hop off to a BG thread @@ -515,7 +676,7 @@ public partial class TopLevelCommandManager : ObservableObject, } public void Receive(ReloadCommandsMessage message) => - ReloadAllCommandsAsync().ConfigureAwait(false); + _ = ReloadAllCommandsAsync(); public void Receive(PinCommandItemMessage message) { @@ -611,7 +772,87 @@ public partial class TopLevelCommandManager : ObservableObject, public void Dispose() { + _extensionLoadCts.Cancel(); + _extensionLoadCts.Dispose(); _reloadCommandsGate.Dispose(); GC.SuppressFinalize(this); } + + private sealed class ExtensionStartResult + { + public IExtensionWrapper Extension { get; } + + public CommandProviderWrapper? Wrapper { get; private init; } + + public Task? PendingStartTask { get; private init; } + + public Stopwatch? Stopwatch { get; private init; } + + [MemberNotNullWhen(true, nameof(Wrapper))] + public bool IsStarted => Wrapper is not null; + + [MemberNotNullWhen(true, nameof(PendingStartTask), nameof(Stopwatch))] + public bool IsTimedOut => PendingStartTask is not null; + + private ExtensionStartResult(IExtensionWrapper extension) + { + Extension = extension; + } + + public static ExtensionStartResult Started(IExtensionWrapper extension, CommandProviderWrapper wrapper) + { + return new ExtensionStartResult(extension) { Wrapper = wrapper }; + } + + public static ExtensionStartResult TimedOut(IExtensionWrapper extension, Task pendingStartTask, Stopwatch sw) + { + return new ExtensionStartResult(extension) { PendingStartTask = pendingStartTask, Stopwatch = sw }; + } + + public static ExtensionStartResult Failed(IExtensionWrapper extension) + { + return new ExtensionStartResult(extension); + } + } + + private sealed class CommandLoadResult + { + public TopLevelObjectSets? TopLevelObjectSets { get; private init; } + + public CommandProviderWrapper Wrapper { get; } + + public Task? PendingLoadTask { get; private init; } + + public Stopwatch? Stopwatch { get; private init; } + + [MemberNotNullWhen(true, nameof(TopLevelObjectSets))] + public bool IsLoaded => TopLevelObjectSets is not null; + + [MemberNotNullWhen(true, nameof(PendingLoadTask), nameof(Stopwatch))] + public bool IsTimedOut => PendingLoadTask is not null; + + private CommandLoadResult(CommandProviderWrapper wrapper) + { + Wrapper = wrapper; + } + + public static CommandLoadResult Loaded(CommandProviderWrapper wrapper, TopLevelObjectSets topLevelObjectSets) + { + return new CommandLoadResult(wrapper) { TopLevelObjectSets = topLevelObjectSets }; + } + + public static CommandLoadResult TimedOut(CommandProviderWrapper wrapper, Task pendingLoadTask, Stopwatch sw) + { + return new CommandLoadResult(wrapper) { PendingLoadTask = pendingLoadTask, Stopwatch = sw }; + } + + public static CommandLoadResult Failed(CommandProviderWrapper wrapper) + { + return new CommandLoadResult(wrapper); + } + } + + private readonly record struct RegisterAndLoadSummary(int CommandCount, int DockBandCount); + + private record TopLevelObjectSets(ICollection? Commands, ICollection? DockBands); }