diff --git a/installer/PowerToysSetupCustomActionsVNext/CustomAction.cpp b/installer/PowerToysSetupCustomActionsVNext/CustomAction.cpp index 43919ecaf1..d6a27d9fb1 100644 --- a/installer/PowerToysSetupCustomActionsVNext/CustomAction.cpp +++ b/installer/PowerToysSetupCustomActionsVNext/CustomAction.cpp @@ -10,11 +10,10 @@ #include "../../src/common/utils/gpo.h" #include "../../src/common/utils/MsiUtils.h" #include "../../src/common/utils/modulesRegistry.h" -#include "../../src/common/updating/installer.h" -#include "../../src/common/version/version.h" +#include "../../src/common/utils/version.h" #include "../../src/common/Telemetry/EtwTrace/EtwTrace.h" -#include "../../src/common/utils/package.h" #include "../../src/common/utils/clean_video_conference.h" +#include "../../src/common/utils/package.h" #include #include @@ -26,6 +25,7 @@ #include #include + using namespace std; HINSTANCE DLL_HANDLE = nullptr; diff --git a/src/common/utils/package.h b/src/common/utils/package.h new file mode 100644 index 0000000000..27676101c9 --- /dev/null +++ b/src/common/utils/package.h @@ -0,0 +1,504 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../logger/logger.h" +#include "../utils/version.h" + +namespace package +{ + using winrt::Windows::ApplicationModel::Package; + using winrt::Windows::Foundation::IAsyncOperationWithProgress; + using winrt::Windows::Foundation::AsyncStatus; + using winrt::Windows::Foundation::Uri; + using winrt::Windows::Foundation::Collections::IVector; + using winrt::Windows::Management::Deployment::AddPackageOptions; + using winrt::Windows::Management::Deployment::DeploymentOptions; + using winrt::Windows::Management::Deployment::DeploymentProgress; + using winrt::Windows::Management::Deployment::DeploymentResult; + using winrt::Windows::Management::Deployment::PackageManager; + using Microsoft::WRL::ComPtr; + + inline BOOL IsWin11OrGreater() + { + OSVERSIONINFOEX osvi{}; + DWORDLONG dwlConditionMask = 0; + byte op = VER_GREATER_EQUAL; + + // Initialize the OSVERSIONINFOEX structure. + osvi.dwOSVersionInfoSize = sizeof(OSVERSIONINFOEX); + osvi.dwMajorVersion = HIBYTE(_WIN32_WINNT_WINTHRESHOLD); + osvi.dwMinorVersion = LOBYTE(_WIN32_WINNT_WINTHRESHOLD); + // Windows 11 build number + osvi.dwBuildNumber = 22000; + + // Initialize the condition mask. + VER_SET_CONDITION(dwlConditionMask, VER_MAJORVERSION, op); + VER_SET_CONDITION(dwlConditionMask, VER_MINORVERSION, op); + VER_SET_CONDITION(dwlConditionMask, VER_BUILDNUMBER, op); + + // Perform the test. + return VerifyVersionInfo( + &osvi, + VER_MAJORVERSION | VER_MINORVERSION | VER_BUILDNUMBER, + dwlConditionMask); + } + + struct PACKAGE_VERSION + { + UINT16 Major; + UINT16 Minor; + UINT16 Build; + UINT16 Revision; + }; + + class ComInitializer + { + public: + explicit ComInitializer(DWORD coInitFlags = COINIT_MULTITHREADED) : + _initialized(false) + { + const HRESULT hr = CoInitializeEx(nullptr, coInitFlags); + _initialized = SUCCEEDED(hr); + } + + ~ComInitializer() + { + if (_initialized) + { + CoUninitialize(); + } + } + + bool Succeeded() const { return _initialized; } + + private: + bool _initialized; + }; + + inline bool GetPackageNameAndVersionFromAppx( + const std::wstring& appxPath, + std::wstring& outName, + PACKAGE_VERSION& outVersion) + { + try + { + ComInitializer comInit; + if (!comInit.Succeeded()) + { + Logger::error(L"COM initialization failed."); + return false; + } + + ComPtr factory; + ComPtr stream; + ComPtr reader; + ComPtr manifest; + ComPtr packageId; + + HRESULT hr = CoCreateInstance(__uuidof(AppxFactory), nullptr, CLSCTX_INPROC_SERVER, IID_PPV_ARGS(&factory)); + if (FAILED(hr)) + return false; + + hr = SHCreateStreamOnFileEx(appxPath.c_str(), STGM_READ | STGM_SHARE_DENY_WRITE, FILE_ATTRIBUTE_NORMAL, FALSE, nullptr, &stream); + if (FAILED(hr)) + return false; + + hr = factory->CreatePackageReader(stream.Get(), &reader); + if (FAILED(hr)) + return false; + + hr = reader->GetManifest(&manifest); + if (FAILED(hr)) + return false; + + hr = manifest->GetPackageId(&packageId); + if (FAILED(hr)) + return false; + + LPWSTR name = nullptr; + hr = packageId->GetName(&name); + if (FAILED(hr)) + return false; + + UINT64 version = 0; + hr = packageId->GetVersion(&version); + if (FAILED(hr)) + return false; + + outName = std::wstring(name); + CoTaskMemFree(name); + + outVersion.Major = static_cast((version >> 48) & 0xFFFF); + outVersion.Minor = static_cast((version >> 32) & 0xFFFF); + outVersion.Build = static_cast((version >> 16) & 0xFFFF); + outVersion.Revision = static_cast(version & 0xFFFF); + + Logger::info(L"Package name: {}, version: {}.{}.{}.{}, appxPath: {}", + outName, + outVersion.Major, + outVersion.Minor, + outVersion.Build, + outVersion.Revision, + appxPath); + + return true; + } + catch (const std::exception& ex) + { + Logger::error(L"Standard exception: {}", winrt::to_hstring(ex.what())); + return false; + } + catch (...) + { + Logger::error(L"Unknown or non-standard exception occurred."); + return false; + } + } + + inline std::optional GetRegisteredPackage(std::wstring packageDisplayName, bool checkVersion) + { + PackageManager packageManager; + + for (const auto& package : packageManager.FindPackagesForUser({})) + { + const auto& packageFullName = std::wstring{ package.Id().FullName() }; + const auto& packageVersion = package.Id().Version(); + + if (packageFullName.contains(packageDisplayName)) + { + // If checkVersion is true, verify if the package has the same version as PowerToys. + if ((!checkVersion) || (packageVersion.Major == VERSION_MAJOR && packageVersion.Minor == VERSION_MINOR && packageVersion.Revision == VERSION_REVISION)) + { + return { package }; + } + } + } + + return {}; + } + + inline bool IsPackageRegisteredWithPowerToysVersion(std::wstring packageDisplayName) + { + return GetRegisteredPackage(packageDisplayName, true).has_value(); + } + + inline bool RegisterSparsePackage(const std::wstring& externalLocation, const std::wstring& sparsePkgPath) + { + try + { + Uri externalUri{ externalLocation }; + Uri packageUri{ sparsePkgPath }; + + PackageManager packageManager; + + // Declare use of an external location + AddPackageOptions options; + options.ExternalLocationUri(externalUri); + options.ForceUpdateFromAnyVersion(true); + + IAsyncOperationWithProgress deploymentOperation = packageManager.AddPackageByUriAsync(packageUri, options); + deploymentOperation.get(); + + // Check the status of the operation + if (deploymentOperation.Status() == AsyncStatus::Error) + { + auto deploymentResult{ deploymentOperation.GetResults() }; + auto errorCode = deploymentOperation.ErrorCode(); + auto errorText = deploymentResult.ErrorText(); + + Logger::error(L"Register {} package failed. ErrorCode: {}, ErrorText: {}", sparsePkgPath, std::to_wstring(errorCode), errorText); + return false; + } + else if (deploymentOperation.Status() == AsyncStatus::Canceled) + { + Logger::error(L"Register {} package canceled.", sparsePkgPath); + return false; + } + else if (deploymentOperation.Status() == AsyncStatus::Completed) + { + Logger::info(L"Register {} package completed.", sparsePkgPath); + } + else + { + Logger::debug(L"Register {} package started.", sparsePkgPath); + } + + return true; + } + catch (std::exception& e) + { + Logger::error("Exception thrown while trying to register package: {}", e.what()); + + return false; + } + } + + inline bool UnRegisterPackage(const std::wstring& pkgDisplayName) + { + try + { + PackageManager packageManager; + const static auto packages = packageManager.FindPackagesForUser({}); + + for (auto const& package : packages) + { + const auto& packageFullName = std::wstring{ package.Id().FullName() }; + + if (packageFullName.contains(pkgDisplayName)) + { + auto deploymentOperation{ packageManager.RemovePackageAsync(packageFullName) }; + deploymentOperation.get(); + + // Check the status of the operation + if (deploymentOperation.Status() == AsyncStatus::Error) + { + auto deploymentResult{ deploymentOperation.GetResults() }; + auto errorCode = deploymentOperation.ErrorCode(); + auto errorText = deploymentResult.ErrorText(); + + Logger::error(L"Unregister {} package failed. ErrorCode: {}, ErrorText: {}", packageFullName, std::to_wstring(errorCode), errorText); + } + else if (deploymentOperation.Status() == AsyncStatus::Canceled) + { + Logger::error(L"Unregister {} package canceled.", packageFullName); + } + else if (deploymentOperation.Status() == AsyncStatus::Completed) + { + Logger::info(L"Unregister {} package completed.", packageFullName); + } + else + { + Logger::debug(L"Unregister {} package started.", packageFullName); + } + + break; + } + } + } + catch (std::exception& e) + { + Logger::error("Exception thrown while trying to unregister package: {}", e.what()); + return false; + } + + return true; + } + + inline std::vector FindMsixFile(const std::wstring& directoryPath, bool recursive) + { + if (directoryPath.empty()) + { + return {}; + } + + if (!std::filesystem::exists(directoryPath)) + { + Logger::error(L"The directory '" + directoryPath + L"' does not exist."); + return {}; + } + + const std::regex pattern(R"(^.+\.(appx|msix|msixbundle)$)", std::regex_constants::icase); + std::vector matchedFiles; + + try + { + if (recursive) + { + for (const auto& entry : std::filesystem::recursive_directory_iterator(directoryPath)) + { + if (entry.is_regular_file()) + { + const auto& fileName = entry.path().filename().string(); + if (std::regex_match(fileName, pattern)) + { + matchedFiles.push_back(entry.path()); + } + } + } + } + else + { + for (const auto& entry : std::filesystem::directory_iterator(directoryPath)) + { + if (entry.is_regular_file()) + { + const auto& fileName = entry.path().filename().string(); + if (std::regex_match(fileName, pattern)) + { + matchedFiles.push_back(entry.path()); + } + } + } + } + + // Sort by package version in descending order (newest first) + std::sort(matchedFiles.begin(), matchedFiles.end(), [](const std::wstring& a, const std::wstring& b) { + std::wstring nameA, nameB; + PACKAGE_VERSION versionA{}, versionB{}; + + bool gotA = GetPackageNameAndVersionFromAppx(a, nameA, versionA); + bool gotB = GetPackageNameAndVersionFromAppx(b, nameB, versionB); + + // Files that failed to parse go to the end + if (!gotA) + return false; + if (!gotB) + return true; + + // Compare versions: Major, Minor, Build, Revision (descending) + if (versionA.Major != versionB.Major) + return versionA.Major > versionB.Major; + if (versionA.Minor != versionB.Minor) + return versionA.Minor > versionB.Minor; + if (versionA.Build != versionB.Build) + return versionA.Build > versionB.Build; + return versionA.Revision > versionB.Revision; + }); + } + catch (const std::exception& ex) + { + Logger::error("An error occurred while searching for MSIX files: " + std::string(ex.what())); + } + + return matchedFiles; + } + + inline bool IsPackageSatisfied(const std::wstring& appxPath) + { + std::wstring targetName; + PACKAGE_VERSION targetVersion{}; + + if (!GetPackageNameAndVersionFromAppx(appxPath, targetName, targetVersion)) + { + Logger::error(L"Failed to get package name and version from appx: " + appxPath); + return false; + } + + PackageManager pm; + + for (const auto& package : pm.FindPackagesForUser({})) + { + const auto& id = package.Id(); + if (std::wstring(id.Name()) == targetName) + { + const auto& version = id.Version(); + + if (version.Major > targetVersion.Major || + (version.Major == targetVersion.Major && version.Minor > targetVersion.Minor) || + (version.Major == targetVersion.Major && version.Minor == targetVersion.Minor && version.Build > targetVersion.Build) || + (version.Major == targetVersion.Major && version.Minor == targetVersion.Minor && version.Build == targetVersion.Build && version.Revision >= targetVersion.Revision)) + { + Logger::info( + L"Package {} is already satisfied with version {}.{}.{}.{}; target version {}.{}.{}.{}; appxPath: {}", + id.Name(), + version.Major, + version.Minor, + version.Build, + version.Revision, + targetVersion.Major, + targetVersion.Minor, + targetVersion.Build, + targetVersion.Revision, + appxPath); + return true; + } + } + } + + Logger::info( + L"Package {} is not satisfied. Target version: {}.{}.{}.{}; appxPath: {}", + targetName, + targetVersion.Major, + targetVersion.Minor, + targetVersion.Build, + targetVersion.Revision, + appxPath); + return false; + } + + inline bool RegisterPackage(std::wstring pkgPath, std::vector dependencies) + { + try + { + Uri packageUri{ pkgPath }; + + PackageManager packageManager; + + // Declare use of an external location + DeploymentOptions options = DeploymentOptions::ForceTargetApplicationShutdown; + + IVector uris = winrt::single_threaded_vector(); + if (!dependencies.empty()) + { + for (const auto& dependency : dependencies) + { + try + { + if (IsPackageSatisfied(dependency)) + { + Logger::info(L"Dependency already satisfied: {}", dependency); + } + else + { + uris.Append(Uri(dependency)); + } + } + catch (const winrt::hresult_error& ex) + { + Logger::error(L"Error creating Uri for dependency: %s", ex.message().c_str()); + } + } + } + + IAsyncOperationWithProgress deploymentOperation = packageManager.AddPackageAsync(packageUri, uris, options); + deploymentOperation.get(); + + // Check the status of the operation + if (deploymentOperation.Status() == AsyncStatus::Error) + { + auto deploymentResult{ deploymentOperation.GetResults() }; + auto errorCode = deploymentOperation.ErrorCode(); + auto errorText = deploymentResult.ErrorText(); + + Logger::error(L"Register {} package failed. ErrorCode: {}, ErrorText: {}", pkgPath, std::to_wstring(errorCode), errorText); + return false; + } + else if (deploymentOperation.Status() == AsyncStatus::Canceled) + { + Logger::error(L"Register {} package canceled.", pkgPath); + return false; + } + else if (deploymentOperation.Status() == AsyncStatus::Completed) + { + Logger::info(L"Register {} package completed.", pkgPath); + } + else + { + Logger::debug(L"Register {} package started.", pkgPath); + } + } + catch (std::exception& e) + { + Logger::error("Exception thrown while trying to register package: {}", e.what()); + + return false; + } + + return true; + } +} \ No newline at end of file