From 8c08c8947e73d7ce6ed298b2d71606731256fa5f Mon Sep 17 00:00:00 2001 From: Jerry Yan <792602257@qq.com> Date: Wed, 3 Dec 2025 18:49:03 +0800 Subject: [PATCH] =?UTF-8?q?feat(pipeline):=20=E5=A2=9E=E5=BC=BA=E4=BA=BA?= =?UTF-8?q?=E8=84=B8=E5=8C=B9=E9=85=8D=E6=B5=81=E6=B0=B4=E7=BA=BF=E7=9A=84?= =?UTF-8?q?=E5=81=A5=E5=A3=AE=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在BuildSourceRelationStage中增加sampleListIds空值检查与降级处理 - 在PersistRelationsStage中增加memberSourceList空值检查与提前跳过逻辑 - 为BuildSourceRelationStage、DeleteOldRelationsStage和PersistRelationsStage添加完整的单元测试覆盖 - 实现异常情况下的优雅降级与错误日志记录 - 完善上下文状态管理与阶段跳过机制 --- .../stages/BuildSourceRelationStage.java | 17 ++ .../stages/PersistRelationsStage.java | 6 + .../stages/BuildSourceRelationStageTest.java | 245 ++++++++++++++++ .../stages/DeleteOldRelationsStageTest.java | 175 ++++++++++++ .../stages/PersistRelationsStageTest.java | 262 ++++++++++++++++++ 5 files changed, 705 insertions(+) create mode 100644 src/test/java/com/ycwl/basic/face/pipeline/stages/BuildSourceRelationStageTest.java create mode 100644 src/test/java/com/ycwl/basic/face/pipeline/stages/DeleteOldRelationsStageTest.java create mode 100644 src/test/java/com/ycwl/basic/face/pipeline/stages/PersistRelationsStageTest.java diff --git a/src/main/java/com/ycwl/basic/face/pipeline/stages/BuildSourceRelationStage.java b/src/main/java/com/ycwl/basic/face/pipeline/stages/BuildSourceRelationStage.java index f1f56355..013c469e 100644 --- a/src/main/java/com/ycwl/basic/face/pipeline/stages/BuildSourceRelationStage.java +++ b/src/main/java/com/ycwl/basic/face/pipeline/stages/BuildSourceRelationStage.java @@ -67,6 +67,23 @@ public class BuildSourceRelationStage extends AbstractFaceMatchingStage sampleListIds = context.getSampleListIds(); Long faceId = context.getFaceId(); + // 防御性检查:sampleListIds为空 + if (sampleListIds == null || sampleListIds.isEmpty()) { + // 尝试从searchResult中获取 + if (context.getSearchResult() != null) { + sampleListIds = context.getSearchResult().getSampleListIds(); + if (sampleListIds != null && !sampleListIds.isEmpty()) { + context.setSampleListIds(sampleListIds); + } else { + log.debug("sampleListIds为空,跳过源文件关联,faceId={}", faceId); + return StageResult.skipped("sampleListIds为空"); + } + } else { + log.debug("sampleListIds为空,跳过源文件关联,faceId={}", faceId); + return StageResult.skipped("sampleListIds为空"); + } + } + try { // 处理源文件关联 List memberSourceEntityList = diff --git a/src/main/java/com/ycwl/basic/face/pipeline/stages/PersistRelationsStage.java b/src/main/java/com/ycwl/basic/face/pipeline/stages/PersistRelationsStage.java index 0b981760..937f78a3 100644 --- a/src/main/java/com/ycwl/basic/face/pipeline/stages/PersistRelationsStage.java +++ b/src/main/java/com/ycwl/basic/face/pipeline/stages/PersistRelationsStage.java @@ -60,6 +60,12 @@ public class PersistRelationsStage extends AbstractFaceMatchingStage memberSourceEntityList = context.getMemberSourceList(); Long faceId = context.getFaceId(); + // 防御性检查:memberSourceList为空 + if (memberSourceEntityList == null || memberSourceEntityList.isEmpty()) { + log.debug("memberSourceList为空,跳过持久化,faceId={}", faceId); + return StageResult.skipped("memberSourceList为空"); + } + try { // 1. 过滤已存在的关联关系 List existingFiltered = sourceMapper.filterExistingRelations(memberSourceEntityList); diff --git a/src/test/java/com/ycwl/basic/face/pipeline/stages/BuildSourceRelationStageTest.java b/src/test/java/com/ycwl/basic/face/pipeline/stages/BuildSourceRelationStageTest.java new file mode 100644 index 00000000..8a32f105 --- /dev/null +++ b/src/test/java/com/ycwl/basic/face/pipeline/stages/BuildSourceRelationStageTest.java @@ -0,0 +1,245 @@ +package com.ycwl.basic.face.pipeline.stages; + +import com.ycwl.basic.face.pipeline.core.FaceMatchingContext; +import com.ycwl.basic.face.pipeline.core.StageResult; +import com.ycwl.basic.model.pc.face.entity.FaceEntity; +import com.ycwl.basic.model.pc.source.entity.MemberSourceEntity; +import com.ycwl.basic.model.task.resp.SearchFaceRespVo; +import com.ycwl.basic.service.pc.processor.SourceRelationProcessor; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.Mockito.*; + +/** + * BuildSourceRelationStage 单元测试 + */ +@ExtendWith(MockitoExtension.class) +class BuildSourceRelationStageTest { + + @Mock + private SourceRelationProcessor sourceRelationProcessor; + + @InjectMocks + private BuildSourceRelationStage stage; + + private FaceMatchingContext context; + private FaceEntity face; + + @BeforeEach + void setUp() { + context = FaceMatchingContext.forAutoMatching(1L, true); + + face = new FaceEntity(); + face.setId(1L); + face.setMemberId(100L); + face.setScenicId(10L); + + context.setFace(face); + } + + @Test + void testExecute_Success() { + // Given + List sampleListIds = Arrays.asList(101L, 102L, 103L); + context.setSampleListIds(sampleListIds); + + List memberSourceList = createMemberSourceList(3); + when(sourceRelationProcessor.processMemberSources(sampleListIds, face)) + .thenReturn(memberSourceList); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + assertTrue(result.getMessage().contains("构建了3个源文件关联")); + assertEquals(memberSourceList, context.getMemberSourceList()); + verify(sourceRelationProcessor, times(1)).processMemberSources(sampleListIds, face); + } + + @Test + void testExecute_SampleListIdsNull_FromSearchResult() { + // Given: sampleListIds为null,但searchResult有值 + context.setSampleListIds(null); + + SearchFaceRespVo searchResult = new SearchFaceRespVo(); + searchResult.setSampleListIds(Arrays.asList(101L, 102L)); + context.setSearchResult(searchResult); + + List memberSourceList = createMemberSourceList(2); + when(sourceRelationProcessor.processMemberSources(anyList(), any())) + .thenReturn(memberSourceList); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + assertEquals(Arrays.asList(101L, 102L), context.getSampleListIds()); // 从searchResult复制过来 + verify(sourceRelationProcessor, times(1)).processMemberSources(anyList(), any()); + } + + @Test + void testExecute_SampleListIdsEmpty_Skipped() { + // Given + context.setSampleListIds(new ArrayList<>()); + context.setSearchResult(null); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSkipped()); + assertTrue(result.getMessage().contains("sampleListIds为空")); + verify(sourceRelationProcessor, never()).processMemberSources(anyList(), any()); + } + + @Test + void testExecute_BothSampleListIdsAndSearchResultEmpty_Skipped() { + // Given + context.setSampleListIds(null); + + SearchFaceRespVo searchResult = new SearchFaceRespVo(); + searchResult.setSampleListIds(null); + context.setSearchResult(searchResult); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSkipped()); + verify(sourceRelationProcessor, never()).processMemberSources(anyList(), any()); + } + + @Test + void testExecute_ProcessorReturnsNull_Skipped() { + // Given + List sampleListIds = Arrays.asList(101L, 102L); + context.setSampleListIds(sampleListIds); + + when(sourceRelationProcessor.processMemberSources(sampleListIds, face)) + .thenReturn(null); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSkipped()); + assertTrue(result.getMessage().contains("未找到有效的源文件")); + assertNull(context.getMemberSourceList()); + } + + @Test + void testExecute_ProcessorReturnsEmpty_Skipped() { + // Given + List sampleListIds = Arrays.asList(101L, 102L); + context.setSampleListIds(sampleListIds); + + when(sourceRelationProcessor.processMemberSources(sampleListIds, face)) + .thenReturn(new ArrayList<>()); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSkipped()); + assertTrue(result.getMessage().contains("未找到有效的源文件")); + } + + @Test + void testExecute_ProcessorThrowsException_Degraded() { + // Given + List sampleListIds = Arrays.asList(101L, 102L); + context.setSampleListIds(sampleListIds); + + when(sourceRelationProcessor.processMemberSources(sampleListIds, face)) + .thenThrow(new RuntimeException("Processing error")); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isDegraded()); // 降级处理 + assertTrue(result.getMessage().contains("构建源文件关联失败")); + verify(sourceRelationProcessor, times(1)).processMemberSources(sampleListIds, face); + } + + @Test + void testExecute_SingleSample() { + // Given: 只有1个样本 + List sampleListIds = Arrays.asList(101L); + context.setSampleListIds(sampleListIds); + + List memberSourceList = createMemberSourceList(1); + when(sourceRelationProcessor.processMemberSources(sampleListIds, face)) + .thenReturn(memberSourceList); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + assertTrue(result.getMessage().contains("构建了1个源文件关联")); + } + + @Test + void testExecute_ManySamples() { + // Given: 大量样本 + List sampleListIds = Arrays.asList(101L, 102L, 103L, 104L, 105L, 106L, 107L, 108L, 109L, 110L); + context.setSampleListIds(sampleListIds); + + List memberSourceList = createMemberSourceList(10); + when(sourceRelationProcessor.processMemberSources(sampleListIds, face)) + .thenReturn(memberSourceList); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + assertTrue(result.getMessage().contains("构建了10个源文件关联")); + assertEquals(10, context.getMemberSourceList().size()); + } + + @Test + void testExecute_SampleListIdsDifferentFromResult() { + // Given: sampleListIds与processMemberSources返回的数量不同(部分无效) + List sampleListIds = Arrays.asList(101L, 102L, 103L, 104L, 105L); + context.setSampleListIds(sampleListIds); + + List memberSourceList = createMemberSourceList(3); // 只有3个有效 + when(sourceRelationProcessor.processMemberSources(sampleListIds, face)) + .thenReturn(memberSourceList); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + assertTrue(result.getMessage().contains("构建了3个源文件关联")); + assertEquals(3, context.getMemberSourceList().size()); + } + + private List createMemberSourceList(int count) { + List list = new ArrayList<>(); + for (int i = 0; i < count; i++) { + MemberSourceEntity entity = new MemberSourceEntity(); + entity.setMemberId(100L); + entity.setSourceId((long) (i + 1)); + list.add(entity); + } + return list; + } +} diff --git a/src/test/java/com/ycwl/basic/face/pipeline/stages/DeleteOldRelationsStageTest.java b/src/test/java/com/ycwl/basic/face/pipeline/stages/DeleteOldRelationsStageTest.java new file mode 100644 index 00000000..e4a29364 --- /dev/null +++ b/src/test/java/com/ycwl/basic/face/pipeline/stages/DeleteOldRelationsStageTest.java @@ -0,0 +1,175 @@ +package com.ycwl.basic.face.pipeline.stages; + +import com.ycwl.basic.face.pipeline.core.FaceMatchingContext; +import com.ycwl.basic.face.pipeline.core.StageResult; +import com.ycwl.basic.mapper.SourceMapper; +import com.ycwl.basic.mapper.VideoMapper; +import com.ycwl.basic.model.pc.face.entity.FaceEntity; +import com.ycwl.basic.repository.MemberRelationRepository; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.*; + +/** + * DeleteOldRelationsStage 单元测试 + */ +@ExtendWith(MockitoExtension.class) +class DeleteOldRelationsStageTest { + + @Mock + private SourceMapper sourceMapper; + + @Mock + private VideoMapper videoMapper; + + @Mock + private MemberRelationRepository memberRelationRepository; + + @InjectMocks + private DeleteOldRelationsStage stage; + + private FaceMatchingContext context; + private FaceEntity face; + + @BeforeEach + void setUp() { + context = FaceMatchingContext.forAutoMatching(1L, true); + + face = new FaceEntity(); + face.setId(1L); + face.setMemberId(100L); + face.setScenicId(10L); + + context.setFace(face); + } + + @Test + void testExecute_Success() { + // Given - mock方法默认不做任何事,无需doNothing() + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + assertTrue(result.getMessage().contains("旧关系数据已删除")); + verify(sourceMapper, times(1)).deleteNotBuyFaceRelation(100L, 1L); + verify(videoMapper, times(1)).deleteNotBuyFaceRelations(100L, 1L); + verify(memberRelationRepository, times(1)).clearSCacheByFace(1L); + } + + @Test + void testExecute_SourceMapperFailed_Degraded() { + // Given + doThrow(new RuntimeException("Database error")) + .when(sourceMapper).deleteNotBuyFaceRelation(100L, 1L); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isDegraded()); // 降级处理,不影响主流程 + assertTrue(result.getMessage().contains("删除旧关系数据失败")); + verify(sourceMapper, times(1)).deleteNotBuyFaceRelation(100L, 1L); + verify(videoMapper, never()).deleteNotBuyFaceRelations(anyLong(), anyLong()); + verify(memberRelationRepository, never()).clearSCacheByFace(anyLong()); + } + + @Test + void testExecute_VideoMapperFailed_Degraded() { + // Given - sourceMapper正常执行(无需doNothing) + doThrow(new RuntimeException("Delete video error")) + .when(videoMapper).deleteNotBuyFaceRelations(100L, 1L); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isDegraded()); + verify(sourceMapper, times(1)).deleteNotBuyFaceRelation(100L, 1L); + verify(videoMapper, times(1)).deleteNotBuyFaceRelations(100L, 1L); + verify(memberRelationRepository, never()).clearSCacheByFace(anyLong()); + } + + @Test + void testExecute_CacheClearFailed_Degraded() { + // Given - sourceMapper和videoMapper正常执行 + doThrow(new RuntimeException("Cache error")) + .when(memberRelationRepository).clearSCacheByFace(1L); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isDegraded()); + verify(memberRelationRepository, times(1)).clearSCacheByFace(1L); + } + + @Test + void testExecute_DifferentMemberId() { + // Given: 不同的memberId + face.setMemberId(999L); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + verify(sourceMapper, times(1)).deleteNotBuyFaceRelation(999L, 1L); + verify(videoMapper, times(1)).deleteNotBuyFaceRelations(999L, 1L); + } + + @Test + void testExecute_DifferentFaceId() { + // Given: 不同的faceId + context = FaceMatchingContext.forAutoMatching(888L, true); + face.setId(888L); + context.setFace(face); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + verify(sourceMapper, times(1)).deleteNotBuyFaceRelation(100L, 888L); + verify(videoMapper, times(1)).deleteNotBuyFaceRelations(100L, 888L); + verify(memberRelationRepository, times(1)).clearSCacheByFace(888L); + } + + @Test + void testExecute_NullPointerException_Degraded() { + // Given + doThrow(new NullPointerException("Null member")) + .when(sourceMapper).deleteNotBuyFaceRelation(100L, 1L); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isDegraded()); + assertTrue(result.getMessage().contains("删除旧关系数据失败")); + } + + @Test + void testExecute_PartialSuccess_Degraded() { + // Given: sourceMapper成功,但videoMapper失败 + doThrow(new RuntimeException("Video deletion failed")) + .when(videoMapper).deleteNotBuyFaceRelations(100L, 1L); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isDegraded()); + // source删除成功,但video删除失败 + verify(sourceMapper, times(1)).deleteNotBuyFaceRelation(100L, 1L); + verify(videoMapper, times(1)).deleteNotBuyFaceRelations(100L, 1L); + } +} diff --git a/src/test/java/com/ycwl/basic/face/pipeline/stages/PersistRelationsStageTest.java b/src/test/java/com/ycwl/basic/face/pipeline/stages/PersistRelationsStageTest.java new file mode 100644 index 00000000..837bc5d0 --- /dev/null +++ b/src/test/java/com/ycwl/basic/face/pipeline/stages/PersistRelationsStageTest.java @@ -0,0 +1,262 @@ +package com.ycwl.basic.face.pipeline.stages; + +import com.ycwl.basic.face.pipeline.core.FaceMatchingContext; +import com.ycwl.basic.face.pipeline.core.StageResult; +import com.ycwl.basic.mapper.SourceMapper; +import com.ycwl.basic.model.pc.source.entity.MemberSourceEntity; +import com.ycwl.basic.repository.MemberRelationRepository; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.*; + +/** + * PersistRelationsStage 单元测试 + */ +@ExtendWith(MockitoExtension.class) +class PersistRelationsStageTest { + + @Mock + private SourceMapper sourceMapper; + + @Mock + private MemberRelationRepository memberRelationRepository; + + @InjectMocks + private PersistRelationsStage stage; + + private FaceMatchingContext context; + + @BeforeEach + void setUp() { + context = FaceMatchingContext.forAutoMatching(1L, true); + } + + @Test + void testExecute_Success() { + // Given + List memberSourceList = createMemberSourceList(3); + context.setMemberSourceList(memberSourceList); + + List afterExistingFilter = createMemberSourceList(2); + List afterValidFilter = createMemberSourceList(2); + + when(sourceMapper.filterExistingRelations(memberSourceList)) + .thenReturn(afterExistingFilter); + when(sourceMapper.filterValidSourceRelations(afterExistingFilter)) + .thenReturn(afterValidFilter); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + assertTrue(result.getMessage().contains("持久化了2条关联关系")); + verify(sourceMapper, times(1)).filterExistingRelations(memberSourceList); + verify(sourceMapper, times(1)).filterValidSourceRelations(afterExistingFilter); + verify(sourceMapper, times(1)).addRelations(afterValidFilter); + verify(memberRelationRepository, times(1)).clearSCacheByFace(1L); + } + + @Test + void testExecute_MemberSourceListNull_Skipped() { + // Given + context.setMemberSourceList(null); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSkipped()); + assertTrue(result.getMessage().contains("memberSourceList为空")); + verify(sourceMapper, never()).filterExistingRelations(anyList()); + verify(sourceMapper, never()).addRelations(anyList()); + } + + @Test + void testExecute_MemberSourceListEmpty_Skipped() { + // Given + context.setMemberSourceList(new ArrayList<>()); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSkipped()); + assertTrue(result.getMessage().contains("memberSourceList为空")); + verify(sourceMapper, never()).filterExistingRelations(anyList()); + } + + @Test + void testExecute_AllFilteredByExisting_Skipped() { + // Given: 所有关系都已存在 + List memberSourceList = createMemberSourceList(5); + context.setMemberSourceList(memberSourceList); + + when(sourceMapper.filterExistingRelations(memberSourceList)) + .thenReturn(new ArrayList<>()); // 全部被过滤 + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSkipped()); + assertTrue(result.getMessage().contains("没有有效的关联关系可创建")); + verify(sourceMapper, times(1)).filterExistingRelations(memberSourceList); + verify(sourceMapper, never()).addRelations(anyList()); + verify(memberRelationRepository, never()).clearSCacheByFace(anyLong()); + } + + @Test + void testExecute_AllFilteredByInvalid_Skipped() { + // Given: 过滤掉已存在的后,剩余的都是无效引用 + List memberSourceList = createMemberSourceList(5); + context.setMemberSourceList(memberSourceList); + + List afterExistingFilter = createMemberSourceList(3); + + when(sourceMapper.filterExistingRelations(memberSourceList)) + .thenReturn(afterExistingFilter); + when(sourceMapper.filterValidSourceRelations(afterExistingFilter)) + .thenReturn(new ArrayList<>()); // 全部无效 + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSkipped()); + assertTrue(result.getMessage().contains("没有有效的关联关系可创建")); + verify(sourceMapper, times(1)).filterValidSourceRelations(afterExistingFilter); + verify(sourceMapper, never()).addRelations(anyList()); + } + + @Test + void testExecute_PartialFiltered() { + // Given: 部分被过滤,部分有效 + List memberSourceList = createMemberSourceList(10); + context.setMemberSourceList(memberSourceList); + + List afterExistingFilter = createMemberSourceList(7); + List afterValidFilter = createMemberSourceList(5); + + when(sourceMapper.filterExistingRelations(memberSourceList)) + .thenReturn(afterExistingFilter); + when(sourceMapper.filterValidSourceRelations(afterExistingFilter)) + .thenReturn(afterValidFilter); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + assertTrue(result.getMessage().contains("持久化了5条关联关系")); + verify(sourceMapper, times(1)).addRelations(afterValidFilter); + } + + @Test + void testExecute_FilterExistingThrowsException_Failed() { + // Given + List memberSourceList = createMemberSourceList(3); + context.setMemberSourceList(memberSourceList); + + when(sourceMapper.filterExistingRelations(memberSourceList)) + .thenThrow(new RuntimeException("Database error")); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isFailed()); + assertTrue(result.getMessage().contains("保存关联关系失败")); + assertNotNull(result.getException()); + verify(sourceMapper, never()).addRelations(anyList()); + } + + @Test + void testExecute_AddRelationsThrowsException_Failed() { + // Given + List memberSourceList = createMemberSourceList(3); + context.setMemberSourceList(memberSourceList); + + List afterExistingFilter = createMemberSourceList(2); + List afterValidFilter = createMemberSourceList(2); + + when(sourceMapper.filterExistingRelations(memberSourceList)) + .thenReturn(afterExistingFilter); + when(sourceMapper.filterValidSourceRelations(afterExistingFilter)) + .thenReturn(afterValidFilter); + doThrow(new RuntimeException("Insert error")).when(sourceMapper).addRelations(afterValidFilter); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isFailed()); + verify(memberRelationRepository, never()).clearSCacheByFace(anyLong()); // 失败时不清缓存 + } + + @Test + void testExecute_CacheClearFailed() { + // Given + List memberSourceList = createMemberSourceList(3); + context.setMemberSourceList(memberSourceList); + + List afterExistingFilter = createMemberSourceList(2); + List afterValidFilter = createMemberSourceList(2); + + when(sourceMapper.filterExistingRelations(memberSourceList)) + .thenReturn(afterExistingFilter); + when(sourceMapper.filterValidSourceRelations(afterExistingFilter)) + .thenReturn(afterValidFilter); + doThrow(new RuntimeException("Cache clear error")) + .when(memberRelationRepository).clearSCacheByFace(1L); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isFailed()); // 缓存清理失败导致整体失败 + verify(sourceMapper, times(1)).addRelations(afterValidFilter); // 但关系已保存 + } + + @Test + void testExecute_SingleRelation() { + // Given: 只有1条关系 + List memberSourceList = createMemberSourceList(1); + context.setMemberSourceList(memberSourceList); + + when(sourceMapper.filterExistingRelations(memberSourceList)) + .thenReturn(memberSourceList); + when(sourceMapper.filterValidSourceRelations(memberSourceList)) + .thenReturn(memberSourceList); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + assertTrue(result.getMessage().contains("持久化了1条关联关系")); + } + + private List createMemberSourceList(int count) { + List list = new ArrayList<>(); + for (int i = 0; i < count; i++) { + MemberSourceEntity entity = new MemberSourceEntity(); + entity.setMemberId(100L); + entity.setSourceId((long) (i + 1)); + list.add(entity); + } + return list; + } +}