diff --git a/FastGithub.DomainResolve/DnsClient.cs b/FastGithub.DomainResolve/DnsClient.cs index e96bc88..c552769 100644 --- a/FastGithub.DomainResolve/DnsClient.cs +++ b/FastGithub.DomainResolve/DnsClient.cs @@ -1,4 +1,5 @@ -using DNS.Client; +using AsyncKeyedLock; +using DNS.Client; using DNS.Client.RequestResolver; using DNS.Protocol; using DNS.Protocol.ResourceRecords; @@ -7,7 +8,6 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using System; -using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Net; @@ -30,7 +30,7 @@ sealed class DnsClient private readonly FastGithubConfig fastGithubConfig; private readonly ILogger logger; - private readonly ConcurrentDictionary semaphoreSlims = new(); + private readonly AsyncKeyedLocker locks = new(); private readonly IMemoryCache dnsStateCache = new MemoryCache(Options.Create(new MemoryCacheOptions())); private readonly IMemoryCache dnsLookupCache = new MemoryCache(Options.Create(new MemoryCacheOptions())); @@ -123,9 +123,8 @@ private async ValueTask IsDnsAvailableAsync(IPEndPoint dns, CancellationTo } var key = dns.ToString(); - var semaphore = this.semaphoreSlims.GetOrAdd(key, _ => new SemaphoreSlim(1, 1)); - await semaphore.WaitAsync(CancellationToken.None); + using var _ = await locks.LockAsync(key); try { using var timeoutTokenSource = new CancellationTokenSource(tcpConnectTimeout); @@ -139,10 +138,6 @@ private async ValueTask IsDnsAvailableAsync(IPEndPoint dns, CancellationTo cancellationToken.ThrowIfCancellationRequested(); return this.dnsStateCache.Set(dns, false, this.stateExpiration); } - finally - { - semaphore.Release(); - } } /// @@ -156,9 +151,7 @@ private async ValueTask IsDnsAvailableAsync(IPEndPoint dns, CancellationTo private async Task> LookupAsync(IPEndPoint dns, DnsEndPoint endPoint, bool fastSort, CancellationToken cancellationToken = default) { var key = $"{dns}/{endPoint}"; - var semaphore = this.semaphoreSlims.GetOrAdd(key, _ => new SemaphoreSlim(1, 1)); - await semaphore.WaitAsync(CancellationToken.None); - + using var _ = await locks.LockAsync(key); try { if (this.dnsLookupCache.TryGetValue>(key, out var value)) @@ -178,10 +171,6 @@ private async Task> LookupAsync(IPEndPoint dns, DnsEndPoint end var expiration = IsSocketException(ex) ? this.maxTimeToLive : this.minTimeToLive; return this.dnsLookupCache.Set(key, Array.Empty(), expiration); } - finally - { - semaphore.Release(); - } } /// diff --git a/FastGithub.DomainResolve/FastGithub.DomainResolve.csproj b/FastGithub.DomainResolve/FastGithub.DomainResolve.csproj index 7ae4c7a..54cfbe6 100644 --- a/FastGithub.DomainResolve/FastGithub.DomainResolve.csproj +++ b/FastGithub.DomainResolve/FastGithub.DomainResolve.csproj @@ -1,6 +1,7 @@  + diff --git a/FastGithub.DomainResolve/PersistenceService.cs b/FastGithub.DomainResolve/PersistenceService.cs index 7f844fc..c70380d 100644 --- a/FastGithub.DomainResolve/PersistenceService.cs +++ b/FastGithub.DomainResolve/PersistenceService.cs @@ -1,4 +1,5 @@ -using FastGithub.Configuration; +using AsyncKeyedLock; +using FastGithub.Configuration; using Microsoft.Extensions.Logging; using System; using System.Collections.Generic; @@ -18,7 +19,7 @@ namespace FastGithub.DomainResolve sealed partial class PersistenceService { private static readonly string dataFile = "dnsendpoints.json"; - private static readonly SemaphoreSlim dataLocker = new(1, 1); + private static readonly AsyncNonKeyedLocker dataLocker = new(1); private readonly FastGithubConfig fastGithubConfig; private readonly ILogger logger; @@ -60,10 +61,9 @@ public IList ReadDnsEndPoints() return Array.Empty(); } + using var _ = dataLocker.Lock(); try { - dataLocker.Wait(); - var utf8Json = File.ReadAllBytes(dataFile); var endPointItems = JsonSerializer.Deserialize(utf8Json, EndPointItemsContext.Default.EndPointItemArray); if (endPointItems == null) @@ -86,10 +86,6 @@ public IList ReadDnsEndPoints() this.logger.LogWarning(ex.Message, "读取dns记录异常"); return Array.Empty(); } - finally - { - dataLocker.Release(); - } } /// @@ -100,10 +96,9 @@ public IList ReadDnsEndPoints() /// public async Task WriteDnsEndPointsAsync(IEnumerable dnsEndPoints, CancellationToken cancellationToken) { + using var _ = await dataLocker.LockAsync(); try { - await dataLocker.WaitAsync(CancellationToken.None); - var endPointItems = dnsEndPoints.Select(item => new EndPointItem(item.Host, item.Port)).ToArray(); var utf8Json = JsonSerializer.SerializeToUtf8Bytes(endPointItems, EndPointItemsContext.Default.EndPointItemArray); await File.WriteAllBytesAsync(dataFile, utf8Json, cancellationToken); @@ -112,10 +107,6 @@ public async Task WriteDnsEndPointsAsync(IEnumerable dnsEndPoints, { this.logger.LogWarning(ex.Message, "保存dns记录异常"); } - finally - { - dataLocker.Release(); - } } } }