diff --git a/Common/Configuration/ConfigContainer.cs b/Common/Configuration/ConfigContainer.cs index 5472e47..2ed2718 100644 --- a/Common/Configuration/ConfigContainer.cs +++ b/Common/Configuration/ConfigContainer.cs @@ -6,6 +6,7 @@ public class ConfigContainer public GameServerConfig GameServer { get; set; } = new(); public PathConfig Path { get; set; } = new(); public ServerOption ServerOption { get; set; } = new(); + public ProxyOptions Proxy { get; set; } = new(); } public class HttpServerConfig @@ -73,4 +74,15 @@ public class ServerProfile { public string Name { get; set; } = "Miku-chan"; public int Uid { get; set; } = 80; +} + +public class ProxyOptions +{ + public bool Enabled { get; set; } = true; + public int Port { get; set; } = 8888; + public int ServerHttpPort { get; set; } = 8080; + public bool InstallRootCertificate { get; set; } = true; + public bool ManageSystemProxy { get; set; } = true; + public bool RestoreSystemProxyOnStop { get; set; } = true; + public string ProxyOverride { get; set; } = "localhost;127.*;10.*;192.168.*;"; } \ No newline at end of file diff --git a/MikuSB.sln b/MikuSB.sln index 43ce1f1..556d819 100644 --- a/MikuSB.sln +++ b/MikuSB.sln @@ -20,6 +20,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Proto", "Proto\Proto.csproj EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TcpSharp", "TcpSharp\TcpSharp.csproj", "{CD7EFAA3-C655-40EE-8F6A-A8E2DA3B0FCB}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Proxy", "Proxy\Proxy.csproj", "{B3C4D5E6-F7A8-9012-BCDE-F12345678901}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -50,6 +52,10 @@ Global {CD7EFAA3-C655-40EE-8F6A-A8E2DA3B0FCB}.Debug|Any CPU.Build.0 = Debug|Any CPU {CD7EFAA3-C655-40EE-8F6A-A8E2DA3B0FCB}.Release|Any CPU.ActiveCfg = Release|Any CPU {CD7EFAA3-C655-40EE-8F6A-A8E2DA3B0FCB}.Release|Any CPU.Build.0 = Release|Any CPU + {B3C4D5E6-F7A8-9012-BCDE-F12345678901}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {B3C4D5E6-F7A8-9012-BCDE-F12345678901}.Debug|Any CPU.Build.0 = Debug|Any CPU + {B3C4D5E6-F7A8-9012-BCDE-F12345678901}.Release|Any CPU.ActiveCfg = Release|Any CPU + {B3C4D5E6-F7A8-9012-BCDE-F12345678901}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/Proxy/Proxy.csproj b/Proxy/Proxy.csproj new file mode 100644 index 0000000..c33e878 --- /dev/null +++ b/Proxy/Proxy.csproj @@ -0,0 +1,18 @@ + + + net9.0 + enable + enable + MikuSB.Proxy + $(NoWarn);CA1416 + + + + + + + + + + + diff --git a/Proxy/ProxyCertificateAuthority.cs b/Proxy/ProxyCertificateAuthority.cs new file mode 100644 index 0000000..99cf144 --- /dev/null +++ b/Proxy/ProxyCertificateAuthority.cs @@ -0,0 +1,124 @@ +using System.Collections.Concurrent; +using System.Net; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using MikuSB.Configuration; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; + +namespace MikuSB.Proxy; + +public sealed class ProxyCertificateAuthority +{ + private const string Password = "MikuSB.Proxy.LocalCA"; + private readonly ConcurrentDictionary _serverCertificates = new(StringComparer.OrdinalIgnoreCase); + private readonly ILogger _logger; + private readonly ProxyOptions _options; + private readonly X509Certificate2 _rootCertificate; + + public ProxyCertificateAuthority(IOptions options, ILogger logger) + { + _options = options.Value; + _logger = logger; + _rootCertificate = LoadOrCreateRootCertificate(); + + if (_options.InstallRootCertificate) + InstallRootCertificate(); + else + _logger.LogWarning( + "MikuSB proxy root certificate is not installed automatically. Import {CertificatePath} into CurrentUser Root to enable HTTPS interception.", + RootCerPath); + } + + public string RootCerPath => Path.Combine(CertificateDirectory, "MikuSB.Proxy.Root.cer"); + + private static string CertificateDirectory => Path.Combine(AppContext.BaseDirectory, "proxy-certs"); + + public X509Certificate2 GetServerCertificate(string host) + { + host = host.Trim().TrimEnd('.').ToLowerInvariant(); + return _serverCertificates.GetOrAdd(host, CreateServerCertificate); + } + + private X509Certificate2 LoadOrCreateRootCertificate() + { + Directory.CreateDirectory(CertificateDirectory); + var pfxPath = Path.Combine(CertificateDirectory, "MikuSB.Proxy.Root.pfx"); + + if (File.Exists(pfxPath)) + { + var existing = new X509Certificate2( + File.ReadAllBytes(pfxPath), + Password, + X509KeyStorageFlags.Exportable | X509KeyStorageFlags.UserKeySet); + + if (!File.Exists(RootCerPath)) + File.WriteAllBytes(RootCerPath, existing.Export(X509ContentType.Cert)); + + return existing; + } + + using var rsa = RSA.Create(4096); + var request = new CertificateRequest( + "CN=MikuSB Proxy Root CA", + rsa, + HashAlgorithmName.SHA256, + RSASignaturePadding.Pkcs1); + + request.CertificateExtensions.Add(new X509BasicConstraintsExtension(true, false, 0, true)); + request.CertificateExtensions.Add(new X509KeyUsageExtension(X509KeyUsageFlags.KeyCertSign | X509KeyUsageFlags.CrlSign | X509KeyUsageFlags.DigitalSignature, true)); + request.CertificateExtensions.Add(new X509SubjectKeyIdentifierExtension(request.PublicKey, false)); + + var root = request.CreateSelfSigned(DateTimeOffset.UtcNow.AddDays(-1), DateTimeOffset.UtcNow.AddYears(10)); + var exportable = new X509Certificate2( + root.Export(X509ContentType.Pfx, Password), + Password, + X509KeyStorageFlags.Exportable | X509KeyStorageFlags.UserKeySet); + + File.WriteAllBytes(pfxPath, exportable.Export(X509ContentType.Pfx, Password)); + File.WriteAllBytes(RootCerPath, exportable.Export(X509ContentType.Cert)); + _logger.LogInformation("Created MikuSB proxy root certificate at {CertificatePath}", RootCerPath); + return exportable; + } + + private X509Certificate2 CreateServerCertificate(string host) + { + using var rsa = RSA.Create(2048); + var request = new CertificateRequest($"CN={host}", rsa, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); + + var san = new SubjectAlternativeNameBuilder(); + if (IPAddress.TryParse(host, out var address)) + san.AddIpAddress(address); + else + san.AddDnsName(host); + + request.CertificateExtensions.Add(san.Build()); + request.CertificateExtensions.Add(new X509BasicConstraintsExtension(false, false, 0, true)); + request.CertificateExtensions.Add(new X509KeyUsageExtension(X509KeyUsageFlags.DigitalSignature | X509KeyUsageFlags.KeyEncipherment, true)); + request.CertificateExtensions.Add(new X509EnhancedKeyUsageExtension([new Oid("1.3.6.1.5.5.7.3.1")], false)); + + var serial = RandomNumberGenerator.GetBytes(16); + using var certificate = request.Create( + _rootCertificate, + DateTimeOffset.UtcNow.AddDays(-1), + DateTimeOffset.UtcNow.AddYears(2), + serial); + + return new X509Certificate2( + certificate.CopyWithPrivateKey(rsa).Export(X509ContentType.Pfx), + (string?)null, + X509KeyStorageFlags.Exportable | X509KeyStorageFlags.UserKeySet); + } + + private void InstallRootCertificate() + { + using var store = new X509Store(StoreName.Root, StoreLocation.CurrentUser); + store.Open(OpenFlags.ReadWrite); + var existing = store.Certificates.Find(X509FindType.FindByThumbprint, _rootCertificate.Thumbprint, false); + if (existing.Count > 0) + return; + + store.Add(_rootCertificate); + _logger.LogWarning("Installed MikuSB proxy root certificate into CurrentUser Root store. Thumbprint={Thumbprint}", _rootCertificate.Thumbprint); + } +} diff --git a/Proxy/ProxyServer.cs b/Proxy/ProxyServer.cs new file mode 100644 index 0000000..e27f995 --- /dev/null +++ b/Proxy/ProxyServer.cs @@ -0,0 +1,438 @@ +using System.Buffers; +using System.Net; +using System.Net.Security; +using System.Net.Sockets; +using System.Security.Authentication; +using System.Text; +using MikuSB.Configuration; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; + +namespace MikuSB.Proxy; + +public sealed class ProxyServer( + IOptions options, + ProxyCertificateAuthority certificateAuthority, + HttpClient httpClient, + ILogger logger) : BackgroundService +{ + private const string ListenAddress = "127.0.0.1"; + private const string ServerHost = "127.0.0.1"; + private static readonly string[] TargetDomains = + [ + "amazingseasuncdn.com", + "amazingseasun.com", + "seasungames.com", + "snowbreak-game.com", + "xoyo.games", + "yo.games", + "qcloud.com", + "xgsdk.xoyo.games", + "xqdata.xoyo.games", + "tencentcs.com" + ]; + + private static readonly HashSet HopByHopHeaders = new(StringComparer.OrdinalIgnoreCase) + { + "Connection", + "Proxy-Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "TE", + "Trailer", + "Transfer-Encoding", + "Upgrade" + }; + + private readonly ProxyOptions _options = options.Value; + private TcpListener? _listener; + + protected override async Task ExecuteAsync(CancellationToken stoppingToken) + { + if (!_options.Enabled) + { + logger.LogInformation("MikuSB proxy is disabled"); + return; + } + + var address = IPAddress.Parse(ListenAddress); + _listener = new TcpListener(address, _options.Port); + _listener.Start(); + logger.LogInformation("MikuSB proxy listening on {Address}:{Port}", ListenAddress, _options.Port); + + try + { + while (!stoppingToken.IsCancellationRequested) + { + var client = await _listener.AcceptTcpClientAsync(stoppingToken); + _ = Task.Run(() => HandleClientAsync(client, stoppingToken), stoppingToken); + } + } + catch (OperationCanceledException) when (stoppingToken.IsCancellationRequested) + { + } + } + + public override Task StopAsync(CancellationToken cancellationToken) + { + _listener?.Stop(); + return base.StopAsync(cancellationToken); + } + + private async Task HandleClientAsync(TcpClient client, CancellationToken cancellationToken) + { + using (client) + { + try + { + await HandleClientCoreAsync(client, cancellationToken); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + } + catch (IOException) + { + } + catch (SocketException) + { + } + catch (AuthenticationException ex) + { + logger.LogWarning(ex, "Proxy TLS authentication failed"); + } + catch (Exception ex) + { + logger.LogWarning(ex, "Proxy client failed"); + } + } + } + + private async Task HandleClientCoreAsync(TcpClient client, CancellationToken cancellationToken) + { + await using var stream = client.GetStream(); + var request = await ProxyHttpRequest.ReadAsync(stream, cancellationToken); + if (request is null) + return; + + if (request.Method.Equals("CONNECT", StringComparison.OrdinalIgnoreCase)) + { + var (host, port) = SplitHostPort(request.Target, 443); + if (ShouldRedirect(host)) + { + await WriteAsciiAsync(stream, "HTTP/1.1 200 Connection Established\r\nProxy-Agent: MikuSB.Proxy\r\n\r\n", cancellationToken); + using var tlsStream = new SslStream(stream, false); + await tlsStream.AuthenticateAsServerAsync(new SslServerAuthenticationOptions + { + ServerCertificate = certificateAuthority.GetServerCertificate(host), + EnabledSslProtocols = SslProtocols.Tls12 | SslProtocols.Tls13 + }, cancellationToken); + + await HandleRedirectedHttpLoopAsync(tlsStream, host, cancellationToken); + return; + } + + await TunnelAsync(stream, host, port, cancellationToken); + return; + } + + await HandlePlainHttpLoopAsync(stream, request, cancellationToken); + } + + private async Task HandlePlainHttpLoopAsync(Stream clientStream, ProxyHttpRequest request, CancellationToken cancellationToken) + { + while (true) + { + var host = request.Host; + if (string.IsNullOrWhiteSpace(host)) + { + await WriteSimpleResponseAsync(clientStream, HttpStatusCode.BadRequest, "Missing Host header", cancellationToken); + return; + } + + if (ShouldRedirect(SplitHostPort(host, 80).Host)) + await ForwardToServerAsync(clientStream, request, cancellationToken); + else + await ForwardToOriginAsync(clientStream, request, cancellationToken); + + if (request.ShouldClose) + return; + + var nextRequest = await ProxyHttpRequest.ReadAsync(clientStream, cancellationToken); + if (nextRequest is null) + return; + + request = nextRequest; + } + } + + private async Task HandleRedirectedHttpLoopAsync(Stream tlsStream, string originalHost, CancellationToken cancellationToken) + { + while (!cancellationToken.IsCancellationRequested) + { + var request = await ProxyHttpRequest.ReadAsync(tlsStream, cancellationToken); + if (request is null) + return; + + request.HostOverride = originalHost; + await ForwardToServerAsync(tlsStream, request, cancellationToken); + + if (request.ShouldClose) + return; + } + } + + private async Task ForwardToServerAsync(Stream clientStream, ProxyHttpRequest request, CancellationToken cancellationToken) + { + var pathAndQuery = request.GetPathAndQuery(); + var uri = new Uri($"http://{ServerHost}:{_options.ServerHttpPort}{pathAndQuery}"); + logger.LogInformation("[Proxy] Redirect: {Method} {Host}{Path} -> {Uri}", request.Method, request.HostOverride ?? request.Host, pathAndQuery, uri); + await SendHttpRequestAsync(clientStream, request, uri, true, cancellationToken); + } + + private async Task ForwardToOriginAsync(Stream clientStream, ProxyHttpRequest request, CancellationToken cancellationToken) + { + var uri = request.GetAbsoluteUri(); + if (uri is null) + { + await WriteSimpleResponseAsync(clientStream, HttpStatusCode.BadRequest, "Only absolute-form proxy requests are supported for non-target HTTP", cancellationToken); + return; + } + + if (IsSelfReference(uri)) + { + logger.LogWarning("[Proxy] Self-reference blocked: {Method} {Uri}", request.Method, uri); + await WriteSimpleResponseAsync(clientStream, HttpStatusCode.LoopDetected, "Proxy self-reference detected", cancellationToken); + return; + } + + await SendHttpRequestAsync(clientStream, request, uri, false, cancellationToken); + } + + private bool IsSelfReference(Uri uri) + { + if (uri.Port != _options.Port) + return false; + + return uri.Host is "127.0.0.1" or "localhost" or "::1" + || uri.Host.Equals(ListenAddress, StringComparison.OrdinalIgnoreCase); + } + + private async Task SendHttpRequestAsync(Stream clientStream, ProxyHttpRequest request, Uri uri, bool addCors, CancellationToken cancellationToken) + { + using var outgoing = new HttpRequestMessage(new HttpMethod(request.Method), uri); + if (request.Body.Length > 0) + outgoing.Content = new ByteArrayContent(request.Body); + + foreach (var (name, value) in request.Headers) + { + if (HopByHopHeaders.Contains(name) || name.Equals("Host", StringComparison.OrdinalIgnoreCase)) + continue; + + if (!outgoing.Headers.TryAddWithoutValidation(name, value)) + { + outgoing.Content ??= new ByteArrayContent(request.Body); + outgoing.Content.Headers.TryAddWithoutValidation(name, value); + } + } + + using var response = await httpClient.SendAsync(outgoing, HttpCompletionOption.ResponseHeadersRead, cancellationToken); + var body = await response.Content.ReadAsByteArrayAsync(cancellationToken); + + var builder = new StringBuilder(); + builder.Append("HTTP/1.1 ") + .Append((int)response.StatusCode) + .Append(' ') + .Append(response.ReasonPhrase ?? response.StatusCode.ToString()) + .Append("\r\n"); + + foreach (var header in response.Headers) + AppendHeader(builder, header.Key, header.Value); + + foreach (var header in response.Content.Headers) + { + if (!header.Key.Equals("Content-Length", StringComparison.OrdinalIgnoreCase)) + AppendHeader(builder, header.Key, header.Value); + } + + if (addCors) + builder.Append("Access-Control-Allow-Origin: *\r\n"); + + builder.Append("Content-Length: ").Append(body.Length).Append("\r\n"); + builder.Append("Connection: keep-alive\r\n\r\n"); + + await WriteAsciiAsync(clientStream, builder.ToString(), cancellationToken); + if (body.Length > 0) + await clientStream.WriteAsync(body, cancellationToken); + } + + private async Task TunnelAsync(Stream clientStream, string host, int port, CancellationToken cancellationToken) + { + using var upstream = new TcpClient(); + await upstream.ConnectAsync(host, port, cancellationToken); + await WriteAsciiAsync(clientStream, "HTTP/1.1 200 Connection Established\r\nProxy-Agent: MikuSB.Proxy\r\n\r\n", cancellationToken); + + await using var upstreamStream = upstream.GetStream(); + var clientToServer = clientStream.CopyToAsync(upstreamStream, cancellationToken); + var serverToClient = upstreamStream.CopyToAsync(clientStream, cancellationToken); + await Task.WhenAny(clientToServer, serverToClient); + } + + private bool ShouldRedirect(string host) + { + host = host.Trim().TrimEnd('.').ToLowerInvariant(); + foreach (var target in TargetDomains) + { + var normalized = target.Trim().TrimEnd('.').ToLowerInvariant(); + if (host == normalized || host.EndsWith($".{normalized}", StringComparison.OrdinalIgnoreCase)) + return true; + } + + return false; + } + + private static void AppendHeader(StringBuilder builder, string name, IEnumerable values) + { + if (HopByHopHeaders.Contains(name)) + return; + + foreach (var value in values) + builder.Append(name).Append(": ").Append(value).Append("\r\n"); + } + + private static async Task WriteSimpleResponseAsync(Stream stream, HttpStatusCode statusCode, string message, CancellationToken cancellationToken) + { + var body = Encoding.UTF8.GetBytes(message); + await WriteAsciiAsync( + stream, + $"HTTP/1.1 {(int)statusCode} {statusCode}\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Length: {body.Length}\r\nConnection: close\r\n\r\n", + cancellationToken); + await stream.WriteAsync(body, cancellationToken); + } + + private static Task WriteAsciiAsync(Stream stream, string value, CancellationToken cancellationToken) => + stream.WriteAsync(Encoding.ASCII.GetBytes(value), cancellationToken).AsTask(); + + private static (string Host, int Port) SplitHostPort(string hostPort, int defaultPort) + { + if (hostPort.StartsWith('[')) + { + var end = hostPort.IndexOf(']'); + if (end > 0 && hostPort.Length > end + 2 && hostPort[end + 1] == ':' && int.TryParse(hostPort[(end + 2)..], out var ipv6Port)) + return (hostPort[1..end], ipv6Port); + + return (hostPort.Trim('[', ']'), defaultPort); + } + + var colon = hostPort.LastIndexOf(':'); + if (colon > 0 && int.TryParse(hostPort[(colon + 1)..], out var port)) + return (hostPort[..colon], port); + + return (hostPort, defaultPort); + } + + private sealed class ProxyHttpRequest + { + public required string Method { get; init; } + public required string Target { get; init; } + public required string Version { get; init; } + public required List> Headers { get; init; } + public required byte[] Body { get; init; } + public string? HostOverride { get; set; } + + public string? Host => HostOverride ?? Headers.FirstOrDefault(x => x.Key.Equals("Host", StringComparison.OrdinalIgnoreCase)).Value; + + public bool ShouldClose => + Headers.Any(x => x.Key.Equals("Connection", StringComparison.OrdinalIgnoreCase) + && x.Value.Contains("close", StringComparison.OrdinalIgnoreCase)); + + public Uri? GetAbsoluteUri() => Uri.TryCreate(Target, UriKind.Absolute, out var uri) ? uri : null; + + public string GetPathAndQuery() + { + if (Uri.TryCreate(Target, UriKind.Absolute, out var uri)) + return uri.PathAndQuery; + + if (string.IsNullOrEmpty(Target)) + return "/"; + + return Target[0] == '/' ? Target : $"/{Target}"; + } + + public static async Task ReadAsync(Stream stream, CancellationToken cancellationToken) + { + var rented = ArrayPool.Shared.Rent(64 * 1024); + try + { + var length = 0; + while (true) + { + var read = await stream.ReadAsync(rented.AsMemory(length, 1), cancellationToken); + if (read == 0) + return null; + + length += read; + if (length >= 4 + && rented[length - 4] == '\r' + && rented[length - 3] == '\n' + && rented[length - 2] == '\r' + && rented[length - 1] == '\n') + break; + + if (length == rented.Length) + throw new InvalidDataException("HTTP proxy request header is too large"); + } + + var headerText = Encoding.ASCII.GetString(rented, 0, length); + var lines = headerText.Split("\r\n", StringSplitOptions.None); + var requestLine = lines[0].Split(' ', 3, StringSplitOptions.RemoveEmptyEntries); + if (requestLine.Length != 3) + throw new InvalidDataException("Invalid HTTP proxy request line"); + + var headers = new List>(); + var contentLength = 0; + for (var i = 1; i < lines.Length; i++) + { + var line = lines[i]; + if (string.IsNullOrEmpty(line)) + break; + + var colon = line.IndexOf(':'); + if (colon <= 0) + continue; + + var name = line[..colon].Trim(); + var value = line[(colon + 1)..].Trim(); + headers.Add(new KeyValuePair(name, value)); + if (name.Equals("Content-Length", StringComparison.OrdinalIgnoreCase) && int.TryParse(value, out var parsedLength)) + contentLength = parsedLength; + } + + var body = new byte[contentLength]; + var offset = 0; + while (offset < body.Length) + { + var read = await stream.ReadAsync(body.AsMemory(offset), cancellationToken); + if (read == 0) + throw new EndOfStreamException("HTTP proxy request body ended early"); + + offset += read; + } + + return new ProxyHttpRequest + { + Method = requestLine[0], + Target = requestLine[1], + Version = requestLine[2], + Headers = headers, + Body = body + }; + } + finally + { + ArrayPool.Shared.Return(rented); + } + } + } +} diff --git a/Proxy/ProxyServiceCollectionExtensions.cs b/Proxy/ProxyServiceCollectionExtensions.cs new file mode 100644 index 0000000..55d16a5 --- /dev/null +++ b/Proxy/ProxyServiceCollectionExtensions.cs @@ -0,0 +1,25 @@ +using MikuSB.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; + +namespace MikuSB.Proxy; + +public static class ProxyServiceCollectionExtensions +{ + public static IServiceCollection AddMikuSbProxy(this IServiceCollection services, ProxyOptions options) + { + services.AddSingleton>(Microsoft.Extensions.Options.Options.Create(options)); + services.AddSingleton(); + services.AddSingleton(new HttpClient(new SocketsHttpHandler + { + AllowAutoRedirect = false, + AutomaticDecompression = System.Net.DecompressionMethods.None, + UseCookies = false, + UseProxy = false + })); + services.AddSingleton(); + services.AddHostedService(sp => sp.GetRequiredService()); + services.AddHostedService(); + return services; + } +} diff --git a/Proxy/WindowsSystemProxyService.cs b/Proxy/WindowsSystemProxyService.cs new file mode 100644 index 0000000..9433877 --- /dev/null +++ b/Proxy/WindowsSystemProxyService.cs @@ -0,0 +1,132 @@ +using System.Runtime.InteropServices; +using MikuSB.Configuration; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Microsoft.Win32; + +namespace MikuSB.Proxy; + +public sealed class WindowsSystemProxyService( + IOptions options, + ILogger logger) : IHostedService, IDisposable +{ + private const string InternetSettingsPath = @"Software\Microsoft\Windows\CurrentVersion\Internet Settings"; + private readonly ProxyOptions _options = options.Value; + private ConsoleCtrlHandler? _consoleCtrlHandler; + private int _proxyDisabled; + + public Task StartAsync(CancellationToken cancellationToken) + { + if (!_options.Enabled || !_options.ManageSystemProxy) + return Task.CompletedTask; + + if (!OperatingSystem.IsWindows()) + { + logger.LogWarning("System proxy management is only supported on Windows"); + return Task.CompletedTask; + } + + using var key = Registry.CurrentUser.OpenSubKey(InternetSettingsPath, writable: true); + if (key is null) + { + logger.LogWarning("Unable to open Windows Internet Settings registry key"); + return Task.CompletedTask; + } + + var proxyServer = $"http=127.0.0.1:{_options.Port};https=127.0.0.1:{_options.Port}"; + + key.SetValue("ProxyEnable", 1, RegistryValueKind.DWord); + key.SetValue("ProxyServer", proxyServer, RegistryValueKind.String); + key.SetValue("ProxyOverride", _options.ProxyOverride, RegistryValueKind.String); + NotifyProxySettingsChanged(); + AppDomain.CurrentDomain.ProcessExit += OnProcessExit; + RegisterConsoleCtrlHandler(); + + logger.LogWarning("Windows system proxy enabled for MikuSB: {ProxyServer}", proxyServer); + return Task.CompletedTask; + } + + public Task StopAsync(CancellationToken cancellationToken) + { + if (!_options.Enabled || !_options.ManageSystemProxy || !_options.RestoreSystemProxyOnStop) + return Task.CompletedTask; + + DisableSystemProxy(); + return Task.CompletedTask; + } + + public void Dispose() + { + AppDomain.CurrentDomain.ProcessExit -= OnProcessExit; + UnregisterConsoleCtrlHandler(); + + if (_options.Enabled && _options.ManageSystemProxy && _options.RestoreSystemProxyOnStop) + DisableSystemProxy(); + } + + private void OnProcessExit(object? sender, EventArgs e) + { + if (_options.Enabled && _options.ManageSystemProxy && _options.RestoreSystemProxyOnStop) + DisableSystemProxy(); + } + + private void DisableSystemProxy() + { + if (!OperatingSystem.IsWindows()) + return; + + if (Interlocked.Exchange(ref _proxyDisabled, 1) == 1) + return; + + using var key = Registry.CurrentUser.OpenSubKey(InternetSettingsPath, writable: true); + if (key is null) + return; + + key.SetValue("ProxyEnable", 0, RegistryValueKind.DWord); + key.DeleteValue("ProxyServer", throwOnMissingValue: false); + key.DeleteValue("ProxyOverride", throwOnMissingValue: false); + NotifyProxySettingsChanged(); + logger.LogWarning("Windows system proxy disabled for MikuSB shutdown"); + } + + private void RegisterConsoleCtrlHandler() + { + if (!OperatingSystem.IsWindows()) + return; + + _consoleCtrlHandler = OnConsoleCtrl; + SetConsoleCtrlHandler(_consoleCtrlHandler, add: true); + } + + private void UnregisterConsoleCtrlHandler() + { + if (!OperatingSystem.IsWindows() || _consoleCtrlHandler is null) + return; + + SetConsoleCtrlHandler(_consoleCtrlHandler, add: false); + _consoleCtrlHandler = null; + } + + private bool OnConsoleCtrl(int signal) + { + if (_options.Enabled && _options.ManageSystemProxy && _options.RestoreSystemProxyOnStop) + DisableSystemProxy(); + + return false; + } + + private static void NotifyProxySettingsChanged() + { + InternetSetOption(IntPtr.Zero, 39, IntPtr.Zero, 0); + InternetSetOption(IntPtr.Zero, 37, IntPtr.Zero, 0); + } + + private delegate bool ConsoleCtrlHandler(int signal); + + [DllImport("kernel32.dll", SetLastError = true)] + private static extern bool SetConsoleCtrlHandler(ConsoleCtrlHandler handler, bool add); + + [DllImport("wininet.dll", SetLastError = true)] + private static extern bool InternetSetOption(IntPtr internet, int option, IntPtr buffer, int bufferLength); +} diff --git a/SdkServer/SdkServer.cs b/SdkServer/SdkServer.cs index 68c74cc..6955abe 100644 --- a/SdkServer/SdkServer.cs +++ b/SdkServer/SdkServer.cs @@ -5,6 +5,7 @@ using Microsoft.AspNetCore.Http; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; +using MikuSB.Proxy; using MikuSB.SdkServer.Handlers; using MikuSB.SdkServer.Utils; using MikuSB.Util; @@ -91,5 +92,6 @@ public class Startup options.JsonSerializerOptions.PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower; }); services.AddSingleton(_ => new Logger("HttpServer")); + services.AddMikuSbProxy(ConfigManager.Config.Proxy); } } diff --git a/SdkServer/SdkServer.csproj b/SdkServer/SdkServer.csproj index b69edd1..c968bc8 100644 --- a/SdkServer/SdkServer.csproj +++ b/SdkServer/SdkServer.csproj @@ -13,6 +13,7 @@ +