Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions src/windows/wslcsession/WSLCSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Module Name:
using namespace wsl::windows::common;
using relay::MultiHandleWait;
using wsl::shared::Localization;
using wsl::windows::service::wslc::UserCOMCallback;
using wsl::windows::service::wslc::UserHandle;
using wsl::windows::service::wslc::WSLCSession;
using wsl::windows::service::wslc::WSLCVirtualMachine;
Expand Down Expand Up @@ -220,6 +221,48 @@ HANDLE UserHandle::Get() const noexcept
return m_handle;
}

UserCOMCallback::UserCOMCallback(WSLCSession& Session) noexcept : m_session(&Session), m_threadId(GetCurrentThreadId())
{
LOG_IF_FAILED(CoEnableCallCancellation(nullptr));
Comment thread
OneBlue marked this conversation as resolved.
Outdated
}

UserCOMCallback::UserCOMCallback(UserCOMCallback&& Other) noexcept
{
*this = std::move(Other);
}

UserCOMCallback& UserCOMCallback::operator=(UserCOMCallback&& Other) noexcept
{
if (this != &Other)
{
Reset();
m_session = Other.m_session;
m_threadId = Other.m_threadId;

Other.m_threadId = 0;
Other.m_session = nullptr;
}
return *this;
}

void UserCOMCallback::Reset() noexcept
{
if (m_threadId != 0)
{
WI_ASSERT(m_session != nullptr);

m_session->UnregisterUserCOMCallback(m_threadId);
m_threadId = 0;

LOG_IF_FAILED(CoDisableCallCancellation(nullptr));
}
}

UserCOMCallback::~UserCOMCallback() noexcept
{
Reset();
Comment thread
OneBlue marked this conversation as resolved.
}

HRESULT WSLCSession::GetProcessHandle(_Out_ HANDLE* ProcessHandle)
try
{
Expand Down Expand Up @@ -443,6 +486,12 @@ void WSLCSession::StreamImageOperation(DockerHTTPClient::HTTPRequestContext& req
bool isJson = false;
};

std::optional<UserCOMCallback> comCall;
if (ProgressCallback != nullptr)
{
comCall = RegisterUserCOMCallback();
}

Comment thread
OneBlue marked this conversation as resolved.
std::optional<Response> httpResponse;

auto onHttpResponse = [&](const boost::beast::http::message<false, boost::beast::http::buffer_body>& response) {
Expand Down Expand Up @@ -589,6 +638,12 @@ try

auto buildFileHandle = OpenUserHandle(Options->DockerfileHandle);

std::optional<UserCOMCallback> comCall;
if (ProgressCallback != nullptr)
{
comCall = RegisterUserCOMCallback();
}

Comment thread
OneBlue marked this conversation as resolved.
auto lock = m_lock.lock_shared();

THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_virtualMachine);
Expand Down Expand Up @@ -1918,6 +1973,14 @@ try
CancelUserHandleIO();
}

{
std::lock_guard comLock(m_userCOMCallbacksLock);

// Cancel any pending outgoing COM callback calls (e.g. IProgressCallback::OnProgress)
// to unblock operations waiting for cross-process COM responses.
CancelUserCOMCallbacks();
Comment thread
OneBlue marked this conversation as resolved.
}
Comment thread
OneBlue marked this conversation as resolved.

// Acquire an exclusive lock to ensure that no operation is running.
auto lock = m_lock.lock_exclusive();
std::lock_guard containersLock(m_containersLock);
Expand Down Expand Up @@ -2154,6 +2217,39 @@ void WSLCSession::CancelUserHandleIO()
}
}

UserCOMCallback WSLCSession::RegisterUserCOMCallback()
{
std::lock_guard lock(m_userCOMCallbacksLock);

// Don't allow new COM calls if the session is terminating.
// N.B. This check must happen under m_userCOMCallbacksLock to synchronize with Terminate().
THROW_HR_IF_MSG(
E_ABORT, m_sessionTerminatingEvent.is_signaled(), "Refusing to make a COM callback while the session is terminating.");

auto [_, inserted] = m_userCOMCallbackThreads.insert(GetCurrentThreadId());
WI_ASSERT(inserted);

return UserCOMCallback{*this};
Comment thread
OneBlue marked this conversation as resolved.
}

void WSLCSession::UnregisterUserCOMCallback(DWORD ThreadId)
{
std::lock_guard lock(m_userCOMCallbacksLock);

auto it = std::ranges::find(m_userCOMCallbackThreads, ThreadId);
WI_ASSERT(it != m_userCOMCallbackThreads.end());
Comment thread
OneBlue marked this conversation as resolved.
Outdated

m_userCOMCallbackThreads.erase(it);
Comment thread
OneBlue marked this conversation as resolved.
}
Comment thread
OneBlue marked this conversation as resolved.

void WSLCSession::CancelUserCOMCallbacks()
{
for (auto threadId : m_userCOMCallbackThreads)
{
Comment thread
OneBlue marked this conversation as resolved.
LOG_IF_FAILED(CoCancelCall(threadId, 0));
}
}

void WSLCSession::OnContainerDeleted(const WSLCContainerImpl* Container)
{
auto lock = m_lock.lock_shared();
Expand Down
27 changes: 27 additions & 0 deletions src/windows/wslcsession/WSLCSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,24 @@ class UserHandle
HANDLE m_handle{};
};

class UserCOMCallback
{
NON_COPYABLE(UserCOMCallback);

public:
UserCOMCallback(WSLCSession& Session) noexcept;
UserCOMCallback(UserCOMCallback&& Other) noexcept;

~UserCOMCallback() noexcept;

UserCOMCallback& operator=(UserCOMCallback&& Other) noexcept;
void Reset() noexcept;

private:
WSLCSession* m_session{};
DWORD m_threadId{};
};
Comment on lines +51 to +67
Copy link

