using Microsoft.EntityFrameworkCore;
using System.Reflection;
using System.Linq;
using TakeoutSaaS.Shared.Abstractions.Entities;
using TakeoutSaaS.Shared.Abstractions.Ids;
using TakeoutSaaS.Shared.Abstractions.Security;
using TakeoutSaaS.Shared.Abstractions.Tenancy;
using Microsoft.AspNetCore.Http;
namespace TakeoutSaaS.Infrastructure.Common.Persistence;
///
/// 多租户感知 DbContext:自动应用租户过滤并填充租户字段。
///
public abstract class TenantAwareDbContext(
DbContextOptions options,
ITenantProvider tenantProvider,
ICurrentUserAccessor? currentUserAccessor = null,
IIdGenerator? idGenerator = null,
IHttpContextAccessor? httpContextAccessor = null) : AppDbContext(options, currentUserAccessor, idGenerator)
{
private readonly ITenantProvider _tenantProvider = tenantProvider ?? throw new ArgumentNullException(nameof(tenantProvider));
private readonly IHttpContextAccessor? _httpContextAccessor = httpContextAccessor;
private static readonly string[] PlatformRoleCodes =
{
"super-admin",
"SUPER_ADMIN",
"PlatformAdmin",
"platform-admin"
};
///
/// 当前请求租户 ID。
///
protected long CurrentTenantId => _tenantProvider.GetCurrentTenantId();
///
/// 保存前填充租户元数据并执行基础处理。
///
protected override void OnBeforeSaving()
{
ApplyTenantMetadata();
base.OnBeforeSaving();
}
///
/// 应用租户过滤器到所有实现 的实体。
///
/// 模型构建器。
protected void ApplyTenantQueryFilters(ModelBuilder modelBuilder)
{
foreach (var entityType in modelBuilder.Model.GetEntityTypes())
{
if (!typeof(IMultiTenantEntity).IsAssignableFrom(entityType.ClrType))
{
continue;
}
var methodInfo = typeof(TenantAwareDbContext)
.GetMethod(nameof(SetTenantFilter), BindingFlags.Instance | BindingFlags.NonPublic)!
.MakeGenericMethod(entityType.ClrType);
methodInfo.Invoke(this, new object[] { modelBuilder });
}
}
///
/// 为具体实体设置租户过滤器。
///
/// 实体类型。
/// 模型构建器。
private void SetTenantFilter(ModelBuilder modelBuilder)
where TEntity : class, IMultiTenantEntity
{
modelBuilder.Entity().HasQueryFilter(entity => entity.TenantId == CurrentTenantId);
}
///
/// 为新增实体填充租户 ID。
///
private void ApplyTenantMetadata()
{
var tenantId = CurrentTenantId;
foreach (var entry in ChangeTracker.Entries())
{
if (entry.State == EntityState.Added && entry.Entity.TenantId == 0 && tenantId != 0)
{
if (!IsPlatformAdmin())
{
entry.Entity.TenantId = tenantId;
}
}
}
}
private bool IsPlatformAdmin()
{
var user = _httpContextAccessor?.HttpContext?.User;
if (user?.Identity?.IsAuthenticated != true)
{
return false;
}
return PlatformRoleCodes.Any(user.IsInRole);
}
}