fix: 禁止账单跨租户查询

This commit is contained in:
root
2026-01-29 14:51:56 +00:00
parent 1622c38043
commit a0b77d4847
11 changed files with 61 additions and 27 deletions

View File

@@ -5,7 +5,7 @@ using TakeoutSaaS.Shared.Abstractions.Serialization;
namespace TakeoutSaaS.Application.App.Billings.Dto;
/// <summary>
/// 账单详情 DTO管理员端)。
/// 账单详情 DTO租户端)。
/// </summary>
public sealed record BillingDetailDto
{

View File

@@ -364,7 +364,7 @@ public sealed record PaymentRecordDto
public sealed record BillingStatisticsDto
{
/// <summary>
/// 租户 ID为空表示跨租户统计)。
/// 租户 ID当前租户)。
/// </summary>
[JsonConverter(typeof(NullableSnowflakeIdJsonConverter))]
public long? TenantId { get; init; }

View File

@@ -5,7 +5,7 @@ using TakeoutSaaS.Shared.Abstractions.Serialization;
namespace TakeoutSaaS.Application.App.Billings.Dto;
/// <summary>
/// 账单列表 DTO管理员端列表展示)。
/// 账单列表 DTO租户端列表展示)。
/// </summary>
public sealed record BillingListDto
{

View File

@@ -9,7 +9,7 @@ namespace TakeoutSaaS.Application.App.Billings.Dto;
public sealed record BillingStatisticsDto
{
/// <summary>
/// 租户 ID可选,管理员可跨租户统计)。
/// 租户 ID当前租户)。
/// </summary>
[JsonConverter(typeof(NullableSnowflakeIdJsonConverter))]
public long? TenantId { get; init; }

View File

@@ -5,7 +5,7 @@ using TakeoutSaaS.Shared.Abstractions.Serialization;
namespace TakeoutSaaS.Application.App.Billings.Dto;
/// <summary>
/// 支付记录 DTO管理员端)。
/// 支付记录 DTO租户端)。
/// </summary>
public sealed record PaymentRecordDto
{

View File

@@ -6,7 +6,9 @@ using TakeoutSaaS.Application.App.Billings.Queries;
using TakeoutSaaS.Domain.Tenants.Enums;
using TakeoutSaaS.Shared.Abstractions.Constants;
using TakeoutSaaS.Shared.Abstractions.Data;
using TakeoutSaaS.Shared.Abstractions.Exceptions;
using TakeoutSaaS.Shared.Abstractions.Results;
using TakeoutSaaS.Shared.Abstractions.Tenancy;
namespace TakeoutSaaS.Application.App.Billings.Handlers;
@@ -14,7 +16,8 @@ namespace TakeoutSaaS.Application.App.Billings.Handlers;
/// 分页查询账单列表处理器。
/// </summary>
public sealed class GetBillingListQueryHandler(
IDapperExecutor dapperExecutor)
IDapperExecutor dapperExecutor,
ITenantProvider tenantProvider)
: IRequestHandler<GetBillingListQuery, PagedResult<BillingListDto>>
{
/// <summary>
@@ -25,7 +28,21 @@ public sealed class GetBillingListQueryHandler(
/// <returns>分页账单列表 DTO。</returns>
public async Task<PagedResult<BillingListDto>> Handle(GetBillingListQuery request, CancellationToken cancellationToken)
{
// 1. 参数规范化
// 1. 校验租户上下文(租户端禁止跨租户)
var currentTenantId = tenantProvider.GetCurrentTenantId();
if (currentTenantId <= 0)
{
throw new BusinessException(ErrorCodes.BadRequest, "缺少租户标识");
}
// 2. (空行后) 禁止跨租户查询
if (request.TenantId.HasValue && request.TenantId.Value != currentTenantId)
{
throw new BusinessException(ErrorCodes.Forbidden, "禁止跨租户查询账单");
}
var tenantId = currentTenantId;
// 3. (空行后) 参数规范化
var page = request.PageNumber <= 0 ? 1 : request.PageNumber;
var pageSize = request.PageSize is <= 0 or > 200 ? 20 : request.PageSize;
var keyword = string.IsNullOrWhiteSpace(request.Keyword) ? null : request.Keyword.Trim();
@@ -61,7 +78,7 @@ public sealed class GetBillingListQueryHandler(
connection,
BuildCountSql(),
[
("tenantId", request.TenantId),
("tenantId", tenantId),
("status", request.Status.HasValue ? (int)request.Status.Value : null),
("billingType", request.BillingType.HasValue ? (int)request.BillingType.Value : null),
("startDate", request.StartDate),
@@ -78,7 +95,7 @@ public sealed class GetBillingListQueryHandler(
connection,
listSql,
[
("tenantId", request.TenantId),
("tenantId", tenantId),
("status", request.Status.HasValue ? (int)request.Status.Value : null),
("billingType", request.BillingType.HasValue ? (int)request.BillingType.Value : null),
("startDate", request.StartDate),
@@ -145,7 +162,7 @@ public sealed class GetBillingListQueryHandler(
from public.tenant_billing_statements b
join public.tenants t on t."Id" = b."TenantId" and t."DeletedAt" is null
where b."DeletedAt" is null
and (@tenantId::bigint is null or b."TenantId" = @tenantId)
and b."TenantId" = @tenantId
and (@status::int is null or b."Status" = @status)
and (@billingType::int is null or b."BillingType" = @billingType)
and (@startDate::timestamp with time zone is null or b."PeriodStart" >= @startDate)
@@ -186,7 +203,7 @@ public sealed class GetBillingListQueryHandler(
from public.tenant_billing_statements b
join public.tenants t on t."Id" = b."TenantId" and t."DeletedAt" is null
where b."DeletedAt" is null
and (@tenantId::bigint is null or b."TenantId" = @tenantId)
and b."TenantId" = @tenantId
and (@status::int is null or b."Status" = @status)
and (@billingType::int is null or b."BillingType" = @billingType)
and (@startDate::timestamp with time zone is null or b."PeriodStart" >= @startDate)

View File

@@ -6,6 +6,8 @@ using TakeoutSaaS.Application.App.Billings.Queries;
using TakeoutSaaS.Domain.Tenants.Enums;
using TakeoutSaaS.Shared.Abstractions.Constants;
using TakeoutSaaS.Shared.Abstractions.Data;
using TakeoutSaaS.Shared.Abstractions.Exceptions;
using TakeoutSaaS.Shared.Abstractions.Tenancy;
namespace TakeoutSaaS.Application.App.Billings.Handlers;
@@ -13,7 +15,8 @@ namespace TakeoutSaaS.Application.App.Billings.Handlers;
/// 查询账单统计数据处理器。
/// </summary>
public sealed class GetBillingStatisticsQueryHandler(
IDapperExecutor dapperExecutor)
IDapperExecutor dapperExecutor,
ITenantProvider tenantProvider)
: IRequestHandler<GetBillingStatisticsQuery, BillingStatisticsDto>
{
/// <summary>
@@ -24,7 +27,21 @@ public sealed class GetBillingStatisticsQueryHandler(
/// <returns>账单统计数据 DTO。</returns>
public async Task<BillingStatisticsDto> Handle(GetBillingStatisticsQuery request, CancellationToken cancellationToken)
{
// 1. 参数规范化
// 1. 校验租户上下文(租户端禁止跨租户)
var currentTenantId = tenantProvider.GetCurrentTenantId();
if (currentTenantId <= 0)
{
throw new BusinessException(ErrorCodes.BadRequest, "缺少租户标识");
}
// 2. (空行后) 禁止跨租户统计
if (request.TenantId.HasValue && request.TenantId.Value != currentTenantId)
{
throw new BusinessException(ErrorCodes.Forbidden, "禁止跨租户统计账单");
}
var tenantId = currentTenantId;
// 3. (空行后) 参数规范化
var startDate = request.StartDate ?? DateTime.UtcNow.AddMonths(-1);
var endDate = request.EndDate ?? DateTime.UtcNow;
var groupBy = NormalizeGroupBy(request.GroupBy);
@@ -40,7 +57,7 @@ public sealed class GetBillingStatisticsQueryHandler(
connection,
BuildSummarySql(),
[
("tenantId", request.TenantId),
("tenantId", tenantId),
("startDate", startDate),
("endDate", endDate),
("now", DateTime.UtcNow)
@@ -64,7 +81,7 @@ public sealed class GetBillingStatisticsQueryHandler(
connection,
BuildTrendSql(groupBy),
[
("tenantId", request.TenantId),
("tenantId", tenantId),
("startDate", startDate),
("endDate", endDate)
]);
@@ -86,7 +103,7 @@ public sealed class GetBillingStatisticsQueryHandler(
// 2.3 组装 DTO
return new BillingStatisticsDto
{
TenantId = request.TenantId,
TenantId = tenantId,
StartDate = startDate,
EndDate = endDate,
GroupBy = groupBy,
@@ -138,7 +155,7 @@ public sealed class GetBillingStatisticsQueryHandler(
), 0)::numeric as "TotalOverdueAmount"
from public.tenant_billing_statements b
where b."DeletedAt" is null
and (@tenantId::bigint is null or b."TenantId" = @tenantId)
and b."TenantId" = @tenantId
and b."PeriodStart" >= @startDate
and b."PeriodEnd" <= @endDate;
""";
@@ -161,7 +178,7 @@ public sealed class GetBillingStatisticsQueryHandler(
count(*)::int as "Count"
from public.tenant_billing_statements b
where b."DeletedAt" is null
and (@tenantId::bigint is null or b."TenantId" = @tenantId)
and b."TenantId" = @tenantId
and b."PeriodStart" >= @startDate
and b."PeriodEnd" <= @endDate
group by 1

View File

@@ -11,7 +11,7 @@ namespace TakeoutSaaS.Application.App.Billings.Queries;
public sealed record GetBillingListQuery : IRequest<PagedResult<BillingListDto>>
{
/// <summary>
/// 租户 ID可选管理员可查询所有租户)。
/// 租户 ID可选默认当前租户;禁止跨租户)。
/// </summary>
public long? TenantId { get; init; }

View File

@@ -9,7 +9,7 @@ namespace TakeoutSaaS.Application.App.Billings.Queries;
public sealed record GetBillingStatisticsQuery : IRequest<BillingStatisticsDto>
{
/// <summary>
/// 租户 ID可选管理员可查询所有租户)。
/// 租户 ID可选默认当前租户;禁止跨租户)。
/// </summary>
public long? TenantId { get; init; }