Copilot AI Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UserCOMCallback references WSLCSession before WSLCSession is declared in this header. This is a compile error unless there is an earlier forward declaration not shown; add class WSLCSession; before UserCOMCallback (or move UserCOMCallback below the WSLCSession declaration).

Copilot uses AI. Check for mistakes.

//
// WSLCSession - Implements IWSLCSession for container management.
// Runs in a per-user COM server process for security isolation.
Expand Down Expand Up @@ -125,10 +143,14 @@ class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLCSession
UserHandle OpenUserHandle(WSLCHandle Handle);
void ReleaseUserHandle(HANDLE Handle);

UserCOMCallback RegisterUserCOMCallback();
void UnregisterUserCOMCallback(DWORD ThreadId);

private:
ULONG m_id = 0;

__requires_lock_held(m_userHandlesLock) void CancelUserHandleIO();
__requires_lock_held(m_userCOMCallbacksLock) void CancelUserCOMCallbacks();
void ConfigureStorage(const WSLCSessionInitSettings& Settings, PSID UserSid);
void Ext4Format(const std::string& Device);
void OnContainerDeleted(const WSLCContainerImpl* Container);
Expand Down Expand Up @@ -169,6 +191,11 @@ class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLCSession
std::mutex m_userHandlesLock;
__guarded_by(m_userHandlesLock) std::vector<HANDLE> m_userHandles;

// Threads currently inside an outgoing COM callback (e.g. IProgressCallback::OnProgress).
// Used by Terminate() to cancel stuck cross-process calls via CoCancelCall().
std::mutex m_userCOMCallbacksLock;
__guarded_by(m_userCOMCallbacksLock) std::set<DWORD> m_userCOMCallbackThreads;

// Used for testing only.
std::mutex m_allocatedPortsLock;
__guarded_by(m_allocatedPortsLock) std::map<uint16_t, std::pair<std::shared_ptr<VmPortAllocation>, size_t>> m_allocatedPorts;
Expand Down
82 changes: 82 additions & 0 deletions test/windows/WSLCTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2497,6 +2497,88 @@ class WSLCTests
VERIFY_ARE_NOT_EQUAL(details, L"");
}

WSLC_TEST_METHOD(BuildImageStuckCallbackCancellation)
{
class StuckBuildProgressCallback
: public Microsoft::WRL::RuntimeClass<Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>, IProgressCallback>
{
public:
StuckBuildProgressCallback(std::promise<void>& reachedPromise, wil::unique_event& exitEvent) :
m_reachedPromise(reachedPromise), m_exitEvent(exitEvent)
{
}

HRESULT OnProgress(LPCSTR, LPCSTR, ULONGLONG, ULONGLONG) override
{
if (!m_signaled)
{
m_signaled = true;
m_reachedPromise.set_value();
m_exitEvent.wait(); // Block until this test case is complete.
}

return S_OK;
}

private:
std::promise<void>& m_reachedPromise;
wil::unique_event& m_exitEvent;
bool m_signaled{};
};

auto contextDir = std::filesystem::current_path() / "build-context-stuck-callback";
std::filesystem::create_directories(contextDir);
auto cleanup = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() {
std::error_code ec;
std::filesystem::remove_all(contextDir, ec);
});

{
std::ofstream dockerfile(contextDir / "Dockerfile");
dockerfile << "FROM debian:latest\n";
dockerfile << "RUN echo hello\n";
Comment thread
OneBlue marked this conversation as resolved.
}

auto contextPathStr = contextDir.wstring();
auto dockerfileHandle = wil::open_file((contextDir / "Dockerfile").c_str());

WSLCBuildImageOptions options{
.ContextPath = contextPathStr.c_str(),
.DockerfileHandle = ToCOMInputHandle(dockerfileHandle.get()),
.Flags = WSLCBuildImageFlagsVerbose,
};

std::promise<void> callbackReached;
wil::unique_event exitEvent{wil::EventOptions::ManualReset};
auto callback = Microsoft::WRL::Make<StuckBuildProgressCallback>(callbackReached, exitEvent);

std::promise<HRESULT> buildResult;
std::thread buildThread(
[&]() { buildResult.set_value(m_defaultSession->BuildImage(&options, callback.Get(), exitEvent.get())); });
Comment thread
OneBlue marked this conversation as resolved.

auto joinThread = wil::scope_exit([&]() {
exitEvent.SetEvent();
buildThread.join();
});

// Wait for the progress callback to be called, proving the COM call is in flight.
auto reachedFuture = callbackReached.get_future();
auto reachedStatus = reachedFuture.wait_for(std::chrono::seconds(60));
VERIFY_ARE_EQUAL(reachedStatus, std::future_status::ready);
Comment thread
OneBlue marked this conversation as resolved.

// Terminate the session while the callback is stuck.
// This should cancel the pending COM call and unblock BuildImage.
VERIFY_SUCCEEDED(m_defaultSession->Terminate());
ResetTestSession();

auto buildFuture = buildResult.get_future();
auto buildStatus = buildFuture.wait_for(std::chrono::seconds(60));
VERIFY_ARE_EQUAL(buildStatus, std::future_status::ready);
Comment thread
OneBlue marked this conversation as resolved.

// BuildImage should have failed due to COM call cancellation.
VERIFY_FAILED(buildFuture.get());
}

WSLC_TEST_METHOD(InteractiveShell)
{
WSLCProcessLauncher launcher("/bin/sh", {"/bin/sh"}, {"TERM=xterm-256color"}, WSLCProcessFlagsTty | WSLCProcessFlagsStdin);
Expand Down
Loading