From b3fa10e8fd96644b06373f91f6213724ba116e1c Mon Sep 17 00:00:00 2001 From: Jerry Yan <792602257@qq.com> Date: Wed, 3 Dec 2025 18:17:34 +0800 Subject: [PATCH] =?UTF-8?q?fix(pipeline):=20=E5=A2=9E=E5=BC=BA=E4=BA=BA?= =?UTF-8?q?=E8=84=B8=E5=8C=B9=E9=85=8D=E6=B5=81=E6=B0=B4=E7=BA=BF=E7=9A=84?= =?UTF-8?q?=E5=81=A5=E5=A3=AE=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在FilterByTimeRangeStage中增加空值检查和配置验证 - 在LoadMatchedSamplesStage中增加sampleListIds空值检查 - 添加完整的集成测试覆盖Pipeline工厂和Context构建 - 为FilterByDevicePhotoLimitStage添加全面的单元测试 - 为FilterByTimeRangeStage添加边界条件和异常处理测试 - 为LoadMatchedSamplesStage添加异常路径测试 --- .../stages/FilterByTimeRangeStage.java | 16 +- .../stages/LoadMatchedSamplesStage.java | 6 + .../FaceMatchingPipelineIntegrationTest.java | 200 ++++++++++++++++ .../FilterByDevicePhotoLimitStageTest.java | 213 ++++++++++++++++++ .../stages/FilterByTimeRangeStageTest.java | 211 +++++++++++++++++ .../stages/LoadMatchedSamplesStageTest.java | 155 +++++++++++++ 6 files changed, 800 insertions(+), 1 deletion(-) create mode 100644 src/test/java/com/ycwl/basic/face/pipeline/integration/FaceMatchingPipelineIntegrationTest.java create mode 100644 src/test/java/com/ycwl/basic/face/pipeline/stages/FilterByDevicePhotoLimitStageTest.java create mode 100644 src/test/java/com/ycwl/basic/face/pipeline/stages/FilterByTimeRangeStageTest.java create mode 100644 src/test/java/com/ycwl/basic/face/pipeline/stages/LoadMatchedSamplesStageTest.java diff --git a/src/main/java/com/ycwl/basic/face/pipeline/stages/FilterByTimeRangeStage.java b/src/main/java/com/ycwl/basic/face/pipeline/stages/FilterByTimeRangeStage.java index e316442e..9e32ebff 100644 --- a/src/main/java/com/ycwl/basic/face/pipeline/stages/FilterByTimeRangeStage.java +++ b/src/main/java/com/ycwl/basic/face/pipeline/stages/FilterByTimeRangeStage.java @@ -69,9 +69,23 @@ public class FilterByTimeRangeStage extends AbstractFaceMatchingStage doExecute(FaceMatchingContext context) { List faceSamples = context.getFaceSamples(); List sampleListIds = context.getSampleListIds(); - Integer tourMinutes = context.getScenicConfig().getInteger("tour_time"); Long faceId = context.getFaceId(); + // 防御性检查:faceSamples为空 + if (faceSamples == null || faceSamples.isEmpty()) { + log.debug("faceSamples为空,跳过时间范围筛选,faceId={}", faceId); + return StageResult.skipped("faceSamples为空"); + } + + // 防御性检查:tour_time配置 + Integer tourMinutes = context.getScenicConfig() != null + ? context.getScenicConfig().getInteger("tour_time") + : null; + if (tourMinutes == null || tourMinutes <= 0) { + log.debug("景区未配置tour_time或配置为0,跳过时间范围筛选,faceId={}", faceId); + return StageResult.skipped("未配置tour_time"); + } + try { // 1. 构建样本ID到实体的映射 Map sampleMap = faceSamples.stream() diff --git a/src/main/java/com/ycwl/basic/face/pipeline/stages/LoadMatchedSamplesStage.java b/src/main/java/com/ycwl/basic/face/pipeline/stages/LoadMatchedSamplesStage.java index 27a27bfe..bfb9381e 100644 --- a/src/main/java/com/ycwl/basic/face/pipeline/stages/LoadMatchedSamplesStage.java +++ b/src/main/java/com/ycwl/basic/face/pipeline/stages/LoadMatchedSamplesStage.java @@ -64,6 +64,12 @@ public class LoadMatchedSamplesStage extends AbstractFaceMatchingStage sampleListIds = context.getSampleListIds(); Long faceId = context.getFaceId(); + // 防御性检查:如果sampleListIds为空,直接跳过 + if (sampleListIds == null || sampleListIds.isEmpty()) { + log.debug("sampleListIds为空,跳过加载匹配样本,faceId={}", faceId); + return StageResult.skipped("sampleListIds为空"); + } + try { // 批量加载样本实体 List faceSamples = faceSampleMapper.listByIds(sampleListIds); diff --git a/src/test/java/com/ycwl/basic/face/pipeline/integration/FaceMatchingPipelineIntegrationTest.java b/src/test/java/com/ycwl/basic/face/pipeline/integration/FaceMatchingPipelineIntegrationTest.java new file mode 100644 index 00000000..4bcf4d9c --- /dev/null +++ b/src/test/java/com/ycwl/basic/face/pipeline/integration/FaceMatchingPipelineIntegrationTest.java @@ -0,0 +1,200 @@ +package com.ycwl.basic.face.pipeline.integration; + +import com.ycwl.basic.face.pipeline.core.FaceMatchingContext; +import com.ycwl.basic.face.pipeline.core.Pipeline; +import com.ycwl.basic.face.pipeline.enums.FaceMatchingScene; +import com.ycwl.basic.face.pipeline.factory.FaceMatchingPipelineFactory; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.test.context.ActiveProfiles; + +import java.util.Arrays; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Pipeline集成测试 + * 测试Pipeline的完整流程和Stage协作 + */ +@SpringBootTest +@ActiveProfiles("test") +class FaceMatchingPipelineIntegrationTest { + + @Autowired + private FaceMatchingPipelineFactory pipelineFactory; + + /** + * 测试Pipeline工厂能够成功创建Pipeline + */ + @Test + void testCreatePipelines() { + // When: 创建各种场景的Pipeline + Pipeline autoMatchingNew = pipelineFactory.createAutoMatchingPipeline(true); + Pipeline autoMatchingOld = pipelineFactory.createAutoMatchingPipeline(false); + Pipeline customMatching = pipelineFactory.createCustomMatchingPipeline(); + Pipeline recognitionOnly = pipelineFactory.createRecognitionOnlyPipeline(); + + // Then: 验证Pipeline创建成功 + assertNotNull(autoMatchingNew); + assertNotNull(autoMatchingOld); + assertNotNull(customMatching); + assertNotNull(recognitionOnly); + + // 验证Stage数量符合预期 + assertEquals(13, autoMatchingNew.getStageCount()); + assertEquals(13, autoMatchingOld.getStageCount()); + assertEquals(15, customMatching.getStageCount()); + assertEquals(3, recognitionOnly.getStageCount()); + } + + /** + * 测试通过场景和isNew参数创建Pipeline + */ + @Test + void testCreatePipelineByScene() { + // When + Pipeline autoNew = pipelineFactory.createPipeline(FaceMatchingScene.AUTO_MATCHING, true); + Pipeline autoOld = pipelineFactory.createPipeline(FaceMatchingScene.AUTO_MATCHING, false); + Pipeline custom = pipelineFactory.createPipeline(FaceMatchingScene.CUSTOM_MATCHING, false); + Pipeline recognition = pipelineFactory.createPipeline(FaceMatchingScene.RECOGNITION_ONLY, false); + + // Then + assertNotNull(autoNew); + assertNotNull(autoOld); + assertNotNull(custom); + assertNotNull(recognition); + } + + /** + * 测试通过Context创建Pipeline + */ + @Test + void testCreatePipelineByContext() { + // Given + FaceMatchingContext autoContext = FaceMatchingContext.forAutoMatching(1L, true); + FaceMatchingContext customContext = FaceMatchingContext.forCustomMatching(2L, Arrays.asList(101L, 102L)); + FaceMatchingContext recognitionContext = FaceMatchingContext.forRecognitionOnly(3L); + + // When + Pipeline autoPipeline = pipelineFactory.createPipeline(autoContext); + Pipeline customPipeline = pipelineFactory.createPipeline(customContext); + Pipeline recognitionPipeline = pipelineFactory.createPipeline(recognitionContext); + + // Then + assertNotNull(autoPipeline); + assertNotNull(customPipeline); + assertNotNull(recognitionPipeline); + } + + /** + * 测试Pipeline名称 + */ + @Test + void testPipelineNames() { + // When + Pipeline autoNew = pipelineFactory.createAutoMatchingPipeline(true); + Pipeline autoOld = pipelineFactory.createAutoMatchingPipeline(false); + Pipeline custom = pipelineFactory.createCustomMatchingPipeline(); + Pipeline recognition = pipelineFactory.createRecognitionOnlyPipeline(); + + // Then + assertTrue(autoNew.getName().contains("AutoMatching")); + assertTrue(autoNew.getName().contains("New")); + assertTrue(autoOld.getName().contains("AutoMatching")); + assertTrue(autoOld.getName().contains("Old")); + assertTrue(custom.getName().contains("CustomMatching")); + assertTrue(recognition.getName().contains("RecognitionOnly")); + } + + /** + * 测试Context构建 + */ + @Test + void testContextCreation() { + // When: 使用不同的工厂方法创建Context + FaceMatchingContext autoContext = FaceMatchingContext.forAutoMatching(100L, true); + FaceMatchingContext customContext = FaceMatchingContext.forCustomMatching(200L, Arrays.asList(1L, 2L, 3L)); + FaceMatchingContext recognitionContext = FaceMatchingContext.forRecognitionOnly(300L); + + // Then: 验证Context属性 + assertEquals(100L, autoContext.getFaceId()); + assertTrue(autoContext.isNew()); + assertEquals(FaceMatchingScene.AUTO_MATCHING, autoContext.getScene()); + + assertEquals(200L, customContext.getFaceId()); + assertFalse(customContext.isNew()); + assertEquals(FaceMatchingScene.CUSTOM_MATCHING, customContext.getScene()); + assertEquals(3, customContext.getFaceSampleIds().size()); + + assertEquals(300L, recognitionContext.getFaceId()); + assertFalse(recognitionContext.isNew()); + assertEquals(FaceMatchingScene.RECOGNITION_ONLY, recognitionContext.getScene()); + } + + /** + * 测试Context的Stage开关配置 + */ + @Test + void testContextStageConfiguration() { + // Given + FaceMatchingContext context = FaceMatchingContext.forAutoMatching(1L, true); + + // When: 配置Stage开关 + context.enableStage("stage1"); + context.disableStage("stage2"); + context.setStageState("stage3", true); + + // Then + assertTrue(context.isStageEnabled("stage1")); + assertFalse(context.isStageEnabled("stage2")); + assertTrue(context.isStageEnabled("stage3")); + assertFalse(context.isStageEnabled("non_exist_stage")); // 默认false + } + + /** + * 测试Context Builder + */ + @Test + void testContextBuilder() { + // When + FaceMatchingContext context = FaceMatchingContext.builder() + .faceId(999L) + .isNew(true) + .scene(FaceMatchingScene.AUTO_MATCHING) + .build(); + + // Then + assertEquals(999L, context.getFaceId()); + assertTrue(context.isNew()); + assertEquals(FaceMatchingScene.AUTO_MATCHING, context.getScene()); + } + + /** + * 测试Builder参数校验 + */ + @Test + void testContextBuilderValidation() { + // When & Then: faceId为null应该抛异常 + assertThrows(IllegalArgumentException.class, () -> { + FaceMatchingContext.builder() + .scene(FaceMatchingScene.AUTO_MATCHING) + .build(); + }); + + // When & Then: scene为null应该抛异常 + assertThrows(IllegalArgumentException.class, () -> { + FaceMatchingContext.builder() + .faceId(1L) + .build(); + }); + + // When & Then: CUSTOM_MATCHING场景必须提供faceSampleIds + assertThrows(IllegalArgumentException.class, () -> { + FaceMatchingContext.builder() + .faceId(1L) + .scene(FaceMatchingScene.CUSTOM_MATCHING) + .build(); + }); + } +} diff --git a/src/test/java/com/ycwl/basic/face/pipeline/stages/FilterByDevicePhotoLimitStageTest.java b/src/test/java/com/ycwl/basic/face/pipeline/stages/FilterByDevicePhotoLimitStageTest.java new file mode 100644 index 00000000..8b1c4729 --- /dev/null +++ b/src/test/java/com/ycwl/basic/face/pipeline/stages/FilterByDevicePhotoLimitStageTest.java @@ -0,0 +1,213 @@ +package com.ycwl.basic.face.pipeline.stages; + +import com.ycwl.basic.face.pipeline.core.FaceMatchingContext; +import com.ycwl.basic.face.pipeline.core.StageResult; +import com.ycwl.basic.integration.common.manager.DeviceConfigManager; +import com.ycwl.basic.model.pc.faceSample.entity.FaceSampleEntity; +import com.ycwl.basic.repository.DeviceRepository; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.util.*; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.*; + +/** + * FilterByDevicePhotoLimitStage 单元测试 + */ +@ExtendWith(MockitoExtension.class) +class FilterByDevicePhotoLimitStageTest { + + @InjectMocks + private FilterByDevicePhotoLimitStage stage; + + @Mock + private DeviceRepository deviceRepository; + + @Mock + private DeviceConfigManager deviceConfigManager; + + private FaceMatchingContext context; + + @BeforeEach + void setUp() { + context = FaceMatchingContext.forCustomMatching(1L, Arrays.asList(101L, 102L)); + } + + @Test + void testExecute_Success_FilterByLimit() { + // Given: 设备1有5个样本,限制为2张 + when(deviceRepository.getDeviceConfigManager(1L)).thenReturn(deviceConfigManager); + when(deviceConfigManager.getInteger("limit_photo")).thenReturn(2); + + Date baseTime = new Date(); + FaceSampleEntity sample1 = createSample(101L, 1L, new Date(baseTime.getTime() - 5000)); + FaceSampleEntity sample2 = createSample(102L, 1L, new Date(baseTime.getTime() - 4000)); + FaceSampleEntity sample3 = createSample(103L, 1L, new Date(baseTime.getTime() - 3000)); + FaceSampleEntity sample4 = createSample(104L, 1L, new Date(baseTime.getTime() - 2000)); + FaceSampleEntity sample5 = createSample(105L, 1L, new Date(baseTime.getTime() - 1000)); + + context.setFaceSamples(Arrays.asList(sample1, sample2, sample3, sample4, sample5)); + context.setSampleListIds(Arrays.asList(101L, 102L, 103L, 104L, 105L)); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + List filteredIds = context.getSampleListIds(); + assertEquals(2, filteredIds.size()); // 限制为2张 + } + + @Test + void testExecute_EmptyFaceSamples_Skip() { + // Given + context.setFaceSamples(Collections.emptyList()); + context.setSampleListIds(Arrays.asList(101L, 102L)); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSkipped()); + } + + @Test + void testExecute_NullFaceSamples_Skip() { + // Given + context.setFaceSamples(null); + context.setSampleListIds(Arrays.asList(101L, 102L)); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSkipped()); + } + + @Test + void testExecute_NoLimitPhoto_KeepAll() { + // Given: 设备配置limit_photo为null,不限制 + when(deviceRepository.getDeviceConfigManager(1L)).thenReturn(deviceConfigManager); + when(deviceConfigManager.getInteger("limit_photo")).thenReturn(null); + + FaceSampleEntity sample1 = createSample(101L, 1L, new Date()); + FaceSampleEntity sample2 = createSample(102L, 1L, new Date()); + FaceSampleEntity sample3 = createSample(103L, 1L, new Date()); + + context.setFaceSamples(Arrays.asList(sample1, sample2, sample3)); + context.setSampleListIds(Arrays.asList(101L, 102L, 103L)); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + assertEquals(3, context.getSampleListIds().size()); // 全部保留 + } + + @Test + void testExecute_LimitPhotoZero_KeepAll() { + // Given: limit_photo=0,不限制 + when(deviceRepository.getDeviceConfigManager(1L)).thenReturn(deviceConfigManager); + when(deviceConfigManager.getInteger("limit_photo")).thenReturn(0); + + FaceSampleEntity sample1 = createSample(101L, 1L, new Date()); + FaceSampleEntity sample2 = createSample(102L, 1L, new Date()); + + context.setFaceSamples(Arrays.asList(sample1, sample2)); + context.setSampleListIds(Arrays.asList(101L, 102L)); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + assertEquals(2, context.getSampleListIds().size()); + } + + @Test + void testExecute_MultipleDevices() { + // Given: 设备1限制2张,设备2限制3张 + when(deviceRepository.getDeviceConfigManager(1L)).thenReturn(deviceConfigManager); + when(deviceConfigManager.getInteger("limit_photo")).thenReturn(2); + + DeviceConfigManager deviceConfig2 = mock(DeviceConfigManager.class); + when(deviceRepository.getDeviceConfigManager(2L)).thenReturn(deviceConfig2); + when(deviceConfig2.getInteger("limit_photo")).thenReturn(3); + + Date baseTime = new Date(); + FaceSampleEntity sample1 = createSample(101L, 1L, new Date(baseTime.getTime() - 3000)); + FaceSampleEntity sample2 = createSample(102L, 1L, new Date(baseTime.getTime() - 2000)); + FaceSampleEntity sample3 = createSample(103L, 1L, new Date(baseTime.getTime() - 1000)); + FaceSampleEntity sample4 = createSample(104L, 2L, new Date(baseTime.getTime() - 5000)); + FaceSampleEntity sample5 = createSample(105L, 2L, new Date(baseTime.getTime() - 4000)); + FaceSampleEntity sample6 = createSample(106L, 2L, new Date(baseTime.getTime() - 3000)); + FaceSampleEntity sample7 = createSample(107L, 2L, new Date(baseTime.getTime() - 2000)); + + context.setFaceSamples(Arrays.asList(sample1, sample2, sample3, sample4, sample5, sample6, sample7)); + context.setSampleListIds(Arrays.asList(101L, 102L, 103L, 104L, 105L, 106L, 107L)); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + List filteredIds = context.getSampleListIds(); + assertTrue(filteredIds.size() <= 5); // 设备1最多2张 + 设备2最多3张 + } + + @Test + void testExecute_SampleWithoutDeviceId_Passthrough() { + // Given: 有样本没有deviceId,应该直接保留 + when(deviceRepository.getDeviceConfigManager(1L)).thenReturn(deviceConfigManager); + when(deviceConfigManager.getInteger("limit_photo")).thenReturn(1); + + FaceSampleEntity sample1 = createSample(101L, 1L, new Date()); + FaceSampleEntity sample2 = createSample(102L, 1L, new Date()); + FaceSampleEntity sample3 = createSample(103L, null, new Date()); // 无deviceId + + context.setFaceSamples(Arrays.asList(sample1, sample2, sample3)); + context.setSampleListIds(Arrays.asList(101L, 102L, 103L)); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + List filteredIds = context.getSampleListIds(); + assertTrue(filteredIds.contains(103L)); // 无deviceId的样本被保留 + assertTrue(filteredIds.size() >= 2); // 至少包含103和设备1的1张样本 + } + + @Test + void testExecute_Exception_Degraded() { + // Given + when(deviceRepository.getDeviceConfigManager(anyLong())).thenThrow(new RuntimeException("DB error")); + + FaceSampleEntity sample1 = createSample(101L, 1L, new Date()); + context.setFaceSamples(Arrays.asList(sample1)); + context.setSampleListIds(Arrays.asList(101L)); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isDegraded()); + assertTrue(result.getMessage().contains("设备照片数量限制筛选失败")); + } + + private FaceSampleEntity createSample(Long id, Long deviceId, Date createAt) { + FaceSampleEntity sample = new FaceSampleEntity(); + sample.setId(id); + sample.setDeviceId(deviceId); + sample.setCreateAt(createAt); + return sample; + } +} diff --git a/src/test/java/com/ycwl/basic/face/pipeline/stages/FilterByTimeRangeStageTest.java b/src/test/java/com/ycwl/basic/face/pipeline/stages/FilterByTimeRangeStageTest.java new file mode 100644 index 00000000..c9c3f3cf --- /dev/null +++ b/src/test/java/com/ycwl/basic/face/pipeline/stages/FilterByTimeRangeStageTest.java @@ -0,0 +1,211 @@ +package com.ycwl.basic.face.pipeline.stages; + +import com.ycwl.basic.face.pipeline.core.FaceMatchingContext; +import com.ycwl.basic.face.pipeline.core.StageResult; +import com.ycwl.basic.integration.common.manager.ScenicConfigManager; +import com.ycwl.basic.model.pc.faceSample.entity.FaceSampleEntity; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.util.*; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +/** + * FilterByTimeRangeStage 单元测试 + */ +@ExtendWith(MockitoExtension.class) +class FilterByTimeRangeStageTest { + + @InjectMocks + private FilterByTimeRangeStage stage; + + @Mock + private ScenicConfigManager scenicConfig; + + private FaceMatchingContext context; + + @BeforeEach + void setUp() { + context = FaceMatchingContext.forCustomMatching(1L, Arrays.asList(101L, 102L)); + context.setScenicConfig(scenicConfig); + } + + @Test + void testExecute_Success_FilterBySampleTime() { + // Given: 设置tour_time为30分钟 + when(scenicConfig.getInteger("tour_time")).thenReturn(30); + + Date baseTime = new Date(); + Date time10MinBefore = new Date(baseTime.getTime() - 10 * 60 * 1000); // 10分钟前 + Date time20MinBefore = new Date(baseTime.getTime() - 20 * 60 * 1000); // 20分钟前 + Date time40MinBefore = new Date(baseTime.getTime() - 40 * 60 * 1000); // 40分钟前 (超出范围) + + FaceSampleEntity sample1 = createSample(101L, baseTime); + FaceSampleEntity sample2 = createSample(102L, time10MinBefore); + FaceSampleEntity sample3 = createSample(103L, time20MinBefore); + FaceSampleEntity sample4 = createSample(104L, time40MinBefore); + + context.setFaceSamples(Arrays.asList(sample1, sample2, sample3, sample4)); + context.setSampleListIds(Arrays.asList(101L, 102L, 103L, 104L)); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + List filteredIds = context.getSampleListIds(); + assertEquals(3, filteredIds.size()); + assertTrue(filteredIds.contains(101L)); + assertTrue(filteredIds.contains(102L)); + assertTrue(filteredIds.contains(103L)); + assertFalse(filteredIds.contains(104L)); // 40分钟前的被过滤 + } + + @Test + void testExecute_EmptyFaceSamples_Skip() { + // Given + context.setFaceSamples(Collections.emptyList()); + context.setSampleListIds(Arrays.asList(101L, 102L)); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSkipped()); + } + + @Test + void testExecute_NullFaceSamples_Skip() { + // Given + context.setFaceSamples(null); + context.setSampleListIds(Arrays.asList(101L, 102L)); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSkipped()); + } + + @Test + void testExecute_TourTimeZero_Skip() { + // Given + when(scenicConfig.getInteger("tour_time")).thenReturn(0); + + FaceSampleEntity sample1 = createSample(101L, new Date()); + context.setFaceSamples(Arrays.asList(sample1)); + context.setSampleListIds(Arrays.asList(101L)); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSkipped()); + } + + @Test + void testExecute_TourTimeNull_Skip() { + // Given + when(scenicConfig.getInteger("tour_time")).thenReturn(null); + + FaceSampleEntity sample1 = createSample(101L, new Date()); + context.setFaceSamples(Arrays.asList(sample1)); + context.setSampleListIds(Arrays.asList(101L)); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSkipped()); + } + + @Test + void testExecute_NoValidCreateTime_Success() { + // Given + when(scenicConfig.getInteger("tour_time")).thenReturn(30); + + FaceSampleEntity sample1 = createSample(101L, null); // 无创建时间 + FaceSampleEntity sample2 = createSample(102L, null); + + context.setFaceSamples(Arrays.asList(sample1, sample2)); + context.setSampleListIds(Arrays.asList(101L, 102L)); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + assertTrue(result.getMessage().contains("样本无拍摄时间")); + } + + @Test + void testExecute_SampleNotInCache_Filtered() { + // Given: sampleListIds中有样本ID,但faceSamples缓存中没有 + when(scenicConfig.getInteger("tour_time")).thenReturn(30); + + FaceSampleEntity sample1 = createSample(101L, new Date()); + context.setFaceSamples(Arrays.asList(sample1)); + context.setSampleListIds(Arrays.asList(101L, 102L, 103L)); // 102和103不在缓存中 + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + List filteredIds = context.getSampleListIds(); + assertEquals(1, filteredIds.size()); + assertEquals(101L, filteredIds.get(0)); + } + + @Test + void testExecute_AllSamplesWithinRange() { + // Given: 所有样本都在时间范围内 + when(scenicConfig.getInteger("tour_time")).thenReturn(60); + + Date baseTime = new Date(); + Date time30MinBefore = new Date(baseTime.getTime() - 30 * 60 * 1000); + + FaceSampleEntity sample1 = createSample(101L, baseTime); + FaceSampleEntity sample2 = createSample(102L, time30MinBefore); + + context.setFaceSamples(Arrays.asList(sample1, sample2)); + context.setSampleListIds(Arrays.asList(101L, 102L)); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + assertEquals(2, context.getSampleListIds().size()); + } + + @Test + void testExecute_Exception_Degraded() { + // Given + when(scenicConfig.getInteger("tour_time")).thenReturn(30); + + FaceSampleEntity sample1 = createSample(101L, new Date()); + context.setFaceSamples(Arrays.asList(sample1)); + context.setSampleListIds(null); // 会导致异常 + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isDegraded()); + assertTrue(result.getMessage().contains("时间范围筛选失败")); + } + + private FaceSampleEntity createSample(Long id, Date createAt) { + FaceSampleEntity sample = new FaceSampleEntity(); + sample.setId(id); + sample.setCreateAt(createAt); + return sample; + } +} diff --git a/src/test/java/com/ycwl/basic/face/pipeline/stages/LoadMatchedSamplesStageTest.java b/src/test/java/com/ycwl/basic/face/pipeline/stages/LoadMatchedSamplesStageTest.java new file mode 100644 index 00000000..4c2ab1a8 --- /dev/null +++ b/src/test/java/com/ycwl/basic/face/pipeline/stages/LoadMatchedSamplesStageTest.java @@ -0,0 +1,155 @@ +package com.ycwl.basic.face.pipeline.stages; + +import com.ycwl.basic.face.pipeline.core.FaceMatchingContext; +import com.ycwl.basic.face.pipeline.core.StageResult; +import com.ycwl.basic.face.pipeline.enums.FaceMatchingScene; +import com.ycwl.basic.mapper.FaceSampleMapper; +import com.ycwl.basic.model.pc.faceSample.entity.FaceSampleEntity; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Date; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.Mockito.*; + +/** + * LoadMatchedSamplesStage 单元测试 + */ +@ExtendWith(MockitoExtension.class) +class LoadMatchedSamplesStageTest { + + @Mock + private FaceSampleMapper faceSampleMapper; + + @InjectMocks + private LoadMatchedSamplesStage stage; + + private FaceMatchingContext context; + + @BeforeEach + void setUp() { + context = FaceMatchingContext.forCustomMatching(1L, Arrays.asList(101L, 102L)); + } + + @Test + void testExecute_Success() { + // Given + List sampleIds = Arrays.asList(101L, 102L, 103L); + context.setSampleListIds(sampleIds); + + FaceSampleEntity sample1 = createSample(101L, 1L, new Date()); + FaceSampleEntity sample2 = createSample(102L, 1L, new Date()); + FaceSampleEntity sample3 = createSample(103L, 2L, new Date()); + + when(faceSampleMapper.listByIds(sampleIds)) + .thenReturn(Arrays.asList(sample1, sample2, sample3)); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + assertEquals(3, context.getFaceSamples().size()); + assertEquals(101L, context.getFaceSamples().get(0).getId()); + verify(faceSampleMapper, times(1)).listByIds(sampleIds); + } + + @Test + void testExecute_EmptySampleListIds_Skip() { + // Given + context.setSampleListIds(Collections.emptyList()); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSkipped()); + verify(faceSampleMapper, never()).listByIds(anyList()); + } + + @Test + void testExecute_NullSampleListIds_Skip() { + // Given + context.setSampleListIds(null); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSkipped()); + verify(faceSampleMapper, never()).listByIds(anyList()); + } + + @Test + void testExecute_NoSamplesFound_Skipped() { + // Given + List sampleIds = Arrays.asList(101L, 102L); + context.setSampleListIds(sampleIds); + + when(faceSampleMapper.listByIds(sampleIds)) + .thenReturn(Collections.emptyList()); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSkipped()); + assertNull(context.getFaceSamples()); + verify(faceSampleMapper, times(1)).listByIds(sampleIds); + } + + @Test + void testExecute_MapperThrowsException_Failure() { + // Given + List sampleIds = Arrays.asList(101L, 102L); + context.setSampleListIds(sampleIds); + + when(faceSampleMapper.listByIds(sampleIds)) + .thenThrow(new RuntimeException("Database error")); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isFailed()); + assertTrue(result.getMessage().contains("加载匹配样本实体失败")); + verify(faceSampleMapper, times(1)).listByIds(sampleIds); + } + + @Test + void testExecute_PartialResults() { + // Given: 请求3个样本,只返回2个 + List sampleIds = Arrays.asList(101L, 102L, 103L); + context.setSampleListIds(sampleIds); + + FaceSampleEntity sample1 = createSample(101L, 1L, new Date()); + FaceSampleEntity sample2 = createSample(102L, 1L, new Date()); + + when(faceSampleMapper.listByIds(sampleIds)) + .thenReturn(Arrays.asList(sample1, sample2)); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + assertEquals(2, context.getFaceSamples().size()); + } + + private FaceSampleEntity createSample(Long id, Long deviceId, Date createAt) { + FaceSampleEntity sample = new FaceSampleEntity(); + sample.setId(id); + sample.setDeviceId(deviceId); + sample.setCreateAt(createAt); + return sample; + } +}