diff --git a/src/main/java/com/ycwl/basic/puzzle/edge/task/PuzzleEdgeRenderTaskService.java b/src/main/java/com/ycwl/basic/puzzle/edge/task/PuzzleEdgeRenderTaskService.java index 6670928d..94b12c3f 100644 --- a/src/main/java/com/ycwl/basic/puzzle/edge/task/PuzzleEdgeRenderTaskService.java +++ b/src/main/java/com/ycwl/basic/puzzle/edge/task/PuzzleEdgeRenderTaskService.java @@ -31,10 +31,15 @@ import org.springframework.scheduling.annotation.Scheduled; import java.util.Date; import java.util.ArrayList; import java.util.HashMap; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicLong; /** @@ -58,6 +63,59 @@ public class PuzzleEdgeRenderTaskService { private static final long TASK_CACHE_EXPIRE_HOURS = 6L; private static final long TASK_CACHE_MAX_SIZE = 20000L; + private static final long WAIT_FUTURE_EXPIRE_MILLIS = TimeUnit.MINUTES.toMillis(10); + + /** + * 任务等待结果 + */ + public static class TaskWaitResult { + private final boolean success; + private final String errorMessage; + private final String imageUrl; + + private TaskWaitResult(boolean success, String errorMessage, String imageUrl) { + this.success = success; + this.errorMessage = errorMessage; + this.imageUrl = imageUrl; + } + + public static TaskWaitResult success(String imageUrl) { + return new TaskWaitResult(true, null, imageUrl); + } + + public static TaskWaitResult fail(String errorMessage) { + return new TaskWaitResult(false, errorMessage, null); + } + + public boolean isSuccess() { + return success; + } + + public String getErrorMessage() { + return errorMessage; + } + + public String getImageUrl() { + return imageUrl; + } + } + + /** + * 等待 future 的包装,包含创建时间用于过期清理 + */ + private static class WaitFutureEntry { + final CompletableFuture future; + final long createTimeMillis; + + WaitFutureEntry(CompletableFuture future) { + this.future = future; + this.createTimeMillis = System.currentTimeMillis(); + } + + boolean isExpired(long nowMillis) { + return nowMillis - createTimeMillis > WAIT_FUTURE_EXPIRE_MILLIS; + } + } /** * 任务内存池(单实例、允许丢失):仅用作 Worker 拉取与状态落地的中间态 @@ -70,6 +128,11 @@ public class PuzzleEdgeRenderTaskService { private final AtomicLong taskIdSequence = new AtomicLong(System.currentTimeMillis()); private final Object taskLock = new Object(); + /** + * 任务等待 future 池:用于伪同步等待任务完成 + */ + private final ConcurrentHashMap waitFutures = new ConcurrentHashMap<>(); + private final PuzzleGenerationRecordMapper recordMapper; private final PuzzleRepository puzzleRepository; private final PrinterService printerService; @@ -155,6 +218,9 @@ public class PuzzleEdgeRenderTaskService { renderDurationMs ); + // 通知等待方任务完成 + completeWaitFuture(taskId, TaskWaitResult.success(resultImageUrl)); + PuzzleTemplateEntity template = puzzleRepository.getTemplateById(task.getTemplateId()); if (template != null && template.getAutoAddPrint() != null && template.getAutoAddPrint() == 1) { try { @@ -191,6 +257,9 @@ public class PuzzleEdgeRenderTaskService { throw new IllegalStateException("任务状态更新失败"); } recordMapper.updateFail(task.getRecordId(), errorMessage); + + // 通知等待方任务失败 + completeWaitFuture(taskId, TaskWaitResult.fail(errorMessage)); } /** @@ -203,6 +272,7 @@ public class PuzzleEdgeRenderTaskService { public void timeoutFailAndRetry() { List retryRecordIds = new ArrayList<>(); Map failRecordMessages = new HashMap<>(); + Map failTaskMessages = new HashMap<>(); // taskId -> errorMessage synchronized (taskLock) { long now = System.currentTimeMillis(); @@ -234,6 +304,8 @@ public class PuzzleEdgeRenderTaskService { if (task.getRecordId() != null) { failRecordMessages.put(task.getRecordId(), errorMessage); } + // 记录需要通知的任务 + failTaskMessages.put(task.getId(), errorMessage); continue; } @@ -258,6 +330,29 @@ public class PuzzleEdgeRenderTaskService { for (Map.Entry entry : failRecordMessages.entrySet()) { recordMapper.updateFail(entry.getKey(), entry.getValue()); } + + // 通知等待方任务最终失败 + for (Map.Entry entry : failTaskMessages.entrySet()) { + completeWaitFuture(entry.getKey(), TaskWaitResult.fail(entry.getValue())); + } + + // 清理过期的等待 future + cleanupExpiredWaitFutures(); + } + + /** + * 清理过期的等待 future,防止内存泄漏 + */ + private void cleanupExpiredWaitFutures() { + long now = System.currentTimeMillis(); + Iterator> iterator = waitFutures.entrySet().iterator(); + while (iterator.hasNext()) { + Map.Entry entry = iterator.next(); + if (entry.getValue().isExpired(now)) { + entry.getValue().future.complete(TaskWaitResult.fail("等待超时(内部清理)")); + iterator.remove(); + } + } } private void incrementRecordRetryCount(Long recordId) { @@ -375,6 +470,98 @@ public class PuzzleEdgeRenderTaskService { return taskId; } + /** + * 注册任务等待,返回用于等待的 CompletableFuture + * 调用方应在 createRenderTask 之后立即调用此方法 + * + * @param taskId 任务ID + * @return CompletableFuture,可用于同步等待或异步处理 + */ + public CompletableFuture registerWait(Long taskId) { + if (taskId == null) { + CompletableFuture failedFuture = new CompletableFuture<>(); + failedFuture.complete(TaskWaitResult.fail("taskId不能为空")); + return failedFuture; + } + + CompletableFuture future = new CompletableFuture<>(); + waitFutures.put(taskId, new WaitFutureEntry(future)); + return future; + } + + /** + * 伪同步等待任务完成 + * 阻塞当前线程直到任务成功、失败或超时 + * + * @param taskId 任务ID + * @param timeoutMs 超时时间(毫秒) + * @return 任务结果,包含成功/失败状态和相关信息 + */ + public TaskWaitResult waitForTask(Long taskId, long timeoutMs) { + if (taskId == null) { + return TaskWaitResult.fail("taskId不能为空"); + } + + // 检查任务是否已完成 + PuzzleEdgeRenderTaskEntity task = taskCache.getIfPresent(taskId); + if (task != null) { + if (task.getStatus() != null && task.getStatus() == STATUS_SUCCESS) { + return buildSuccessResult(task); + } + if (task.getStatus() != null && task.getStatus() == STATUS_FAIL) { + return TaskWaitResult.fail(task.getErrorMessage()); + } + } + + // 获取或创建等待 future + WaitFutureEntry entry = waitFutures.computeIfAbsent(taskId, k -> new WaitFutureEntry(new CompletableFuture<>())); + + try { + return entry.future.get(timeoutMs, TimeUnit.MILLISECONDS); + } catch (TimeoutException e) { + waitFutures.remove(taskId); + return TaskWaitResult.fail("等待任务超时"); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + waitFutures.remove(taskId); + return TaskWaitResult.fail("等待被中断"); + } catch (ExecutionException e) { + waitFutures.remove(taskId); + return TaskWaitResult.fail("等待出错: " + e.getCause().getMessage()); + } + } + + /** + * 创建任务并同步等待结果(便捷方法) + */ + public TaskWaitResult createAndWait(PuzzleGenerationRecordEntity record, + PuzzleTemplateEntity template, + List sortedElements, + Map finalDynamicData, + String outputFormat, + Integer quality, + long timeoutMs) { + Long taskId = createRenderTask(record, template, sortedElements, finalDynamicData, outputFormat, quality); + registerWait(taskId); + return waitForTask(taskId, timeoutMs); + } + + private TaskWaitResult buildSuccessResult(PuzzleEdgeRenderTaskEntity task) { + IStorageAdapter storage = StorageFactory.use(); + String imageUrl = storage.getUrl(task.getOriginalObjectKey()); + return TaskWaitResult.success(imageUrl); + } + + /** + * 完成任务等待(内部调用) + */ + private void completeWaitFuture(Long taskId, TaskWaitResult result) { + WaitFutureEntry entry = waitFutures.remove(taskId); + if (entry != null && entry.future != null) { + entry.future.complete(result); + } + } + private PuzzleEdgeRenderTaskDTO toTaskDTOOrFail(PuzzleEdgeRenderTaskEntity task, Long workerId) { try { PuzzleEdgeRenderTaskDTO dto = new PuzzleEdgeRenderTaskDTO();