fix(pipeline): 增强人脸匹配流水线的健壮性

- 在FilterByTimeRangeStage中增加空值检查和配置验证
- 在LoadMatchedSamplesStage中增加sampleListIds空值检查
- 添加完整的集成测试覆盖Pipeline工厂和Context构建
- 为FilterByDevicePhotoLimitStage添加全面的单元测试
- 为FilterByTimeRangeStage添加边界条件和异常处理测试
- 为LoadMatchedSamplesStage添加异常路径测试
This commit is contained in:
2025-12-03 18:17:34 +08:00
parent 96e75a458f
commit b3fa10e8fd
6 changed files with 800 additions and 1 deletions

View File

@@ -69,9 +69,23 @@ public class FilterByTimeRangeStage extends AbstractFaceMatchingStage<FaceMatchi
protected StageResult<FaceMatchingContext> doExecute(FaceMatchingContext context) {
List<FaceSampleEntity> faceSamples = context.getFaceSamples();
List<Long> sampleListIds = context.getSampleListIds();
Integer tourMinutes = context.getScenicConfig().getInteger("tour_time");
Long faceId = context.getFaceId();
// 防御性检查:faceSamples为空
if (faceSamples == null || faceSamples.isEmpty()) {
log.debug("faceSamples为空,跳过时间范围筛选,faceId={}", faceId);
return StageResult.skipped("faceSamples为空");
}
// 防御性检查:tour_time配置
Integer tourMinutes = context.getScenicConfig() != null
? context.getScenicConfig().getInteger("tour_time")
: null;
if (tourMinutes == null || tourMinutes <= 0) {
log.debug("景区未配置tour_time或配置为0,跳过时间范围筛选,faceId={}", faceId);
return StageResult.skipped("未配置tour_time");
}
try {
// 1. 构建样本ID到实体的映射
Map<Long, FaceSampleEntity> sampleMap = faceSamples.stream()

View File

@@ -64,6 +64,12 @@ public class LoadMatchedSamplesStage extends AbstractFaceMatchingStage<FaceMatch
List<Long> sampleListIds = context.getSampleListIds();
Long faceId = context.getFaceId();
// 防御性检查:如果sampleListIds为空,直接跳过
if (sampleListIds == null || sampleListIds.isEmpty()) {
log.debug("sampleListIds为空,跳过加载匹配样本,faceId={}", faceId);
return StageResult.skipped("sampleListIds为空");
}
try {
// 批量加载样本实体
List<FaceSampleEntity> faceSamples = faceSampleMapper.listByIds(sampleListIds);

View File

@@ -0,0 +1,200 @@
package com.ycwl.basic.face.pipeline.integration;
import com.ycwl.basic.face.pipeline.core.FaceMatchingContext;
import com.ycwl.basic.face.pipeline.core.Pipeline;
import com.ycwl.basic.face.pipeline.enums.FaceMatchingScene;
import com.ycwl.basic.face.pipeline.factory.FaceMatchingPipelineFactory;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.test.context.ActiveProfiles;
import java.util.Arrays;
import static org.junit.jupiter.api.Assertions.*;
/**
* Pipeline集成测试
* 测试Pipeline的完整流程和Stage协作
*/
@SpringBootTest
@ActiveProfiles("test")
class FaceMatchingPipelineIntegrationTest {
@Autowired
private FaceMatchingPipelineFactory pipelineFactory;
/**
* 测试Pipeline工厂能够成功创建Pipeline
*/
@Test
void testCreatePipelines() {
// When: 创建各种场景的Pipeline
Pipeline<FaceMatchingContext> autoMatchingNew = pipelineFactory.createAutoMatchingPipeline(true);
Pipeline<FaceMatchingContext> autoMatchingOld = pipelineFactory.createAutoMatchingPipeline(false);
Pipeline<FaceMatchingContext> customMatching = pipelineFactory.createCustomMatchingPipeline();
Pipeline<FaceMatchingContext> recognitionOnly = pipelineFactory.createRecognitionOnlyPipeline();
// Then: 验证Pipeline创建成功
assertNotNull(autoMatchingNew);
assertNotNull(autoMatchingOld);
assertNotNull(customMatching);
assertNotNull(recognitionOnly);
// 验证Stage数量符合预期
assertEquals(13, autoMatchingNew.getStageCount());
assertEquals(13, autoMatchingOld.getStageCount());
assertEquals(15, customMatching.getStageCount());
assertEquals(3, recognitionOnly.getStageCount());
}
/**
* 测试通过场景和isNew参数创建Pipeline
*/
@Test
void testCreatePipelineByScene() {
// When
Pipeline<FaceMatchingContext> autoNew = pipelineFactory.createPipeline(FaceMatchingScene.AUTO_MATCHING, true);
Pipeline<FaceMatchingContext> autoOld = pipelineFactory.createPipeline(FaceMatchingScene.AUTO_MATCHING, false);
Pipeline<FaceMatchingContext> custom = pipelineFactory.createPipeline(FaceMatchingScene.CUSTOM_MATCHING, false);
Pipeline<FaceMatchingContext> recognition = pipelineFactory.createPipeline(FaceMatchingScene.RECOGNITION_ONLY, false);
// Then
assertNotNull(autoNew);
assertNotNull(autoOld);
assertNotNull(custom);
assertNotNull(recognition);
}
/**
* 测试通过Context创建Pipeline
*/
@Test
void testCreatePipelineByContext() {
// Given
FaceMatchingContext autoContext = FaceMatchingContext.forAutoMatching(1L, true);
FaceMatchingContext customContext = FaceMatchingContext.forCustomMatching(2L, Arrays.asList(101L, 102L));
FaceMatchingContext recognitionContext = FaceMatchingContext.forRecognitionOnly(3L);
// When
Pipeline<FaceMatchingContext> autoPipeline = pipelineFactory.createPipeline(autoContext);
Pipeline<FaceMatchingContext> customPipeline = pipelineFactory.createPipeline(customContext);
Pipeline<FaceMatchingContext> recognitionPipeline = pipelineFactory.createPipeline(recognitionContext);
// Then
assertNotNull(autoPipeline);
assertNotNull(customPipeline);
assertNotNull(recognitionPipeline);
}
/**
* 测试Pipeline名称
*/
@Test
void testPipelineNames() {
// When
Pipeline<FaceMatchingContext> autoNew = pipelineFactory.createAutoMatchingPipeline(true);
Pipeline<FaceMatchingContext> autoOld = pipelineFactory.createAutoMatchingPipeline(false);
Pipeline<FaceMatchingContext> custom = pipelineFactory.createCustomMatchingPipeline();
Pipeline<FaceMatchingContext> recognition = pipelineFactory.createRecognitionOnlyPipeline();
// Then
assertTrue(autoNew.getName().contains("AutoMatching"));
assertTrue(autoNew.getName().contains("New"));
assertTrue(autoOld.getName().contains("AutoMatching"));
assertTrue(autoOld.getName().contains("Old"));
assertTrue(custom.getName().contains("CustomMatching"));
assertTrue(recognition.getName().contains("RecognitionOnly"));
}
/**
* 测试Context构建
*/
@Test
void testContextCreation() {
// When: 使用不同的工厂方法创建Context
FaceMatchingContext autoContext = FaceMatchingContext.forAutoMatching(100L, true);
FaceMatchingContext customContext = FaceMatchingContext.forCustomMatching(200L, Arrays.asList(1L, 2L, 3L));
FaceMatchingContext recognitionContext = FaceMatchingContext.forRecognitionOnly(300L);
// Then: 验证Context属性
assertEquals(100L, autoContext.getFaceId());
assertTrue(autoContext.isNew());
assertEquals(FaceMatchingScene.AUTO_MATCHING, autoContext.getScene());
assertEquals(200L, customContext.getFaceId());
assertFalse(customContext.isNew());
assertEquals(FaceMatchingScene.CUSTOM_MATCHING, customContext.getScene());
assertEquals(3, customContext.getFaceSampleIds().size());
assertEquals(300L, recognitionContext.getFaceId());
assertFalse(recognitionContext.isNew());
assertEquals(FaceMatchingScene.RECOGNITION_ONLY, recognitionContext.getScene());
}
/**
* 测试Context的Stage开关配置
*/
@Test
void testContextStageConfiguration() {
// Given
FaceMatchingContext context = FaceMatchingContext.forAutoMatching(1L, true);
// When: 配置Stage开关
context.enableStage("stage1");
context.disableStage("stage2");
context.setStageState("stage3", true);
// Then
assertTrue(context.isStageEnabled("stage1"));
assertFalse(context.isStageEnabled("stage2"));
assertTrue(context.isStageEnabled("stage3"));
assertFalse(context.isStageEnabled("non_exist_stage")); // 默认false
}
/**
* 测试Context Builder
*/
@Test
void testContextBuilder() {
// When
FaceMatchingContext context = FaceMatchingContext.builder()
.faceId(999L)
.isNew(true)
.scene(FaceMatchingScene.AUTO_MATCHING)
.build();
// Then
assertEquals(999L, context.getFaceId());
assertTrue(context.isNew());
assertEquals(FaceMatchingScene.AUTO_MATCHING, context.getScene());
}
/**
* 测试Builder参数校验
*/
@Test
void testContextBuilderValidation() {
// When & Then: faceId为null应该抛异常
assertThrows(IllegalArgumentException.class, () -> {
FaceMatchingContext.builder()
.scene(FaceMatchingScene.AUTO_MATCHING)
.build();
});
// When & Then: scene为null应该抛异常
assertThrows(IllegalArgumentException.class, () -> {
FaceMatchingContext.builder()
.faceId(1L)
.build();
});
// When & Then: CUSTOM_MATCHING场景必须提供faceSampleIds
assertThrows(IllegalArgumentException.class, () -> {
FaceMatchingContext.builder()
.faceId(1L)
.scene(FaceMatchingScene.CUSTOM_MATCHING)
.build();
});
}
}

View File

@@ -0,0 +1,213 @@
package com.ycwl.basic.face.pipeline.stages;
import com.ycwl.basic.face.pipeline.core.FaceMatchingContext;
import com.ycwl.basic.face.pipeline.core.StageResult;
import com.ycwl.basic.integration.common.manager.DeviceConfigManager;
import com.ycwl.basic.model.pc.faceSample.entity.FaceSampleEntity;
import com.ycwl.basic.repository.DeviceRepository;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import java.util.*;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.*;
/**
* FilterByDevicePhotoLimitStage 单元测试
*/
@ExtendWith(MockitoExtension.class)
class FilterByDevicePhotoLimitStageTest {
@InjectMocks
private FilterByDevicePhotoLimitStage stage;
@Mock
private DeviceRepository deviceRepository;
@Mock
private DeviceConfigManager deviceConfigManager;
private FaceMatchingContext context;
@BeforeEach
void setUp() {
context = FaceMatchingContext.forCustomMatching(1L, Arrays.asList(101L, 102L));
}
@Test
void testExecute_Success_FilterByLimit() {
// Given: 设备1有5个样本,限制为2张
when(deviceRepository.getDeviceConfigManager(1L)).thenReturn(deviceConfigManager);
when(deviceConfigManager.getInteger("limit_photo")).thenReturn(2);
Date baseTime = new Date();
FaceSampleEntity sample1 = createSample(101L, 1L, new Date(baseTime.getTime() - 5000));
FaceSampleEntity sample2 = createSample(102L, 1L, new Date(baseTime.getTime() - 4000));
FaceSampleEntity sample3 = createSample(103L, 1L, new Date(baseTime.getTime() - 3000));
FaceSampleEntity sample4 = createSample(104L, 1L, new Date(baseTime.getTime() - 2000));
FaceSampleEntity sample5 = createSample(105L, 1L, new Date(baseTime.getTime() - 1000));
context.setFaceSamples(Arrays.asList(sample1, sample2, sample3, sample4, sample5));
context.setSampleListIds(Arrays.asList(101L, 102L, 103L, 104L, 105L));
// When
StageResult<FaceMatchingContext> result = stage.execute(context);
// Then
assertTrue(result.isSuccess());
List<Long> filteredIds = context.getSampleListIds();
assertEquals(2, filteredIds.size()); // 限制为2张
}
@Test
void testExecute_EmptyFaceSamples_Skip() {
// Given
context.setFaceSamples(Collections.emptyList());
context.setSampleListIds(Arrays.asList(101L, 102L));
// When
StageResult<FaceMatchingContext> result = stage.execute(context);
// Then
assertTrue(result.isSkipped());
}
@Test
void testExecute_NullFaceSamples_Skip() {
// Given
context.setFaceSamples(null);
context.setSampleListIds(Arrays.asList(101L, 102L));
// When
StageResult<FaceMatchingContext> result = stage.execute(context);
// Then
assertTrue(result.isSkipped());
}
@Test
void testExecute_NoLimitPhoto_KeepAll() {
// Given: 设备配置limit_photo为null,不限制
when(deviceRepository.getDeviceConfigManager(1L)).thenReturn(deviceConfigManager);
when(deviceConfigManager.getInteger("limit_photo")).thenReturn(null);
FaceSampleEntity sample1 = createSample(101L, 1L, new Date());
FaceSampleEntity sample2 = createSample(102L, 1L, new Date());
FaceSampleEntity sample3 = createSample(103L, 1L, new Date());
context.setFaceSamples(Arrays.asList(sample1, sample2, sample3));
context.setSampleListIds(Arrays.asList(101L, 102L, 103L));
// When
StageResult<FaceMatchingContext> result = stage.execute(context);
// Then
assertTrue(result.isSuccess());
assertEquals(3, context.getSampleListIds().size()); // 全部保留
}
@Test
void testExecute_LimitPhotoZero_KeepAll() {
// Given: limit_photo=0,不限制
when(deviceRepository.getDeviceConfigManager(1L)).thenReturn(deviceConfigManager);
when(deviceConfigManager.getInteger("limit_photo")).thenReturn(0);
FaceSampleEntity sample1 = createSample(101L, 1L, new Date());
FaceSampleEntity sample2 = createSample(102L, 1L, new Date());
context.setFaceSamples(Arrays.asList(sample1, sample2));
context.setSampleListIds(Arrays.asList(101L, 102L));
// When
StageResult<FaceMatchingContext> result = stage.execute(context);
// Then
assertTrue(result.isSuccess());
assertEquals(2, context.getSampleListIds().size());
}
@Test
void testExecute_MultipleDevices() {
// Given: 设备1限制2张,设备2限制3张
when(deviceRepository.getDeviceConfigManager(1L)).thenReturn(deviceConfigManager);
when(deviceConfigManager.getInteger("limit_photo")).thenReturn(2);
DeviceConfigManager deviceConfig2 = mock(DeviceConfigManager.class);
when(deviceRepository.getDeviceConfigManager(2L)).thenReturn(deviceConfig2);
when(deviceConfig2.getInteger("limit_photo")).thenReturn(3);
Date baseTime = new Date();
FaceSampleEntity sample1 = createSample(101L, 1L, new Date(baseTime.getTime() - 3000));
FaceSampleEntity sample2 = createSample(102L, 1L, new Date(baseTime.getTime() - 2000));
FaceSampleEntity sample3 = createSample(103L, 1L, new Date(baseTime.getTime() - 1000));
FaceSampleEntity sample4 = createSample(104L, 2L, new Date(baseTime.getTime() - 5000));
FaceSampleEntity sample5 = createSample(105L, 2L, new Date(baseTime.getTime() - 4000));
FaceSampleEntity sample6 = createSample(106L, 2L, new Date(baseTime.getTime() - 3000));
FaceSampleEntity sample7 = createSample(107L, 2L, new Date(baseTime.getTime() - 2000));
context.setFaceSamples(Arrays.asList(sample1, sample2, sample3, sample4, sample5, sample6, sample7));
context.setSampleListIds(Arrays.asList(101L, 102L, 103L, 104L, 105L, 106L, 107L));
// When
StageResult<FaceMatchingContext> result = stage.execute(context);
// Then
assertTrue(result.isSuccess());
List<Long> filteredIds = context.getSampleListIds();
assertTrue(filteredIds.size() <= 5); // 设备1最多2张 + 设备2最多3张
}
@Test
void testExecute_SampleWithoutDeviceId_Passthrough() {
// Given: 有样本没有deviceId,应该直接保留
when(deviceRepository.getDeviceConfigManager(1L)).thenReturn(deviceConfigManager);
when(deviceConfigManager.getInteger("limit_photo")).thenReturn(1);
FaceSampleEntity sample1 = createSample(101L, 1L, new Date());
FaceSampleEntity sample2 = createSample(102L, 1L, new Date());
FaceSampleEntity sample3 = createSample(103L, null, new Date()); // 无deviceId
context.setFaceSamples(Arrays.asList(sample1, sample2, sample3));
context.setSampleListIds(Arrays.asList(101L, 102L, 103L));
// When
StageResult<FaceMatchingContext> result = stage.execute(context);
// Then
assertTrue(result.isSuccess());
List<Long> filteredIds = context.getSampleListIds();
assertTrue(filteredIds.contains(103L)); // 无deviceId的样本被保留
assertTrue(filteredIds.size() >= 2); // 至少包含103和设备1的1张样本
}
@Test
void testExecute_Exception_Degraded() {
// Given
when(deviceRepository.getDeviceConfigManager(anyLong())).thenThrow(new RuntimeException("DB error"));
FaceSampleEntity sample1 = createSample(101L, 1L, new Date());
context.setFaceSamples(Arrays.asList(sample1));
context.setSampleListIds(Arrays.asList(101L));
// When
StageResult<FaceMatchingContext> result = stage.execute(context);
// Then
assertTrue(result.isDegraded());
assertTrue(result.getMessage().contains("设备照片数量限制筛选失败"));
}
private FaceSampleEntity createSample(Long id, Long deviceId, Date createAt) {
FaceSampleEntity sample = new FaceSampleEntity();
sample.setId(id);
sample.setDeviceId(deviceId);
sample.setCreateAt(createAt);
return sample;
}
}

View File

@@ -0,0 +1,211 @@
package com.ycwl.basic.face.pipeline.stages;
import com.ycwl.basic.face.pipeline.core.FaceMatchingContext;
import com.ycwl.basic.face.pipeline.core.StageResult;
import com.ycwl.basic.integration.common.manager.ScenicConfigManager;
import com.ycwl.basic.model.pc.faceSample.entity.FaceSampleEntity;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import java.util.*;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;
/**
* FilterByTimeRangeStage 单元测试
*/
@ExtendWith(MockitoExtension.class)
class FilterByTimeRangeStageTest {
@InjectMocks
private FilterByTimeRangeStage stage;
@Mock
private ScenicConfigManager scenicConfig;
private FaceMatchingContext context;
@BeforeEach
void setUp() {
context = FaceMatchingContext.forCustomMatching(1L, Arrays.asList(101L, 102L));
context.setScenicConfig(scenicConfig);
}
@Test
void testExecute_Success_FilterBySampleTime() {
// Given: 设置tour_time为30分钟
when(scenicConfig.getInteger("tour_time")).thenReturn(30);
Date baseTime = new Date();
Date time10MinBefore = new Date(baseTime.getTime() - 10 * 60 * 1000); // 10分钟前
Date time20MinBefore = new Date(baseTime.getTime() - 20 * 60 * 1000); // 20分钟前
Date time40MinBefore = new Date(baseTime.getTime() - 40 * 60 * 1000); // 40分钟前 (超出范围)
FaceSampleEntity sample1 = createSample(101L, baseTime);
FaceSampleEntity sample2 = createSample(102L, time10MinBefore);
FaceSampleEntity sample3 = createSample(103L, time20MinBefore);
FaceSampleEntity sample4 = createSample(104L, time40MinBefore);
context.setFaceSamples(Arrays.asList(sample1, sample2, sample3, sample4));
context.setSampleListIds(Arrays.asList(101L, 102L, 103L, 104L));
// When
StageResult<FaceMatchingContext> result = stage.execute(context);
// Then
assertTrue(result.isSuccess());
List<Long> filteredIds = context.getSampleListIds();
assertEquals(3, filteredIds.size());
assertTrue(filteredIds.contains(101L));
assertTrue(filteredIds.contains(102L));
assertTrue(filteredIds.contains(103L));
assertFalse(filteredIds.contains(104L)); // 40分钟前的被过滤
}
@Test
void testExecute_EmptyFaceSamples_Skip() {
// Given
context.setFaceSamples(Collections.emptyList());
context.setSampleListIds(Arrays.asList(101L, 102L));
// When
StageResult<FaceMatchingContext> result = stage.execute(context);
// Then
assertTrue(result.isSkipped());
}
@Test
void testExecute_NullFaceSamples_Skip() {
// Given
context.setFaceSamples(null);
context.setSampleListIds(Arrays.asList(101L, 102L));
// When
StageResult<FaceMatchingContext> result = stage.execute(context);
// Then
assertTrue(result.isSkipped());
}
@Test
void testExecute_TourTimeZero_Skip() {
// Given
when(scenicConfig.getInteger("tour_time")).thenReturn(0);
FaceSampleEntity sample1 = createSample(101L, new Date());
context.setFaceSamples(Arrays.asList(sample1));
context.setSampleListIds(Arrays.asList(101L));
// When
StageResult<FaceMatchingContext> result = stage.execute(context);
// Then
assertTrue(result.isSkipped());
}
@Test
void testExecute_TourTimeNull_Skip() {
// Given
when(scenicConfig.getInteger("tour_time")).thenReturn(null);
FaceSampleEntity sample1 = createSample(101L, new Date());
context.setFaceSamples(Arrays.asList(sample1));
context.setSampleListIds(Arrays.asList(101L));
// When
StageResult<FaceMatchingContext> result = stage.execute(context);
// Then
assertTrue(result.isSkipped());
}
@Test
void testExecute_NoValidCreateTime_Success() {
// Given
when(scenicConfig.getInteger("tour_time")).thenReturn(30);
FaceSampleEntity sample1 = createSample(101L, null); // 无创建时间
FaceSampleEntity sample2 = createSample(102L, null);
context.setFaceSamples(Arrays.asList(sample1, sample2));
context.setSampleListIds(Arrays.asList(101L, 102L));
// When
StageResult<FaceMatchingContext> result = stage.execute(context);
// Then
assertTrue(result.isSuccess());
assertTrue(result.getMessage().contains("样本无拍摄时间"));
}
@Test
void testExecute_SampleNotInCache_Filtered() {
// Given: sampleListIds中有样本ID,但faceSamples缓存中没有
when(scenicConfig.getInteger("tour_time")).thenReturn(30);
FaceSampleEntity sample1 = createSample(101L, new Date());
context.setFaceSamples(Arrays.asList(sample1));
context.setSampleListIds(Arrays.asList(101L, 102L, 103L)); // 102和103不在缓存中
// When
StageResult<FaceMatchingContext> result = stage.execute(context);
// Then
assertTrue(result.isSuccess());
List<Long> filteredIds = context.getSampleListIds();
assertEquals(1, filteredIds.size());
assertEquals(101L, filteredIds.get(0));
}
@Test
void testExecute_AllSamplesWithinRange() {
// Given: 所有样本都在时间范围内
when(scenicConfig.getInteger("tour_time")).thenReturn(60);
Date baseTime = new Date();
Date time30MinBefore = new Date(baseTime.getTime() - 30 * 60 * 1000);
FaceSampleEntity sample1 = createSample(101L, baseTime);
FaceSampleEntity sample2 = createSample(102L, time30MinBefore);
context.setFaceSamples(Arrays.asList(sample1, sample2));
context.setSampleListIds(Arrays.asList(101L, 102L));
// When
StageResult<FaceMatchingContext> result = stage.execute(context);
// Then
assertTrue(result.isSuccess());
assertEquals(2, context.getSampleListIds().size());
}
@Test
void testExecute_Exception_Degraded() {
// Given
when(scenicConfig.getInteger("tour_time")).thenReturn(30);
FaceSampleEntity sample1 = createSample(101L, new Date());
context.setFaceSamples(Arrays.asList(sample1));
context.setSampleListIds(null); // 会导致异常
// When
StageResult<FaceMatchingContext> result = stage.execute(context);
// Then
assertTrue(result.isDegraded());
assertTrue(result.getMessage().contains("时间范围筛选失败"));
}
private FaceSampleEntity createSample(Long id, Date createAt) {
FaceSampleEntity sample = new FaceSampleEntity();
sample.setId(id);
sample.setCreateAt(createAt);
return sample;
}
}

View File

@@ -0,0 +1,155 @@
package com.ycwl.basic.face.pipeline.stages;
import com.ycwl.basic.face.pipeline.core.FaceMatchingContext;
import com.ycwl.basic.face.pipeline.core.StageResult;
import com.ycwl.basic.face.pipeline.enums.FaceMatchingScene;
import com.ycwl.basic.mapper.FaceSampleMapper;
import com.ycwl.basic.model.pc.faceSample.entity.FaceSampleEntity;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.Mockito.*;
/**
* LoadMatchedSamplesStage 单元测试
*/
@ExtendWith(MockitoExtension.class)
class LoadMatchedSamplesStageTest {
@Mock
private FaceSampleMapper faceSampleMapper;
@InjectMocks
private LoadMatchedSamplesStage stage;
private FaceMatchingContext context;
@BeforeEach
void setUp() {
context = FaceMatchingContext.forCustomMatching(1L, Arrays.asList(101L, 102L));
}
@Test
void testExecute_Success() {
// Given
List<Long> sampleIds = Arrays.asList(101L, 102L, 103L);
context.setSampleListIds(sampleIds);
FaceSampleEntity sample1 = createSample(101L, 1L, new Date());
FaceSampleEntity sample2 = createSample(102L, 1L, new Date());
FaceSampleEntity sample3 = createSample(103L, 2L, new Date());
when(faceSampleMapper.listByIds(sampleIds))
.thenReturn(Arrays.asList(sample1, sample2, sample3));
// When
StageResult<FaceMatchingContext> result = stage.execute(context);
// Then
assertTrue(result.isSuccess());
assertEquals(3, context.getFaceSamples().size());
assertEquals(101L, context.getFaceSamples().get(0).getId());
verify(faceSampleMapper, times(1)).listByIds(sampleIds);
}
@Test
void testExecute_EmptySampleListIds_Skip() {
// Given
context.setSampleListIds(Collections.emptyList());
// When
StageResult<FaceMatchingContext> result = stage.execute(context);
// Then
assertTrue(result.isSkipped());
verify(faceSampleMapper, never()).listByIds(anyList());
}
@Test
void testExecute_NullSampleListIds_Skip() {
// Given
context.setSampleListIds(null);
// When
StageResult<FaceMatchingContext> result = stage.execute(context);
// Then
assertTrue(result.isSkipped());
verify(faceSampleMapper, never()).listByIds(anyList());
}
@Test
void testExecute_NoSamplesFound_Skipped() {
// Given
List<Long> sampleIds = Arrays.asList(101L, 102L);
context.setSampleListIds(sampleIds);
when(faceSampleMapper.listByIds(sampleIds))
.thenReturn(Collections.emptyList());
// When
StageResult<FaceMatchingContext> result = stage.execute(context);
// Then
assertTrue(result.isSkipped());
assertNull(context.getFaceSamples());
verify(faceSampleMapper, times(1)).listByIds(sampleIds);
}
@Test
void testExecute_MapperThrowsException_Failure() {
// Given
List<Long> sampleIds = Arrays.asList(101L, 102L);
context.setSampleListIds(sampleIds);
when(faceSampleMapper.listByIds(sampleIds))
.thenThrow(new RuntimeException("Database error"));
// When
StageResult<FaceMatchingContext> result = stage.execute(context);
// Then
assertTrue(result.isFailed());
assertTrue(result.getMessage().contains("加载匹配样本实体失败"));
verify(faceSampleMapper, times(1)).listByIds(sampleIds);
}
@Test
void testExecute_PartialResults() {
// Given: 请求3个样本,只返回2个
List<Long> sampleIds = Arrays.asList(101L, 102L, 103L);
context.setSampleListIds(sampleIds);
FaceSampleEntity sample1 = createSample(101L, 1L, new Date());
FaceSampleEntity sample2 = createSample(102L, 1L, new Date());
when(faceSampleMapper.listByIds(sampleIds))
.thenReturn(Arrays.asList(sample1, sample2));
// When
StageResult<FaceMatchingContext> result = stage.execute(context);
// Then
assertTrue(result.isSuccess());
assertEquals(2, context.getFaceSamples().size());
}
private FaceSampleEntity createSample(Long id, Long deviceId, Date createAt) {
FaceSampleEntity sample = new FaceSampleEntity();
sample.setId(id);
sample.setDeviceId(deviceId);
sample.setCreateAt(createAt);
return sample;
}
}