fix: 禁止账单跨租户查询

This commit is contained in:
root
2026-01-29 14:51:56 +00:00
parent 1622c38043
commit a0b77d4847
11 changed files with 61 additions and 27 deletions

View File

@@ -6,7 +6,9 @@ using TakeoutSaaS.Application.App.Billings.Queries;
using TakeoutSaaS.Domain.Tenants.Enums;
using TakeoutSaaS.Shared.Abstractions.Constants;
using TakeoutSaaS.Shared.Abstractions.Data;
using TakeoutSaaS.Shared.Abstractions.Exceptions;
using TakeoutSaaS.Shared.Abstractions.Results;
using TakeoutSaaS.Shared.Abstractions.Tenancy;
namespace TakeoutSaaS.Application.App.Billings.Handlers;
@@ -14,7 +16,8 @@ namespace TakeoutSaaS.Application.App.Billings.Handlers;
/// 分页查询账单列表处理器。
/// </summary>
public sealed class GetBillingListQueryHandler(
IDapperExecutor dapperExecutor)
IDapperExecutor dapperExecutor,
ITenantProvider tenantProvider)
: IRequestHandler<GetBillingListQuery, PagedResult<BillingListDto>>
{
/// <summary>
@@ -25,7 +28,21 @@ public sealed class GetBillingListQueryHandler(
/// <returns>分页账单列表 DTO。</returns>
public async Task<PagedResult<BillingListDto>> Handle(GetBillingListQuery request, CancellationToken cancellationToken)
{
// 1. 参数规范化
// 1. 校验租户上下文(租户端禁止跨租户)
var currentTenantId = tenantProvider.GetCurrentTenantId();
if (currentTenantId <= 0)
{
throw new BusinessException(ErrorCodes.BadRequest, "缺少租户标识");
}
// 2. (空行后) 禁止跨租户查询
if (request.TenantId.HasValue && request.TenantId.Value != currentTenantId)
{
throw new BusinessException(ErrorCodes.Forbidden, "禁止跨租户查询账单");
}
var tenantId = currentTenantId;
// 3. (空行后) 参数规范化
var page = request.PageNumber <= 0 ? 1 : request.PageNumber;
var pageSize = request.PageSize is <= 0 or > 200 ? 20 : request.PageSize;
var keyword = string.IsNullOrWhiteSpace(request.Keyword) ? null : request.Keyword.Trim();
@@ -61,7 +78,7 @@ public sealed class GetBillingListQueryHandler(
connection,
BuildCountSql(),
[
("tenantId", request.TenantId),
("tenantId", tenantId),
("status", request.Status.HasValue ? (int)request.Status.Value : null),
("billingType", request.BillingType.HasValue ? (int)request.BillingType.Value : null),
("startDate", request.StartDate),
@@ -78,7 +95,7 @@ public sealed class GetBillingListQueryHandler(
connection,
listSql,
[
("tenantId", request.TenantId),
("tenantId", tenantId),
("status", request.Status.HasValue ? (int)request.Status.Value : null),
("billingType", request.BillingType.HasValue ? (int)request.BillingType.Value : null),
("startDate", request.StartDate),
@@ -145,7 +162,7 @@ public sealed class GetBillingListQueryHandler(
from public.tenant_billing_statements b
join public.tenants t on t."Id" = b."TenantId" and t."DeletedAt" is null
where b."DeletedAt" is null
and (@tenantId::bigint is null or b."TenantId" = @tenantId)
and b."TenantId" = @tenantId
and (@status::int is null or b."Status" = @status)
and (@billingType::int is null or b."BillingType" = @billingType)
and (@startDate::timestamp with time zone is null or b."PeriodStart" >= @startDate)
@@ -186,7 +203,7 @@ public sealed class GetBillingListQueryHandler(
from public.tenant_billing_statements b
join public.tenants t on t."Id" = b."TenantId" and t."DeletedAt" is null
where b."DeletedAt" is null
and (@tenantId::bigint is null or b."TenantId" = @tenantId)
and b."TenantId" = @tenantId
and (@status::int is null or b."Status" = @status)
and (@billingType::int is null or b."BillingType" = @billingType)
and (@startDate::timestamp with time zone is null or b."PeriodStart" >= @startDate)