-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Expand file tree
/
Copy pathWSLCSessionManager.cpp
More file actions
317 lines (244 loc) · 12.2 KB
/
WSLCSessionManager.cpp
File metadata and controls
317 lines (244 loc) · 12.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
/*++
Copyright (c) Microsoft. All rights reserved.
Module Name:
WSLCSessionManager.cpp
Abstract:
Implementation for WSLCSessionManager.
Sessions run in a per-user COM server process for security isolation.
The SYSTEM service creates sessions via IWSLCSessionFactory which returns
both the session interface (for clients) and an IWSLCSessionReference
(for the service to track sessions via weak references).
Session lifetime:
- Non-persistent sessions: tracked via IWSLCSessionReference which holds
weak references. Sessions are cleaned up when all client refs are released.
- Persistent sessions: the service holds an additional strong IWSLCSession
reference to keep them alive until explicitly terminated.
A job object with JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE ensures that all
per-user COM server processes are automatically terminated if wslservice
crashes or exits unexpectedly.
--*/
#include "WSLCSessionManager.h"
#include "HcsVirtualMachine.h"
#include "wslutil.h"
using wsl::windows::service::wslc::CallingProcessTokenInfo;
using wsl::windows::service::wslc::HcsVirtualMachine;
using wsl::windows::service::wslc::WSLCSessionManagerImpl;
namespace wslutil = wsl::windows::common::wslutil;
WSLCSessionManagerImpl::~WSLCSessionManagerImpl()
{
// Terminate all sessions on shutdown.
// Call Terminate() directly rather than going through ForEachSession(),
// which would needlessly resolve weak references and call GetState().
// Terminate() already handles the "session is gone" case gracefully.
std::lock_guard lock(m_wslcSessionsLock);
for (auto& entry : m_sessions)
{
LOG_IF_FAILED(entry.Ref->Terminate());
}
}
void WSLCSessionManagerImpl::CreateSession(const WSLCSessionSettings* Settings, WSLCSessionFlags Flags, IWSLCSession** WslcSession)
{
// Ensure that the session display name is non-null and not too long.
THROW_HR_IF(E_INVALIDARG, Settings->DisplayName == nullptr);
THROW_HR_IF(E_INVALIDARG, wcslen(Settings->DisplayName) >= std::size(WSLCSessionInformation{}.DisplayName));
auto tokenInfo = GetCallingProcessTokenInfo();
std::lock_guard lock(m_wslcSessionsLock);
// Check for an existing session first.
auto result = ForEachSession<HRESULT>([&](auto& entry, const wil::com_ptr<IWSLCSession>& session) noexcept -> std::optional<HRESULT> {
if (!wsl::shared::string::IsEqual(entry.DisplayName.c_str(), Settings->DisplayName))
{
return {};
}
RETURN_HR_IF(HRESULT_FROM_WIN32(ERROR_ALREADY_EXISTS), WI_IsFlagClear(Flags, WSLCSessionFlagsOpenExisting));
RETURN_IF_FAILED(CheckTokenAccess(entry, tokenInfo));
RETURN_IF_FAILED(wil::com_copy_to_nothrow(session, WslcSession));
return S_OK;
});
if (result.has_value())
{
THROW_IF_FAILED(result.value());
return; // Existing session was opened.
}
wslutil::StopWatch stopWatch;
HRESULT creationResult = wil::ResultFromException([&]() {
// Get caller info.
const auto callerProcess = wslutil::OpenCallingProcess(PROCESS_QUERY_LIMITED_INFORMATION);
const ULONG sessionId = m_nextSessionId++;
const DWORD creatorPid = GetProcessId(callerProcess.get());
const auto userToken = wsl::windows::common::security::GetUserToken(TokenImpersonation);
// Create the VM in the SYSTEM service (privileged).
auto vm = Microsoft::WRL::Make<HcsVirtualMachine>(Settings);
// Launch per-user COM server factory and add it to our job object for crash cleanup.
auto factory = wslutil::CreateComServerAsUser<IWSLCSessionFactory>(__uuidof(WSLCSessionFactory), userToken.get());
AddSessionProcessToJobObject(factory.get());
// Create the session via the factory.
const auto sessionSettings = CreateSessionSettings(sessionId, creatorPid, Settings);
wil::com_ptr<IWSLCSession> session;
wil::com_ptr<IWSLCSessionReference> serviceRef;
THROW_IF_FAILED(factory->CreateSession(&sessionSettings, vm.Get(), &session, &serviceRef));
// Track the session via its service ref, along with metadata and security info.
m_sessions.push_back({std::move(serviceRef), sessionId, creatorPid, Settings->DisplayName, std::move(tokenInfo)});
// For persistent sessions, also hold a strong reference to keep them alive.
const bool persistent = WI_IsFlagSet(Flags, WSLCSessionFlagsPersistent);
if (persistent)
{
m_persistentSessions.emplace_back(sessionId, session);
}
*WslcSession = session.detach();
});
// This telemetry event is used to keep track of session creation performance (via CreationTimeMs) and failure reasons (via Result).
WSL_LOG_TELEMETRY(
"WSLCCreateSession",
PDT_ProductAndServiceUsage,
TraceLoggingValue(Settings->DisplayName, "Name"),
TraceLoggingValue(stopWatch.ElapsedMilliseconds(), "CreationTimeMs"),
TraceLoggingValue(creationResult, "Result"),
TraceLoggingValue(tokenInfo.Elevated, "Elevated"),
TraceLoggingValue(static_cast<uint32_t>(Flags), "Flags"));
THROW_IF_FAILED_MSG(creationResult, "Failed to create session: %ls", Settings->DisplayName);
}
void WSLCSessionManagerImpl::OpenSession(ULONG Id, IWSLCSession** Session)
{
auto tokenInfo = GetCallingProcessTokenInfo();
auto result = ForEachSession<HRESULT>([&](auto& entry, const wil::com_ptr<IWSLCSession>& session) noexcept -> std::optional<HRESULT> {
if (entry.SessionId != Id)
{
return {};
}
RETURN_IF_FAILED(CheckTokenAccess(entry, tokenInfo));
RETURN_IF_FAILED(wil::com_copy_to_nothrow(session, Session));
return S_OK;
});
THROW_IF_FAILED_MSG(result.value_or(HRESULT_FROM_WIN32(ERROR_NOT_FOUND)), "Session '%lu' not found", Id);
}
void WSLCSessionManagerImpl::OpenSessionByName(LPCWSTR DisplayName, IWSLCSession** Session)
{
auto tokenInfo = GetCallingProcessTokenInfo();
auto result = ForEachSession<HRESULT>([&](auto& entry, const wil::com_ptr<IWSLCSession>& session) noexcept -> std::optional<HRESULT> {
if (!wsl::shared::string::IsEqual(entry.DisplayName.c_str(), DisplayName))
{
return {};
}
RETURN_IF_FAILED(CheckTokenAccess(entry, tokenInfo));
RETURN_IF_FAILED(wil::com_copy_to_nothrow(session, Session));
return S_OK;
});
THROW_IF_FAILED_MSG(result.value_or(HRESULT_FROM_WIN32(ERROR_NOT_FOUND)), "Session '%ls' not found", DisplayName);
}
void WSLCSessionManagerImpl::ListSessions(_Out_ WSLCSessionInformation** Sessions, _Out_ ULONG* SessionsCount)
{
std::vector<WSLCSessionInformation> sessionInfo;
ForEachSession<void>([&](auto& entry, const auto&) noexcept {
try
{
wil::unique_hlocal_string sidString;
THROW_IF_WIN32_BOOL_FALSE(ConvertSidToStringSidW(entry.Owner.TokenInfo->User.Sid, &sidString));
auto& it = sessionInfo.emplace_back(WSLCSessionInformation{.SessionId = entry.SessionId, .CreatorPid = entry.CreatorPid});
wcscpy_s(it.Sid, _countof(it.Sid), sidString.get());
wcscpy_s(it.DisplayName, _countof(it.DisplayName), entry.DisplayName.c_str());
}
CATCH_LOG()
});
auto output = wil::make_unique_cotaskmem<WSLCSessionInformation[]>(sessionInfo.size());
memcpy(output.get(), sessionInfo.data(), sessionInfo.size() * sizeof(WSLCSessionInformation));
*Sessions = output.release();
*SessionsCount = static_cast<ULONG>(sessionInfo.size());
}
WSLCSessionInitSettings WSLCSessionManagerImpl::CreateSessionSettings(_In_ ULONG SessionId, _In_ DWORD CreatorPid, _In_ const WSLCSessionSettings* Settings)
{
WSLCSessionInitSettings sessionSettings{};
sessionSettings.SessionId = SessionId;
sessionSettings.CreatorPid = CreatorPid;
sessionSettings.DisplayName = Settings->DisplayName;
sessionSettings.StoragePath = Settings->StoragePath;
sessionSettings.MaximumStorageSizeMb = Settings->MaximumStorageSizeMb;
sessionSettings.BootTimeoutMs = Settings->BootTimeoutMs;
sessionSettings.NetworkingMode = Settings->NetworkingMode;
sessionSettings.FeatureFlags = Settings->FeatureFlags;
sessionSettings.RootVhdTypeOverride = Settings->RootVhdTypeOverride;
return sessionSettings;
}
void WSLCSessionManagerImpl::AddSessionProcessToJobObject(_In_ IWSLCSessionFactory* Factory)
{
EnsureJobObjectCreated();
wil::unique_handle process;
THROW_IF_FAILED(Factory->GetProcessHandle(process.put()));
THROW_IF_WIN32_BOOL_FALSE(AssignProcessToJobObject(m_sessionJobObject.get(), process.get()));
}
void WSLCSessionManagerImpl::EnsureJobObjectCreated()
{
// Create a job object that will automatically terminate all child processes
// when the job handle is closed (i.e., when wslservice exits or crashes).
std::call_once(m_jobObjectInitFlag, [this] {
m_sessionJobObject.reset(CreateJobObjectW(nullptr, nullptr));
THROW_LAST_ERROR_IF(!m_sessionJobObject);
JOBOBJECT_EXTENDED_LIMIT_INFORMATION jobInfo{};
jobInfo.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE;
THROW_IF_WIN32_BOOL_FALSE(
SetInformationJobObject(m_sessionJobObject.get(), JobObjectExtendedLimitInformation, &jobInfo, sizeof(jobInfo)));
WSL_LOG("SessionManagerJobObjectCreated", TraceLoggingLevel(WINEVENT_LEVEL_INFO));
});
}
CallingProcessTokenInfo WSLCSessionManagerImpl::GetCallingProcessTokenInfo()
{
const wil::unique_handle userToken = wsl::windows::common::security::GetUserToken(TokenImpersonation);
auto tokenInfo = wil::get_token_information<TOKEN_USER>(userToken.get());
auto elevated = wil::test_token_membership(userToken.get(), SECURITY_NT_AUTHORITY, SECURITY_BUILTIN_DOMAIN_RID, DOMAIN_ALIAS_RID_ADMINS);
return {std::move(tokenInfo), elevated};
}
HRESULT WSLCSessionManagerImpl::CheckTokenAccess(const SessionEntry& Entry, const CallingProcessTokenInfo& TokenInfo)
{
// Allow elevated tokens to access all sessions.
// Otherwise a token can only access sessions from the same SID and elevation status.
// TODO: Offer proper ACL checks.
if (TokenInfo.Elevated)
{
return S_OK; // Token is elevated, allow access.
}
RETURN_HR_IF(E_ACCESSDENIED, !EqualSid(Entry.Owner.TokenInfo->User.Sid, TokenInfo.TokenInfo->User.Sid)); // Different account, deny access.
RETURN_HR_IF(HRESULT_FROM_WIN32(ERROR_ELEVATION_REQUIRED), Entry.Owner.Elevated); // Non-elevated token trying to access elevated session, deny access.
return S_OK;
}
WSLCSessionManager::WSLCSessionManager(WSLCSessionManagerImpl* Impl) : COMImplClass<WSLCSessionManagerImpl>(Impl)
{
}
HRESULT WSLCSessionManager::GetVersion(_Out_ WSLCVersion* Version)
{
Version->Major = WSL_PACKAGE_VERSION_MAJOR;
Version->Minor = WSL_PACKAGE_VERSION_MINOR;
Version->Revision = WSL_PACKAGE_VERSION_REVISION;
return S_OK;
}
HRESULT WSLCSessionManager::GetMinimumSupportedClientVersion(_Out_ WSLCVersion* Version)
{
constexpr std::tuple<uint32_t, uint32_t, uint32_t> c_minClientVersion{2, 8, 0};
if constexpr (wsl::shared::PackageVersion < c_minClientVersion)
{
Version->Major = WSL_PACKAGE_VERSION_MAJOR;
Version->Minor = WSL_PACKAGE_VERSION_MINOR;
Version->Revision = WSL_PACKAGE_VERSION_REVISION;
}
else
{
Version->Major = std::get<0>(c_minClientVersion);
Version->Minor = std::get<1>(c_minClientVersion);
Version->Revision = std::get<2>(c_minClientVersion);
}
return S_OK;
}
HRESULT WSLCSessionManager::CreateSession(const WSLCSessionSettings* WslcSessionSettings, WSLCSessionFlags Flags, IWSLCSession** WslcSession)
{
return CallImpl(&WSLCSessionManagerImpl::CreateSession, WslcSessionSettings, Flags, WslcSession);
}
HRESULT WSLCSessionManager::ListSessions(_Out_ WSLCSessionInformation** Sessions, _Out_ ULONG* SessionsCount)
{
return CallImpl(&WSLCSessionManagerImpl::ListSessions, Sessions, SessionsCount);
}
HRESULT WSLCSessionManager::OpenSession(_In_ ULONG Id, _Out_ IWSLCSession** Session)
{
return CallImpl(&WSLCSessionManagerImpl::OpenSession, Id, Session);
}
HRESULT WSLCSessionManager::OpenSessionByName(_In_ LPCWSTR DisplayName, _Out_ IWSLCSession** Session)
{
return CallImpl(&WSLCSessionManagerImpl::OpenSessionByName, DisplayName, Session);
}