fix: 禁止账单跨租户查询
This commit is contained in:
@@ -5,7 +5,7 @@ using TakeoutSaaS.Shared.Abstractions.Serialization;
|
||||
namespace TakeoutSaaS.Application.App.Billings.Dto;
|
||||
|
||||
/// <summary>
|
||||
/// 账单详情 DTO(管理员端)。
|
||||
/// 账单详情 DTO(租户端)。
|
||||
/// </summary>
|
||||
public sealed record BillingDetailDto
|
||||
{
|
||||
|
||||
@@ -364,7 +364,7 @@ public sealed record PaymentRecordDto
|
||||
public sealed record BillingStatisticsDto
|
||||
{
|
||||
/// <summary>
|
||||
/// 租户 ID(为空表示跨租户统计)。
|
||||
/// 租户 ID(当前租户)。
|
||||
/// </summary>
|
||||
[JsonConverter(typeof(NullableSnowflakeIdJsonConverter))]
|
||||
public long? TenantId { get; init; }
|
||||
|
||||
@@ -5,7 +5,7 @@ using TakeoutSaaS.Shared.Abstractions.Serialization;
|
||||
namespace TakeoutSaaS.Application.App.Billings.Dto;
|
||||
|
||||
/// <summary>
|
||||
/// 账单列表 DTO(管理员端列表展示)。
|
||||
/// 账单列表 DTO(租户端列表展示)。
|
||||
/// </summary>
|
||||
public sealed record BillingListDto
|
||||
{
|
||||
|
||||
@@ -9,7 +9,7 @@ namespace TakeoutSaaS.Application.App.Billings.Dto;
|
||||
public sealed record BillingStatisticsDto
|
||||
{
|
||||
/// <summary>
|
||||
/// 租户 ID(可选,管理员可跨租户统计)。
|
||||
/// 租户 ID(当前租户)。
|
||||
/// </summary>
|
||||
[JsonConverter(typeof(NullableSnowflakeIdJsonConverter))]
|
||||
public long? TenantId { get; init; }
|
||||
|
||||
@@ -5,7 +5,7 @@ using TakeoutSaaS.Shared.Abstractions.Serialization;
|
||||
namespace TakeoutSaaS.Application.App.Billings.Dto;
|
||||
|
||||
/// <summary>
|
||||
/// 支付记录 DTO(管理员端)。
|
||||
/// 支付记录 DTO(租户端)。
|
||||
/// </summary>
|
||||
public sealed record PaymentRecordDto
|
||||
{
|
||||
|
||||
@@ -6,7 +6,9 @@ using TakeoutSaaS.Application.App.Billings.Queries;
|
||||
using TakeoutSaaS.Domain.Tenants.Enums;
|
||||
using TakeoutSaaS.Shared.Abstractions.Constants;
|
||||
using TakeoutSaaS.Shared.Abstractions.Data;
|
||||
using TakeoutSaaS.Shared.Abstractions.Exceptions;
|
||||
using TakeoutSaaS.Shared.Abstractions.Results;
|
||||
using TakeoutSaaS.Shared.Abstractions.Tenancy;
|
||||
|
||||
namespace TakeoutSaaS.Application.App.Billings.Handlers;
|
||||
|
||||
@@ -14,7 +16,8 @@ namespace TakeoutSaaS.Application.App.Billings.Handlers;
|
||||
/// 分页查询账单列表处理器。
|
||||
/// </summary>
|
||||
public sealed class GetBillingListQueryHandler(
|
||||
IDapperExecutor dapperExecutor)
|
||||
IDapperExecutor dapperExecutor,
|
||||
ITenantProvider tenantProvider)
|
||||
: IRequestHandler<GetBillingListQuery, PagedResult<BillingListDto>>
|
||||
{
|
||||
/// <summary>
|
||||
@@ -25,7 +28,21 @@ public sealed class GetBillingListQueryHandler(
|
||||
/// <returns>分页账单列表 DTO。</returns>
|
||||
public async Task<PagedResult<BillingListDto>> Handle(GetBillingListQuery request, CancellationToken cancellationToken)
|
||||
{
|
||||
// 1. 参数规范化
|
||||
// 1. 校验租户上下文(租户端禁止跨租户)
|
||||
var currentTenantId = tenantProvider.GetCurrentTenantId();
|
||||
if (currentTenantId <= 0)
|
||||
{
|
||||
throw new BusinessException(ErrorCodes.BadRequest, "缺少租户标识");
|
||||
}
|
||||
|
||||
// 2. (空行后) 禁止跨租户查询
|
||||
if (request.TenantId.HasValue && request.TenantId.Value != currentTenantId)
|
||||
{
|
||||
throw new BusinessException(ErrorCodes.Forbidden, "禁止跨租户查询账单");
|
||||
}
|
||||
var tenantId = currentTenantId;
|
||||
|
||||
// 3. (空行后) 参数规范化
|
||||
var page = request.PageNumber <= 0 ? 1 : request.PageNumber;
|
||||
var pageSize = request.PageSize is <= 0 or > 200 ? 20 : request.PageSize;
|
||||
var keyword = string.IsNullOrWhiteSpace(request.Keyword) ? null : request.Keyword.Trim();
|
||||
@@ -61,7 +78,7 @@ public sealed class GetBillingListQueryHandler(
|
||||
connection,
|
||||
BuildCountSql(),
|
||||
[
|
||||
("tenantId", request.TenantId),
|
||||
("tenantId", tenantId),
|
||||
("status", request.Status.HasValue ? (int)request.Status.Value : null),
|
||||
("billingType", request.BillingType.HasValue ? (int)request.BillingType.Value : null),
|
||||
("startDate", request.StartDate),
|
||||
@@ -78,7 +95,7 @@ public sealed class GetBillingListQueryHandler(
|
||||
connection,
|
||||
listSql,
|
||||
[
|
||||
("tenantId", request.TenantId),
|
||||
("tenantId", tenantId),
|
||||
("status", request.Status.HasValue ? (int)request.Status.Value : null),
|
||||
("billingType", request.BillingType.HasValue ? (int)request.BillingType.Value : null),
|
||||
("startDate", request.StartDate),
|
||||
@@ -145,7 +162,7 @@ public sealed class GetBillingListQueryHandler(
|
||||
from public.tenant_billing_statements b
|
||||
join public.tenants t on t."Id" = b."TenantId" and t."DeletedAt" is null
|
||||
where b."DeletedAt" is null
|
||||
and (@tenantId::bigint is null or b."TenantId" = @tenantId)
|
||||
and b."TenantId" = @tenantId
|
||||
and (@status::int is null or b."Status" = @status)
|
||||
and (@billingType::int is null or b."BillingType" = @billingType)
|
||||
and (@startDate::timestamp with time zone is null or b."PeriodStart" >= @startDate)
|
||||
@@ -186,7 +203,7 @@ public sealed class GetBillingListQueryHandler(
|
||||
from public.tenant_billing_statements b
|
||||
join public.tenants t on t."Id" = b."TenantId" and t."DeletedAt" is null
|
||||
where b."DeletedAt" is null
|
||||
and (@tenantId::bigint is null or b."TenantId" = @tenantId)
|
||||
and b."TenantId" = @tenantId
|
||||
and (@status::int is null or b."Status" = @status)
|
||||
and (@billingType::int is null or b."BillingType" = @billingType)
|
||||
and (@startDate::timestamp with time zone is null or b."PeriodStart" >= @startDate)
|
||||
|
||||
@@ -6,6 +6,8 @@ using TakeoutSaaS.Application.App.Billings.Queries;
|
||||
using TakeoutSaaS.Domain.Tenants.Enums;
|
||||
using TakeoutSaaS.Shared.Abstractions.Constants;
|
||||
using TakeoutSaaS.Shared.Abstractions.Data;
|
||||
using TakeoutSaaS.Shared.Abstractions.Exceptions;
|
||||
using TakeoutSaaS.Shared.Abstractions.Tenancy;
|
||||
|
||||
namespace TakeoutSaaS.Application.App.Billings.Handlers;
|
||||
|
||||
@@ -13,7 +15,8 @@ namespace TakeoutSaaS.Application.App.Billings.Handlers;
|
||||
/// 查询账单统计数据处理器。
|
||||
/// </summary>
|
||||
public sealed class GetBillingStatisticsQueryHandler(
|
||||
IDapperExecutor dapperExecutor)
|
||||
IDapperExecutor dapperExecutor,
|
||||
ITenantProvider tenantProvider)
|
||||
: IRequestHandler<GetBillingStatisticsQuery, BillingStatisticsDto>
|
||||
{
|
||||
/// <summary>
|
||||
@@ -24,7 +27,21 @@ public sealed class GetBillingStatisticsQueryHandler(
|
||||
/// <returns>账单统计数据 DTO。</returns>
|
||||
public async Task<BillingStatisticsDto> Handle(GetBillingStatisticsQuery request, CancellationToken cancellationToken)
|
||||
{
|
||||
// 1. 参数规范化
|
||||
// 1. 校验租户上下文(租户端禁止跨租户)
|
||||
var currentTenantId = tenantProvider.GetCurrentTenantId();
|
||||
if (currentTenantId <= 0)
|
||||
{
|
||||
throw new BusinessException(ErrorCodes.BadRequest, "缺少租户标识");
|
||||
}
|
||||
|
||||
// 2. (空行后) 禁止跨租户统计
|
||||
if (request.TenantId.HasValue && request.TenantId.Value != currentTenantId)
|
||||
{
|
||||
throw new BusinessException(ErrorCodes.Forbidden, "禁止跨租户统计账单");
|
||||
}
|
||||
var tenantId = currentTenantId;
|
||||
|
||||
// 3. (空行后) 参数规范化
|
||||
var startDate = request.StartDate ?? DateTime.UtcNow.AddMonths(-1);
|
||||
var endDate = request.EndDate ?? DateTime.UtcNow;
|
||||
var groupBy = NormalizeGroupBy(request.GroupBy);
|
||||
@@ -40,7 +57,7 @@ public sealed class GetBillingStatisticsQueryHandler(
|
||||
connection,
|
||||
BuildSummarySql(),
|
||||
[
|
||||
("tenantId", request.TenantId),
|
||||
("tenantId", tenantId),
|
||||
("startDate", startDate),
|
||||
("endDate", endDate),
|
||||
("now", DateTime.UtcNow)
|
||||
@@ -64,7 +81,7 @@ public sealed class GetBillingStatisticsQueryHandler(
|
||||
connection,
|
||||
BuildTrendSql(groupBy),
|
||||
[
|
||||
("tenantId", request.TenantId),
|
||||
("tenantId", tenantId),
|
||||
("startDate", startDate),
|
||||
("endDate", endDate)
|
||||
]);
|
||||
@@ -86,7 +103,7 @@ public sealed class GetBillingStatisticsQueryHandler(
|
||||
// 2.3 组装 DTO
|
||||
return new BillingStatisticsDto
|
||||
{
|
||||
TenantId = request.TenantId,
|
||||
TenantId = tenantId,
|
||||
StartDate = startDate,
|
||||
EndDate = endDate,
|
||||
GroupBy = groupBy,
|
||||
@@ -138,7 +155,7 @@ public sealed class GetBillingStatisticsQueryHandler(
|
||||
), 0)::numeric as "TotalOverdueAmount"
|
||||
from public.tenant_billing_statements b
|
||||
where b."DeletedAt" is null
|
||||
and (@tenantId::bigint is null or b."TenantId" = @tenantId)
|
||||
and b."TenantId" = @tenantId
|
||||
and b."PeriodStart" >= @startDate
|
||||
and b."PeriodEnd" <= @endDate;
|
||||
""";
|
||||
@@ -161,7 +178,7 @@ public sealed class GetBillingStatisticsQueryHandler(
|
||||
count(*)::int as "Count"
|
||||
from public.tenant_billing_statements b
|
||||
where b."DeletedAt" is null
|
||||
and (@tenantId::bigint is null or b."TenantId" = @tenantId)
|
||||
and b."TenantId" = @tenantId
|
||||
and b."PeriodStart" >= @startDate
|
||||
and b."PeriodEnd" <= @endDate
|
||||
group by 1
|
||||
|
||||
@@ -11,7 +11,7 @@ namespace TakeoutSaaS.Application.App.Billings.Queries;
|
||||
public sealed record GetBillingListQuery : IRequest<PagedResult<BillingListDto>>
|
||||
{
|
||||
/// <summary>
|
||||
/// 租户 ID(可选,管理员可查询所有租户)。
|
||||
/// 租户 ID(可选,默认当前租户;禁止跨租户)。
|
||||
/// </summary>
|
||||
public long? TenantId { get; init; }
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ namespace TakeoutSaaS.Application.App.Billings.Queries;
|
||||
public sealed record GetBillingStatisticsQuery : IRequest<BillingStatisticsDto>
|
||||
{
|
||||
/// <summary>
|
||||
/// 租户 ID(可选,管理员可查询所有租户)。
|
||||
/// 租户 ID(可选,默认当前租户;禁止跨租户)。
|
||||
/// </summary>
|
||||
public long? TenantId { get; init; }
|
||||
|
||||
|
||||
Reference in New Issue
Block a user