From b165840176bca5861eb186d4668721a4b3af5104 Mon Sep 17 00:00:00 2001 From: Jerry Yan <792602257@qq.com> Date: Wed, 3 Dec 2025 18:41:24 +0800 Subject: [PATCH] =?UTF-8?q?feat(face):=20=E6=B7=BB=E5=8A=A0=E6=96=B0?= =?UTF-8?q?=E4=BA=BA=E8=84=B8=E4=BB=BB=E5=8A=A1=E7=8A=B6=E6=80=81=E8=AE=BE?= =?UTF-8?q?=E7=BD=AE=E9=80=BB=E8=BE=91=E5=8F=8A=E5=8D=95=E5=85=83=E6=B5=8B?= =?UTF-8?q?=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在SetTaskStatusStage中增加新人脸用户判断逻辑,非新用户跳过任务状态设置 - 新增LoadFaceSamplesStage、SetTaskStatusStage和UpdateFaceResultStage的完整单元测试 - 完善各阶段异常处理和边界条件测试,提升代码健壮性 - 添加大量测试用例覆盖成功、失败、异常等多种执行路径 - 验证任务状态设置、人脸样本加载和识别结果更新的核心功能 --- .../pipeline/stages/SetTaskStatusStage.java | 6 + .../stages/LoadFaceSamplesStageTest.java | 216 +++++++++++++++ .../stages/SetTaskStatusStageTest.java | 167 ++++++++++++ .../stages/UpdateFaceResultStageTest.java | 248 ++++++++++++++++++ 4 files changed, 637 insertions(+) create mode 100644 src/test/java/com/ycwl/basic/face/pipeline/stages/LoadFaceSamplesStageTest.java create mode 100644 src/test/java/com/ycwl/basic/face/pipeline/stages/SetTaskStatusStageTest.java create mode 100644 src/test/java/com/ycwl/basic/face/pipeline/stages/UpdateFaceResultStageTest.java diff --git a/src/main/java/com/ycwl/basic/face/pipeline/stages/SetTaskStatusStage.java b/src/main/java/com/ycwl/basic/face/pipeline/stages/SetTaskStatusStage.java index 49b6ceb3..0e8675c2 100644 --- a/src/main/java/com/ycwl/basic/face/pipeline/stages/SetTaskStatusStage.java +++ b/src/main/java/com/ycwl/basic/face/pipeline/stages/SetTaskStatusStage.java @@ -49,6 +49,12 @@ public class SetTaskStatusStage extends AbstractFaceMatchingStage doExecute(FaceMatchingContext context) { Long faceId = context.getFaceId(); + // 防御性检查:只有新用户才执行 + if (!context.isNew()) { + log.debug("非新用户,跳过设置任务状态,faceId={}", faceId); + return StageResult.skipped("非新用户"); + } + try { taskStatusBiz.setFaceCutStatus(faceId, 0); log.debug("设置新用户任务状态: faceId={}, status=0", faceId); diff --git a/src/test/java/com/ycwl/basic/face/pipeline/stages/LoadFaceSamplesStageTest.java b/src/test/java/com/ycwl/basic/face/pipeline/stages/LoadFaceSamplesStageTest.java new file mode 100644 index 00000000..e57a71ee --- /dev/null +++ b/src/test/java/com/ycwl/basic/face/pipeline/stages/LoadFaceSamplesStageTest.java @@ -0,0 +1,216 @@ +package com.ycwl.basic.face.pipeline.stages; + +import com.ycwl.basic.exception.BaseException; +import com.ycwl.basic.face.pipeline.core.FaceMatchingContext; +import com.ycwl.basic.face.pipeline.core.StageResult; +import com.ycwl.basic.mapper.FaceSampleMapper; +import com.ycwl.basic.model.pc.faceSample.entity.FaceSampleEntity; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.Mockito.*; + +/** + * LoadFaceSamplesStage 单元测试 + */ +@ExtendWith(MockitoExtension.class) +class LoadFaceSamplesStageTest { + + @Mock + private FaceSampleMapper faceSampleMapper; + + @InjectMocks + private LoadFaceSamplesStage stage; + + private FaceMatchingContext context; + + @BeforeEach + void setUp() { + context = FaceMatchingContext.forCustomMatching(1L, Arrays.asList(101L, 102L)); + } + + @Test + void testExecute_Success() { + // Given + List faceSampleIds = Arrays.asList(101L, 102L, 103L); + context.setFaceSampleIds(faceSampleIds); + + FaceSampleEntity sample1 = createSample(101L, "http://example.com/s1.jpg"); + FaceSampleEntity sample2 = createSample(102L, "http://example.com/s2.jpg"); + FaceSampleEntity sample3 = createSample(103L, "http://example.com/s3.jpg"); + + when(faceSampleMapper.listByIds(faceSampleIds)) + .thenReturn(Arrays.asList(sample1, sample2, sample3)); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + assertEquals(3, context.getFaceSamples().size()); + assertTrue(result.getMessage().contains("加载了3个人脸样本")); + verify(faceSampleMapper, times(1)).listByIds(faceSampleIds); + } + + @Test + void testExecute_FaceSampleIdsNull_Failed() { + // Given + context.setFaceSampleIds(null); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isFailed()); + assertTrue(result.getMessage().contains("faceSampleIds不能为空")); + verify(faceSampleMapper, never()).listByIds(anyList()); + } + + @Test + void testExecute_FaceSampleIdsEmpty_Failed() { + // Given + context.setFaceSampleIds(new ArrayList<>()); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isFailed()); + assertTrue(result.getMessage().contains("faceSampleIds不能为空")); + verify(faceSampleMapper, never()).listByIds(anyList()); + } + + @Test + void testExecute_NoSamplesFound_ThrowException() { + // Given + List faceSampleIds = Arrays.asList(101L, 102L); + context.setFaceSampleIds(faceSampleIds); + + when(faceSampleMapper.listByIds(faceSampleIds)) + .thenReturn(new ArrayList<>()); // 返回空列表 + + // When & Then + assertThrows(BaseException.class, () -> { + stage.execute(context); + }); + + verify(faceSampleMapper, times(1)).listByIds(faceSampleIds); + } + + @Test + void testExecute_PartialSamplesFound() { + // Given: 请求3个样本,只返回2个 + List faceSampleIds = Arrays.asList(101L, 102L, 103L); + context.setFaceSampleIds(faceSampleIds); + + FaceSampleEntity sample1 = createSample(101L, "http://example.com/s1.jpg"); + FaceSampleEntity sample2 = createSample(102L, "http://example.com/s2.jpg"); + + when(faceSampleMapper.listByIds(faceSampleIds)) + .thenReturn(Arrays.asList(sample1, sample2)); // 只返回2个 + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); // 只要有样本就成功 + assertEquals(2, context.getFaceSamples().size()); + assertTrue(result.getMessage().contains("加载了2个人脸样本")); + } + + @Test + void testExecute_MapperThrowsException_Failed() { + // Given + List faceSampleIds = Arrays.asList(101L, 102L); + context.setFaceSampleIds(faceSampleIds); + + when(faceSampleMapper.listByIds(faceSampleIds)) + .thenThrow(new RuntimeException("Database connection error")); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isFailed()); + assertTrue(result.getMessage().contains("加载人脸样本失败")); + assertNotNull(result.getException()); + } + + @Test + void testExecute_BaseException_Rethrow() { + // Given + List faceSampleIds = Arrays.asList(101L, 102L); + context.setFaceSampleIds(faceSampleIds); + + when(faceSampleMapper.listByIds(faceSampleIds)) + .thenReturn(new ArrayList<>()); // 触发BaseException + + // When & Then + BaseException exception = assertThrows(BaseException.class, () -> { + stage.execute(context); + }); + + assertTrue(exception.getMessage().contains("未找到指定的人脸样本")); + } + + @Test + void testExecute_SingleSample() { + // Given: 只加载1个样本 + List faceSampleIds = Arrays.asList(101L); + context.setFaceSampleIds(faceSampleIds); + + FaceSampleEntity sample = createSample(101L, "http://example.com/s1.jpg"); + + when(faceSampleMapper.listByIds(faceSampleIds)) + .thenReturn(Arrays.asList(sample)); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + assertEquals(1, context.getFaceSamples().size()); + assertTrue(result.getMessage().contains("加载了1个人脸样本")); + } + + @Test + void testExecute_ManySamples() { + // Given: 加载大量样本 + List faceSampleIds = Arrays.asList(101L, 102L, 103L, 104L, 105L, 106L, 107L, 108L, 109L, 110L); + context.setFaceSampleIds(faceSampleIds); + + List samples = new ArrayList<>(); + for (Long id : faceSampleIds) { + samples.add(createSample(id, "http://example.com/s" + id + ".jpg")); + } + + when(faceSampleMapper.listByIds(faceSampleIds)) + .thenReturn(samples); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + assertEquals(10, context.getFaceSamples().size()); + assertTrue(result.getMessage().contains("加载了10个人脸样本")); + } + + private FaceSampleEntity createSample(Long id, String faceUrl) { + FaceSampleEntity sample = new FaceSampleEntity(); + sample.setId(id); + sample.setFaceUrl(faceUrl); + sample.setDeviceId(1L); + return sample; + } +} diff --git a/src/test/java/com/ycwl/basic/face/pipeline/stages/SetTaskStatusStageTest.java b/src/test/java/com/ycwl/basic/face/pipeline/stages/SetTaskStatusStageTest.java new file mode 100644 index 00000000..0f890d95 --- /dev/null +++ b/src/test/java/com/ycwl/basic/face/pipeline/stages/SetTaskStatusStageTest.java @@ -0,0 +1,167 @@ +package com.ycwl.basic.face.pipeline.stages; + +import com.ycwl.basic.biz.TaskStatusBiz; +import com.ycwl.basic.face.pipeline.core.FaceMatchingContext; +import com.ycwl.basic.face.pipeline.core.StageResult; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.util.Arrays; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.*; + +/** + * SetTaskStatusStage 单元测试 + */ +@ExtendWith(MockitoExtension.class) +class SetTaskStatusStageTest { + + @Mock + private TaskStatusBiz taskStatusBiz; + + @InjectMocks + private SetTaskStatusStage stage; + + private FaceMatchingContext context; + + @BeforeEach + void setUp() { + context = FaceMatchingContext.forAutoMatching(1L, true); // isNew=true + } + + @Test + void testExecute_NewUser_Success() { + // Given + context = FaceMatchingContext.forAutoMatching(1L, true); + + doNothing().when(taskStatusBiz).setFaceCutStatus(1L, 0); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + assertTrue(result.getMessage().contains("任务状态已设置")); + verify(taskStatusBiz, times(1)).setFaceCutStatus(1L, 0); + } + + @Test + void testExecute_OldUser_Skipped() { + // Given: 老用户 + context = FaceMatchingContext.forAutoMatching(1L, false); // isNew=false + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSkipped()); + assertTrue(result.getMessage().contains("非新用户")); + verify(taskStatusBiz, never()).setFaceCutStatus(anyLong(), anyInt()); + } + + @Test + void testExecute_SetStatusFailed_Degraded() { + // Given + context = FaceMatchingContext.forAutoMatching(1L, true); + + doThrow(new RuntimeException("Database error")) + .when(taskStatusBiz).setFaceCutStatus(1L, 0); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isDegraded()); // 降级处理,不影响主流程 + assertTrue(result.getMessage().contains("任务状态设置失败")); + verify(taskStatusBiz, times(1)).setFaceCutStatus(1L, 0); + } + + @Test + void testExecute_DifferentFaceId() { + // Given: 不同的faceId + context = FaceMatchingContext.forAutoMatching(999L, true); + + doNothing().when(taskStatusBiz).setFaceCutStatus(999L, 0); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + verify(taskStatusBiz, times(1)).setFaceCutStatus(999L, 0); + } + + @Test + void testExecute_RecognitionOnlyScene_Skipped() { + // Given: 仅识别场景(非新用户场景) + context = FaceMatchingContext.forRecognitionOnly(1L); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSkipped()); + verify(taskStatusBiz, never()).setFaceCutStatus(anyLong(), anyInt()); + } + + @Test + void testExecute_CustomMatchingOldUser_Skipped() { + // Given: 自定义匹配老用户 (需要提供faceSampleIds) + context = FaceMatchingContext.forCustomMatching(1L, Arrays.asList(101L, 102L)); + // CustomMatching默认isNew=false + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSkipped()); + verify(taskStatusBiz, never()).setFaceCutStatus(anyLong(), anyInt()); + } + + @Test + void testExecute_NullPointerException_Degraded() { + // Given + context = FaceMatchingContext.forAutoMatching(1L, true); + + doThrow(new NullPointerException("Null task status")) + .when(taskStatusBiz).setFaceCutStatus(1L, 0); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isDegraded()); + assertTrue(result.getMessage().contains("任务状态设置失败")); + } + + @Test + void testShouldExecute_NewUser_True() { + // Given + context = FaceMatchingContext.forAutoMatching(1L, true); + + // When + StageResult result = stage.execute(context); + + // Then + assertFalse(result.isSkipped()); // 应该执行 + } + + @Test + void testShouldExecute_OldUser_False() { + // Given + context = FaceMatchingContext.forAutoMatching(1L, false); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSkipped()); // 应该跳过 + } +} diff --git a/src/test/java/com/ycwl/basic/face/pipeline/stages/UpdateFaceResultStageTest.java b/src/test/java/com/ycwl/basic/face/pipeline/stages/UpdateFaceResultStageTest.java new file mode 100644 index 00000000..5ec4ffc6 --- /dev/null +++ b/src/test/java/com/ycwl/basic/face/pipeline/stages/UpdateFaceResultStageTest.java @@ -0,0 +1,248 @@ +package com.ycwl.basic.face.pipeline.stages; + +import com.ycwl.basic.face.pipeline.core.FaceMatchingContext; +import com.ycwl.basic.face.pipeline.core.StageResult; +import com.ycwl.basic.mapper.FaceMapper; +import com.ycwl.basic.model.pc.face.entity.FaceEntity; +import com.ycwl.basic.model.task.resp.SearchFaceRespVo; +import com.ycwl.basic.repository.FaceRepository; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.math.BigDecimal; +import java.util.Arrays; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +/** + * UpdateFaceResultStage 单元测试 + */ +@ExtendWith(MockitoExtension.class) +class UpdateFaceResultStageTest { + + @Mock + private FaceMapper faceMapper; + + @Mock + private FaceRepository faceRepository; + + @InjectMocks + private UpdateFaceResultStage stage; + + private FaceMatchingContext context; + private FaceEntity originalFace; + + @BeforeEach + void setUp() { + context = FaceMatchingContext.forAutoMatching(1L, true); + + originalFace = new FaceEntity(); + originalFace.setId(1L); + originalFace.setScenicId(10L); + originalFace.setMemberId(100L); + originalFace.setFaceUrl("http://example.com/face.jpg"); + + context.setFace(originalFace); + } + + @Test + void testExecute_Success() { + // Given + SearchFaceRespVo searchResult = new SearchFaceRespVo(); + searchResult.setScore(0.85f); + searchResult.setFirstMatchRate(0.92f); + searchResult.setSampleListIds(Arrays.asList(101L, 102L, 103L)); + searchResult.setSearchResultJson("{\"score\":0.85}"); + context.setSearchResult(searchResult); + + ArgumentCaptor faceCaptor = ArgumentCaptor.forClass(FaceEntity.class); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + verify(faceMapper, times(1)).update(faceCaptor.capture()); + verify(faceRepository, times(1)).clearFaceCache(1L); + + FaceEntity updatedFace = faceCaptor.getValue(); + assertEquals(1L, updatedFace.getId()); + assertEquals(0.85f, updatedFace.getScore(), 0.0001); + assertEquals("{\"score\":0.85}", updatedFace.getMatchResult()); + assertEquals(BigDecimal.valueOf(0.92f), updatedFace.getFirstMatchRate()); + assertEquals("101,102,103", updatedFace.getMatchSampleIds()); + assertEquals(10L, updatedFace.getScenicId()); + assertEquals(100L, updatedFace.getMemberId()); + assertEquals("http://example.com/face.jpg", updatedFace.getFaceUrl()); + assertNotNull(updatedFace.getCreateAt()); + } + + @Test + void testExecute_SearchResultNull_Skipped() { + // Given + context.setSearchResult(null); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSkipped()); + assertTrue(result.getMessage().contains("searchResult为空")); + verify(faceMapper, never()).update(any()); + verify(faceRepository, never()).clearFaceCache(any()); + } + + @Test + void testExecute_NoFirstMatchRate() { + // Given + SearchFaceRespVo searchResult = new SearchFaceRespVo(); + searchResult.setScore(0.75f); + searchResult.setFirstMatchRate(null); // 无firstMatchRate + searchResult.setSampleListIds(Arrays.asList(101L)); + context.setSearchResult(searchResult); + + ArgumentCaptor faceCaptor = ArgumentCaptor.forClass(FaceEntity.class); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + verify(faceMapper, times(1)).update(faceCaptor.capture()); + + FaceEntity updatedFace = faceCaptor.getValue(); + assertNull(updatedFace.getFirstMatchRate()); + } + + @Test + void testExecute_NoSampleListIds() { + // Given + SearchFaceRespVo searchResult = new SearchFaceRespVo(); + searchResult.setScore(0.65f); + searchResult.setSampleListIds(null); // 无样本ID列表 + context.setSearchResult(searchResult); + + ArgumentCaptor faceCaptor = ArgumentCaptor.forClass(FaceEntity.class); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + verify(faceMapper, times(1)).update(faceCaptor.capture()); + + FaceEntity updatedFace = faceCaptor.getValue(); + assertNull(updatedFace.getMatchSampleIds()); + } + + @Test + void testExecute_EmptySampleListIds() { + // Given + SearchFaceRespVo searchResult = new SearchFaceRespVo(); + searchResult.setScore(0.55f); + searchResult.setSampleListIds(Arrays.asList()); // 空列表 + context.setSearchResult(searchResult); + + ArgumentCaptor faceCaptor = ArgumentCaptor.forClass(FaceEntity.class); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + verify(faceMapper, times(1)).update(faceCaptor.capture()); + + FaceEntity updatedFace = faceCaptor.getValue(); + assertEquals("", updatedFace.getMatchSampleIds()); // 空列表连接为空字符串 + } + + @Test + void testExecute_MapperUpdateFailed() { + // Given + SearchFaceRespVo searchResult = new SearchFaceRespVo(); + searchResult.setScore(0.85f); + searchResult.setSampleListIds(Arrays.asList(101L)); + context.setSearchResult(searchResult); + + doThrow(new RuntimeException("Database error")).when(faceMapper).update(any()); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isFailed()); + assertTrue(result.getMessage().contains("保存人脸识别结果失败")); + assertNotNull(result.getException()); + verify(faceRepository, never()).clearFaceCache(any()); // 更新失败时不应清缓存 + } + + @Test + void testExecute_CacheClearFailed() { + // Given + SearchFaceRespVo searchResult = new SearchFaceRespVo(); + searchResult.setScore(0.85f); + searchResult.setSampleListIds(Arrays.asList(101L)); + context.setSearchResult(searchResult); + + // faceMapper.update()正常执行(不需要mock void方法) + doThrow(new RuntimeException("Cache clear error")).when(faceRepository).clearFaceCache(any()); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isFailed()); // 缓存清理失败导致整体失败 + verify(faceMapper, times(1)).update(any()); // Mapper仍然被调用了 + } + + @Test + void testExecute_HighScore() { + // Given: 高分匹配 + SearchFaceRespVo searchResult = new SearchFaceRespVo(); + searchResult.setScore(0.98f); + searchResult.setFirstMatchRate(0.99f); + searchResult.setSampleListIds(Arrays.asList(101L)); + context.setSearchResult(searchResult); + + ArgumentCaptor faceCaptor = ArgumentCaptor.forClass(FaceEntity.class); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + verify(faceMapper, times(1)).update(faceCaptor.capture()); + + FaceEntity updatedFace = faceCaptor.getValue(); + assertEquals(0.98f, updatedFace.getScore(), 0.0001); + } + + @Test + void testExecute_MultipleSamples() { + // Given: 多个样本 + SearchFaceRespVo searchResult = new SearchFaceRespVo(); + searchResult.setScore(0.78f); + searchResult.setSampleListIds(Arrays.asList(101L, 102L, 103L, 104L, 105L)); + context.setSearchResult(searchResult); + + ArgumentCaptor faceCaptor = ArgumentCaptor.forClass(FaceEntity.class); + + // When + StageResult result = stage.execute(context); + + // Then + assertTrue(result.isSuccess()); + verify(faceMapper, times(1)).update(faceCaptor.capture()); + + FaceEntity updatedFace = faceCaptor.getValue(); + assertEquals("101,102,103,104,105", updatedFace.getMatchSampleIds()); + } +}