fix: 禁止账单跨租户查询

This commit is contained in:
root
2026-01-29 14:51:56 +00:00
parent 1622c38043
commit a0b77d4847
11 changed files with 61 additions and 27 deletions

View File

@@ -5,7 +5,7 @@ using TakeoutSaaS.Shared.Abstractions.Serialization;
namespace TakeoutSaaS.Application.App.Billings.Dto; namespace TakeoutSaaS.Application.App.Billings.Dto;
/// <summary> /// <summary>
/// 账单详情 DTO管理员端)。 /// 账单详情 DTO租户端)。
/// </summary> /// </summary>
public sealed record BillingDetailDto public sealed record BillingDetailDto
{ {

View File

@@ -364,7 +364,7 @@ public sealed record PaymentRecordDto
public sealed record BillingStatisticsDto public sealed record BillingStatisticsDto
{ {
/// <summary> /// <summary>
/// 租户 ID为空表示跨租户统计)。 /// 租户 ID当前租户)。
/// </summary> /// </summary>
[JsonConverter(typeof(NullableSnowflakeIdJsonConverter))] [JsonConverter(typeof(NullableSnowflakeIdJsonConverter))]
public long? TenantId { get; init; } public long? TenantId { get; init; }

View File

@@ -5,7 +5,7 @@ using TakeoutSaaS.Shared.Abstractions.Serialization;
namespace TakeoutSaaS.Application.App.Billings.Dto; namespace TakeoutSaaS.Application.App.Billings.Dto;
/// <summary> /// <summary>
/// 账单列表 DTO管理员端列表展示)。 /// 账单列表 DTO租户端列表展示)。
/// </summary> /// </summary>
public sealed record BillingListDto public sealed record BillingListDto
{ {

View File

@@ -9,7 +9,7 @@ namespace TakeoutSaaS.Application.App.Billings.Dto;
public sealed record BillingStatisticsDto public sealed record BillingStatisticsDto
{ {
/// <summary> /// <summary>
/// 租户 ID可选,管理员可跨租户统计)。 /// 租户 ID当前租户)。
/// </summary> /// </summary>
[JsonConverter(typeof(NullableSnowflakeIdJsonConverter))] [JsonConverter(typeof(NullableSnowflakeIdJsonConverter))]
public long? TenantId { get; init; } public long? TenantId { get; init; }

View File

@@ -5,7 +5,7 @@ using TakeoutSaaS.Shared.Abstractions.Serialization;
namespace TakeoutSaaS.Application.App.Billings.Dto; namespace TakeoutSaaS.Application.App.Billings.Dto;
/// <summary> /// <summary>
/// 支付记录 DTO管理员端)。 /// 支付记录 DTO租户端)。
/// </summary> /// </summary>
public sealed record PaymentRecordDto public sealed record PaymentRecordDto
{ {

View File

@@ -6,7 +6,9 @@ using TakeoutSaaS.Application.App.Billings.Queries;
using TakeoutSaaS.Domain.Tenants.Enums; using TakeoutSaaS.Domain.Tenants.Enums;
using TakeoutSaaS.Shared.Abstractions.Constants; using TakeoutSaaS.Shared.Abstractions.Constants;
using TakeoutSaaS.Shared.Abstractions.Data; using TakeoutSaaS.Shared.Abstractions.Data;
using TakeoutSaaS.Shared.Abstractions.Exceptions;
using TakeoutSaaS.Shared.Abstractions.Results; using TakeoutSaaS.Shared.Abstractions.Results;
using TakeoutSaaS.Shared.Abstractions.Tenancy;
namespace TakeoutSaaS.Application.App.Billings.Handlers; namespace TakeoutSaaS.Application.App.Billings.Handlers;
@@ -14,7 +16,8 @@ namespace TakeoutSaaS.Application.App.Billings.Handlers;
/// 分页查询账单列表处理器。 /// 分页查询账单列表处理器。
/// </summary> /// </summary>
public sealed class GetBillingListQueryHandler( public sealed class GetBillingListQueryHandler(
IDapperExecutor dapperExecutor) IDapperExecutor dapperExecutor,
ITenantProvider tenantProvider)
: IRequestHandler<GetBillingListQuery, PagedResult<BillingListDto>> : IRequestHandler<GetBillingListQuery, PagedResult<BillingListDto>>
{ {
/// <summary> /// <summary>
@@ -25,7 +28,21 @@ public sealed class GetBillingListQueryHandler(
/// <returns>分页账单列表 DTO。</returns> /// <returns>分页账单列表 DTO。</returns>
public async Task<PagedResult<BillingListDto>> Handle(GetBillingListQuery request, CancellationToken cancellationToken) public async Task<PagedResult<BillingListDto>> Handle(GetBillingListQuery request, CancellationToken cancellationToken)
{ {
// 1. 参数规范化 // 1. 校验租户上下文(租户端禁止跨租户)
var currentTenantId = tenantProvider.GetCurrentTenantId();
if (currentTenantId <= 0)
{
throw new BusinessException(ErrorCodes.BadRequest, "缺少租户标识");
}
// 2. (空行后) 禁止跨租户查询
if (request.TenantId.HasValue && request.TenantId.Value != currentTenantId)
{
throw new BusinessException(ErrorCodes.Forbidden, "禁止跨租户查询账单");
}
var tenantId = currentTenantId;
// 3. (空行后) 参数规范化
var page = request.PageNumber <= 0 ? 1 : request.PageNumber; var page = request.PageNumber <= 0 ? 1 : request.PageNumber;
var pageSize = request.PageSize is <= 0 or > 200 ? 20 : request.PageSize; var pageSize = request.PageSize is <= 0 or > 200 ? 20 : request.PageSize;
var keyword = string.IsNullOrWhiteSpace(request.Keyword) ? null : request.Keyword.Trim(); var keyword = string.IsNullOrWhiteSpace(request.Keyword) ? null : request.Keyword.Trim();
@@ -61,7 +78,7 @@ public sealed class GetBillingListQueryHandler(
connection, connection,
BuildCountSql(), BuildCountSql(),
[ [
("tenantId", request.TenantId), ("tenantId", tenantId),
("status", request.Status.HasValue ? (int)request.Status.Value : null), ("status", request.Status.HasValue ? (int)request.Status.Value : null),
("billingType", request.BillingType.HasValue ? (int)request.BillingType.Value : null), ("billingType", request.BillingType.HasValue ? (int)request.BillingType.Value : null),
("startDate", request.StartDate), ("startDate", request.StartDate),
@@ -78,7 +95,7 @@ public sealed class GetBillingListQueryHandler(
connection, connection,
listSql, listSql,
[ [
("tenantId", request.TenantId), ("tenantId", tenantId),
("status", request.Status.HasValue ? (int)request.Status.Value : null), ("status", request.Status.HasValue ? (int)request.Status.Value : null),
("billingType", request.BillingType.HasValue ? (int)request.BillingType.Value : null), ("billingType", request.BillingType.HasValue ? (int)request.BillingType.Value : null),
("startDate", request.StartDate), ("startDate", request.StartDate),
@@ -145,7 +162,7 @@ public sealed class GetBillingListQueryHandler(
from public.tenant_billing_statements b from public.tenant_billing_statements b
join public.tenants t on t."Id" = b."TenantId" and t."DeletedAt" is null join public.tenants t on t."Id" = b."TenantId" and t."DeletedAt" is null
where b."DeletedAt" is null where b."DeletedAt" is null
and (@tenantId::bigint is null or b."TenantId" = @tenantId) and b."TenantId" = @tenantId
and (@status::int is null or b."Status" = @status) and (@status::int is null or b."Status" = @status)
and (@billingType::int is null or b."BillingType" = @billingType) and (@billingType::int is null or b."BillingType" = @billingType)
and (@startDate::timestamp with time zone is null or b."PeriodStart" >= @startDate) and (@startDate::timestamp with time zone is null or b."PeriodStart" >= @startDate)
@@ -186,7 +203,7 @@ public sealed class GetBillingListQueryHandler(
from public.tenant_billing_statements b from public.tenant_billing_statements b
join public.tenants t on t."Id" = b."TenantId" and t."DeletedAt" is null join public.tenants t on t."Id" = b."TenantId" and t."DeletedAt" is null
where b."DeletedAt" is null where b."DeletedAt" is null
and (@tenantId::bigint is null or b."TenantId" = @tenantId) and b."TenantId" = @tenantId
and (@status::int is null or b."Status" = @status) and (@status::int is null or b."Status" = @status)
and (@billingType::int is null or b."BillingType" = @billingType) and (@billingType::int is null or b."BillingType" = @billingType)
and (@startDate::timestamp with time zone is null or b."PeriodStart" >= @startDate) and (@startDate::timestamp with time zone is null or b."PeriodStart" >= @startDate)

View File

@@ -6,6 +6,8 @@ using TakeoutSaaS.Application.App.Billings.Queries;
using TakeoutSaaS.Domain.Tenants.Enums; using TakeoutSaaS.Domain.Tenants.Enums;
using TakeoutSaaS.Shared.Abstractions.Constants; using TakeoutSaaS.Shared.Abstractions.Constants;
using TakeoutSaaS.Shared.Abstractions.Data; using TakeoutSaaS.Shared.Abstractions.Data;
using TakeoutSaaS.Shared.Abstractions.Exceptions;
using TakeoutSaaS.Shared.Abstractions.Tenancy;
namespace TakeoutSaaS.Application.App.Billings.Handlers; namespace TakeoutSaaS.Application.App.Billings.Handlers;
@@ -13,7 +15,8 @@ namespace TakeoutSaaS.Application.App.Billings.Handlers;
/// 查询账单统计数据处理器。 /// 查询账单统计数据处理器。
/// </summary> /// </summary>
public sealed class GetBillingStatisticsQueryHandler( public sealed class GetBillingStatisticsQueryHandler(
IDapperExecutor dapperExecutor) IDapperExecutor dapperExecutor,
ITenantProvider tenantProvider)
: IRequestHandler<GetBillingStatisticsQuery, BillingStatisticsDto> : IRequestHandler<GetBillingStatisticsQuery, BillingStatisticsDto>
{ {
/// <summary> /// <summary>
@@ -24,7 +27,21 @@ public sealed class GetBillingStatisticsQueryHandler(
/// <returns>账单统计数据 DTO。</returns> /// <returns>账单统计数据 DTO。</returns>
public async Task<BillingStatisticsDto> Handle(GetBillingStatisticsQuery request, CancellationToken cancellationToken) public async Task<BillingStatisticsDto> Handle(GetBillingStatisticsQuery request, CancellationToken cancellationToken)
{ {
// 1. 参数规范化 // 1. 校验租户上下文(租户端禁止跨租户)
var currentTenantId = tenantProvider.GetCurrentTenantId();
if (currentTenantId <= 0)
{
throw new BusinessException(ErrorCodes.BadRequest, "缺少租户标识");
}
// 2. (空行后) 禁止跨租户统计
if (request.TenantId.HasValue && request.TenantId.Value != currentTenantId)
{
throw new BusinessException(ErrorCodes.Forbidden, "禁止跨租户统计账单");
}
var tenantId = currentTenantId;
// 3. (空行后) 参数规范化
var startDate = request.StartDate ?? DateTime.UtcNow.AddMonths(-1); var startDate = request.StartDate ?? DateTime.UtcNow.AddMonths(-1);
var endDate = request.EndDate ?? DateTime.UtcNow; var endDate = request.EndDate ?? DateTime.UtcNow;
var groupBy = NormalizeGroupBy(request.GroupBy); var groupBy = NormalizeGroupBy(request.GroupBy);
@@ -40,7 +57,7 @@ public sealed class GetBillingStatisticsQueryHandler(
connection, connection,
BuildSummarySql(), BuildSummarySql(),
[ [
("tenantId", request.TenantId), ("tenantId", tenantId),
("startDate", startDate), ("startDate", startDate),
("endDate", endDate), ("endDate", endDate),
("now", DateTime.UtcNow) ("now", DateTime.UtcNow)
@@ -64,7 +81,7 @@ public sealed class GetBillingStatisticsQueryHandler(
connection, connection,
BuildTrendSql(groupBy), BuildTrendSql(groupBy),
[ [
("tenantId", request.TenantId), ("tenantId", tenantId),
("startDate", startDate), ("startDate", startDate),
("endDate", endDate) ("endDate", endDate)
]); ]);
@@ -86,7 +103,7 @@ public sealed class GetBillingStatisticsQueryHandler(
// 2.3 组装 DTO // 2.3 组装 DTO
return new BillingStatisticsDto return new BillingStatisticsDto
{ {
TenantId = request.TenantId, TenantId = tenantId,
StartDate = startDate, StartDate = startDate,
EndDate = endDate, EndDate = endDate,
GroupBy = groupBy, GroupBy = groupBy,
@@ -138,7 +155,7 @@ public sealed class GetBillingStatisticsQueryHandler(
), 0)::numeric as "TotalOverdueAmount" ), 0)::numeric as "TotalOverdueAmount"
from public.tenant_billing_statements b from public.tenant_billing_statements b
where b."DeletedAt" is null where b."DeletedAt" is null
and (@tenantId::bigint is null or b."TenantId" = @tenantId) and b."TenantId" = @tenantId
and b."PeriodStart" >= @startDate and b."PeriodStart" >= @startDate
and b."PeriodEnd" <= @endDate; and b."PeriodEnd" <= @endDate;
"""; """;
@@ -161,7 +178,7 @@ public sealed class GetBillingStatisticsQueryHandler(
count(*)::int as "Count" count(*)::int as "Count"
from public.tenant_billing_statements b from public.tenant_billing_statements b
where b."DeletedAt" is null where b."DeletedAt" is null
and (@tenantId::bigint is null or b."TenantId" = @tenantId) and b."TenantId" = @tenantId
and b."PeriodStart" >= @startDate and b."PeriodStart" >= @startDate
and b."PeriodEnd" <= @endDate and b."PeriodEnd" <= @endDate
group by 1 group by 1

View File

@@ -11,7 +11,7 @@ namespace TakeoutSaaS.Application.App.Billings.Queries;
public sealed record GetBillingListQuery : IRequest<PagedResult<BillingListDto>> public sealed record GetBillingListQuery : IRequest<PagedResult<BillingListDto>>
{ {
/// <summary> /// <summary>
/// 租户 ID可选管理员可查询所有租户)。 /// 租户 ID可选默认当前租户;禁止跨租户)。
/// </summary> /// </summary>
public long? TenantId { get; init; } public long? TenantId { get; init; }

View File

@@ -9,7 +9,7 @@ namespace TakeoutSaaS.Application.App.Billings.Queries;
public sealed record GetBillingStatisticsQuery : IRequest<BillingStatisticsDto> public sealed record GetBillingStatisticsQuery : IRequest<BillingStatisticsDto>
{ {
/// <summary> /// <summary>
/// 租户 ID可选管理员可查询所有租户)。 /// 租户 ID可选默认当前租户;禁止跨租户)。
/// </summary> /// </summary>
public long? TenantId { get; init; } public long? TenantId { get; init; }

View File

@@ -43,7 +43,7 @@ public interface ITenantBillingRepository
Task<TenantBillingStatement?> FindByStatementNoAsync(long tenantId, string statementNo, CancellationToken cancellationToken = default); Task<TenantBillingStatement?> FindByStatementNoAsync(long tenantId, string statementNo, CancellationToken cancellationToken = default);
/// <summary> /// <summary>
/// 按账单编号获取账单(不限租户,管理员端使用)。 /// 按账单编号获取账单(不限租户,系统任务使用)。
/// </summary> /// </summary>
/// <param name="statementNo">账单编号。</param> /// <param name="statementNo">账单编号。</param>
/// <param name="cancellationToken">取消标记。</param> /// <param name="cancellationToken">取消标记。</param>
@@ -86,7 +86,7 @@ public interface ITenantBillingRepository
Task<IReadOnlyList<TenantBillingStatement>> GetByTenantIdAsync(long tenantId, CancellationToken cancellationToken = default); Task<IReadOnlyList<TenantBillingStatement>> GetByTenantIdAsync(long tenantId, CancellationToken cancellationToken = default);
/// <summary> /// <summary>
/// 按 ID 列表批量获取账单(管理员端/批量操作场景)。 /// 按 ID 列表批量获取账单(系统任务/批量操作场景)。
/// </summary> /// </summary>
/// <param name="billingIds">账单 ID 列表。</param> /// <param name="billingIds">账单 ID 列表。</param>
/// <param name="cancellationToken">取消标记。</param> /// <param name="cancellationToken">取消标记。</param>
@@ -119,7 +119,7 @@ public interface ITenantBillingRepository
Task SaveChangesAsync(CancellationToken cancellationToken = default); Task SaveChangesAsync(CancellationToken cancellationToken = default);
/// <summary> /// <summary>
/// 管理员端分页查询账单列表(跨租户)。 /// 系统任务分页查询账单列表(跨租户)。
/// </summary> /// </summary>
/// <param name="tenantId">租户 ID 筛选(可选)。</param> /// <param name="tenantId">租户 ID 筛选(可选)。</param>
/// <param name="status">账单状态筛选(可选)。</param> /// <param name="status">账单状态筛选(可选)。</param>
@@ -147,7 +147,7 @@ public interface ITenantBillingRepository
/// <summary> /// <summary>
/// 获取账单统计数据(用于报表与仪表盘)。 /// 获取账单统计数据(用于报表与仪表盘)。
/// </summary> /// </summary>
/// <param name="tenantId">租户 ID可选管理员可查询所有租户)。</param> /// <param name="tenantId">租户 ID可选系统任务可跨租户统计)。</param>
/// <param name="startDate">统计开始时间UTC。</param> /// <param name="startDate">统计开始时间UTC。</param>
/// <param name="endDate">统计结束时间UTC。</param> /// <param name="endDate">统计结束时间UTC。</param>
/// <param name="groupBy">分组方式Day/Week/Month。</param> /// <param name="groupBy">分组方式Day/Week/Month。</param>
@@ -161,7 +161,7 @@ public interface ITenantBillingRepository
CancellationToken cancellationToken = default); CancellationToken cancellationToken = default);
/// <summary> /// <summary>
/// 按 ID 获取账单(不限租户,管理员端使用)。 /// 按 ID 获取账单(不限租户,系统任务使用)。
/// </summary> /// </summary>
/// <param name="billingId">账单 ID。</param> /// <param name="billingId">账单 ID。</param>
/// <param name="cancellationToken">取消标记。</param> /// <param name="cancellationToken">取消标记。</param>

View File

@@ -151,7 +151,7 @@ public sealed class TenantBillingRepository(TakeoutAppDbContext context) : ITena
return Array.Empty<TenantBillingStatement>(); return Array.Empty<TenantBillingStatement>();
} }
// 1. 忽略全局过滤器以支持管理员端跨租户导出/批量操作 // 1. 忽略全局过滤器以支持系统任务跨租户导出/批量操作
return await context.TenantBillingStatements return await context.TenantBillingStatements
.IgnoreQueryFilters() .IgnoreQueryFilters()
.AsNoTracking() .AsNoTracking()
@@ -192,7 +192,7 @@ public sealed class TenantBillingRepository(TakeoutAppDbContext context) : ITena
int pageSize, int pageSize,
CancellationToken cancellationToken = default) CancellationToken cancellationToken = default)
{ {
// 1. 构建基础查询(管理员端跨租户查询,忽略过滤器) // 1. 构建基础查询(系统任务跨租户查询,忽略过滤器)
var query = context.TenantBillingStatements var query = context.TenantBillingStatements
.IgnoreQueryFilters() .IgnoreQueryFilters()
.AsNoTracking() .AsNoTracking()