You've already forked DataMate
feat(auto-annotation): integrate YOLO auto-labeling and enhance data management (#223)
* feat(auto-annotation): initial setup * chore: remove package-lock.json * chore: 清理本地测试脚本与 Maven 设置 * chore: change package-lock.json
This commit is contained in:
@@ -231,6 +231,29 @@ paths:
|
||||
schema:
|
||||
$ref: '#/components/schemas/PagedDatasetFileResponse'
|
||||
|
||||
/data-management/datasets/{datasetId}/files/directories:
|
||||
post:
|
||||
tags: [ DatasetFile ]
|
||||
operationId: createDirectory
|
||||
summary: 在数据集下创建子目录
|
||||
description: 在指定数据集下的某个前缀路径中创建一个新的子目录
|
||||
parameters:
|
||||
- name: datasetId
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: 数据集ID
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/CreateDirectoryRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: 创建成功
|
||||
|
||||
/data-management/datasets/{datasetId}/files/{fileId}:
|
||||
get:
|
||||
tags: [DatasetFile]
|
||||
@@ -635,8 +658,23 @@ components:
|
||||
type: integer
|
||||
format: int64
|
||||
description: 总文件大小(字节)
|
||||
prefix:
|
||||
type: string
|
||||
description: 目标子目录前缀,例如 "images/",为空表示数据集根目录
|
||||
required: [ totalFileNum ]
|
||||
|
||||
CreateDirectoryRequest:
|
||||
type: object
|
||||
description: 创建数据集子目录请求
|
||||
properties:
|
||||
parentPrefix:
|
||||
type: string
|
||||
description: 父级前缀路径,例如 "images/",为空表示数据集根目录
|
||||
directoryName:
|
||||
type: string
|
||||
description: 新建目录名称
|
||||
required: [ directoryName ]
|
||||
|
||||
UploadFileRequest:
|
||||
type: object
|
||||
description: 分片上传请求
|
||||
|
||||
@@ -24,6 +24,7 @@ import com.datamate.datamanagement.infrastructure.persistence.repository.Dataset
|
||||
import com.datamate.datamanagement.interfaces.converter.DatasetConverter;
|
||||
import com.datamate.datamanagement.interfaces.dto.AddFilesRequest;
|
||||
import com.datamate.datamanagement.interfaces.dto.CopyFilesRequest;
|
||||
import com.datamate.datamanagement.interfaces.dto.CreateDirectoryRequest;
|
||||
import com.datamate.datamanagement.interfaces.dto.UploadFileRequest;
|
||||
import com.datamate.datamanagement.interfaces.dto.UploadFilesPreRequest;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
@@ -149,13 +150,48 @@ public class DatasetFileApplicationService {
|
||||
}
|
||||
datasetFile.setFileName(path.getFileName().toString());
|
||||
datasetFile.setUploadTime(localDateTime);
|
||||
|
||||
// 目录与普通文件区分处理
|
||||
if (Files.isDirectory(path)) {
|
||||
datasetFile.setId("directory-" + datasetFile.getFileName());
|
||||
} else if (Objects.isNull(datasetFilesMap.get(path.toString()))) {
|
||||
datasetFile.setDirectory(true);
|
||||
|
||||
// 统计目录下文件数量和总大小
|
||||
try {
|
||||
long fileCount;
|
||||
long totalSize;
|
||||
|
||||
try (Stream<Path> walk = Files.walk(path)) {
|
||||
fileCount = walk.filter(Files::isRegularFile).count();
|
||||
}
|
||||
|
||||
try (Stream<Path> walk = Files.walk(path)) {
|
||||
totalSize = walk
|
||||
.filter(Files::isRegularFile)
|
||||
.mapToLong(p -> {
|
||||
try {
|
||||
return Files.size(p);
|
||||
} catch (IOException e) {
|
||||
log.error("get file size error", e);
|
||||
return 0L;
|
||||
}
|
||||
})
|
||||
.sum();
|
||||
}
|
||||
|
||||
datasetFile.setFileCount(fileCount);
|
||||
datasetFile.setFileSize(totalSize);
|
||||
} catch (IOException e) {
|
||||
log.error("stat directory info error", e);
|
||||
}
|
||||
} else {
|
||||
DatasetFile exist = datasetFilesMap.get(path.toString());
|
||||
if (exist == null) {
|
||||
datasetFile.setId("file-" + datasetFile.getFileName());
|
||||
datasetFile.setFileSize(path.toFile().length());
|
||||
} else {
|
||||
datasetFile = datasetFilesMap.get(path.toString());
|
||||
datasetFile = exist;
|
||||
}
|
||||
}
|
||||
return datasetFile;
|
||||
}
|
||||
@@ -291,13 +327,27 @@ public class DatasetFileApplicationService {
|
||||
if (Objects.isNull(datasetRepository.getById(datasetId))) {
|
||||
throw BusinessException.of(DataManagementErrorCode.DATASET_NOT_FOUND);
|
||||
}
|
||||
|
||||
// 构建上传路径,如果有 prefix 则追加到路径中
|
||||
String prefix = Optional.ofNullable(chunkUploadRequest.getPrefix()).orElse("").trim();
|
||||
prefix = prefix.replace("\\", "/");
|
||||
while (prefix.startsWith("/")) {
|
||||
prefix = prefix.substring(1);
|
||||
}
|
||||
|
||||
String uploadPath = datasetBasePath + File.separator + datasetId;
|
||||
if (!prefix.isEmpty()) {
|
||||
uploadPath = uploadPath + File.separator + prefix.replace("/", File.separator);
|
||||
}
|
||||
|
||||
ChunkUploadPreRequest request = ChunkUploadPreRequest.builder().build();
|
||||
request.setUploadPath(datasetBasePath + File.separator + datasetId);
|
||||
request.setUploadPath(uploadPath);
|
||||
request.setTotalFileNum(chunkUploadRequest.getTotalFileNum());
|
||||
request.setServiceId(DatasetConstant.SERVICE_ID);
|
||||
DatasetFileUploadCheckInfo checkInfo = new DatasetFileUploadCheckInfo();
|
||||
checkInfo.setDatasetId(datasetId);
|
||||
checkInfo.setHasArchive(chunkUploadRequest.isHasArchive());
|
||||
checkInfo.setPrefix(prefix);
|
||||
try {
|
||||
ObjectMapper objectMapper = new ObjectMapper();
|
||||
String checkInfoJson = objectMapper.writeValueAsString(checkInfo);
|
||||
@@ -368,6 +418,211 @@ public class DatasetFileApplicationService {
|
||||
datasetRepository.updateById(dataset);
|
||||
}
|
||||
|
||||
/**
|
||||
* 在数据集下创建子目录
|
||||
*/
|
||||
@Transactional
|
||||
public void createDirectory(String datasetId, CreateDirectoryRequest req) {
|
||||
Dataset dataset = datasetRepository.getById(datasetId);
|
||||
if (dataset == null) {
|
||||
throw BusinessException.of(DataManagementErrorCode.DATASET_NOT_FOUND);
|
||||
}
|
||||
String datasetPath = dataset.getPath();
|
||||
String parentPrefix = Optional.ofNullable(req.getParentPrefix()).orElse("").trim();
|
||||
parentPrefix = parentPrefix.replace("\\", "/");
|
||||
while (parentPrefix.startsWith("/")) {
|
||||
parentPrefix = parentPrefix.substring(1);
|
||||
}
|
||||
|
||||
String directoryName = Optional.ofNullable(req.getDirectoryName()).orElse("").trim();
|
||||
if (directoryName.isEmpty()) {
|
||||
throw BusinessException.of(CommonErrorCode.PARAM_ERROR);
|
||||
}
|
||||
if (directoryName.contains("..") || directoryName.contains("/") || directoryName.contains("\\")) {
|
||||
throw BusinessException.of(CommonErrorCode.PARAM_ERROR);
|
||||
}
|
||||
|
||||
Path basePath = Paths.get(datasetPath);
|
||||
Path targetPath = parentPrefix.isEmpty()
|
||||
? basePath.resolve(directoryName)
|
||||
: basePath.resolve(parentPrefix).resolve(directoryName);
|
||||
|
||||
Path normalized = targetPath.normalize();
|
||||
if (!normalized.startsWith(basePath)) {
|
||||
throw BusinessException.of(CommonErrorCode.PARAM_ERROR);
|
||||
}
|
||||
|
||||
try {
|
||||
Files.createDirectories(normalized);
|
||||
} catch (IOException e) {
|
||||
log.error("Failed to create directory {} for dataset {}", normalized, datasetId, e);
|
||||
throw BusinessException.of(SystemErrorCode.FILE_SYSTEM_ERROR);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 下载目录为 ZIP 文件
|
||||
*/
|
||||
@Transactional(readOnly = true)
|
||||
public void downloadDirectory(String datasetId, String prefix, HttpServletResponse response) {
|
||||
Dataset dataset = datasetRepository.getById(datasetId);
|
||||
if (dataset == null) {
|
||||
throw BusinessException.of(DataManagementErrorCode.DATASET_NOT_FOUND);
|
||||
}
|
||||
|
||||
String datasetPath = dataset.getPath();
|
||||
prefix = Optional.ofNullable(prefix).orElse("").trim();
|
||||
prefix = prefix.replace("\\", "/");
|
||||
while (prefix.startsWith("/")) {
|
||||
prefix = prefix.substring(1);
|
||||
}
|
||||
while (prefix.endsWith("/")) {
|
||||
prefix = prefix.substring(0, prefix.length() - 1);
|
||||
}
|
||||
|
||||
Path basePath = Paths.get(datasetPath);
|
||||
Path targetPath = prefix.isEmpty() ? basePath : basePath.resolve(prefix);
|
||||
Path normalized = targetPath.normalize();
|
||||
|
||||
if (!normalized.startsWith(basePath)) {
|
||||
throw BusinessException.of(CommonErrorCode.PARAM_ERROR);
|
||||
}
|
||||
|
||||
if (!Files.exists(normalized) || !Files.isDirectory(normalized)) {
|
||||
throw BusinessException.of(DataManagementErrorCode.DIRECTORY_NOT_FOUND);
|
||||
}
|
||||
|
||||
String zipFileName = prefix.isEmpty() ? dataset.getName() : prefix.replace("/", "_");
|
||||
zipFileName = zipFileName + "_" + LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyyMMdd_HHmmss")) + ".zip";
|
||||
|
||||
try {
|
||||
response.setContentType("application/zip");
|
||||
response.setHeader(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + zipFileName + "\"");
|
||||
|
||||
try (ZipArchiveOutputStream zipOut = new ZipArchiveOutputStream(response.getOutputStream())) {
|
||||
zipDirectory(normalized, normalized, zipOut);
|
||||
zipOut.finish();
|
||||
}
|
||||
} catch (IOException e) {
|
||||
log.error("Failed to download directory {} for dataset {}", normalized, datasetId, e);
|
||||
throw BusinessException.of(SystemErrorCode.FILE_SYSTEM_ERROR);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 递归压缩目录
|
||||
*/
|
||||
private void zipDirectory(Path sourceDir, Path basePath, ZipArchiveOutputStream zipOut) throws IOException {
|
||||
try (Stream<Path> paths = Files.walk(sourceDir)) {
|
||||
paths.filter(path -> !Files.isDirectory(path))
|
||||
.forEach(path -> {
|
||||
try {
|
||||
Path relativePath = basePath.relativize(path);
|
||||
ZipArchiveEntry zipEntry = new ZipArchiveEntry(relativePath.toString());
|
||||
zipOut.putArchiveEntry(zipEntry);
|
||||
try (InputStream fis = Files.newInputStream(path)) {
|
||||
IOUtils.copy(fis, zipOut);
|
||||
}
|
||||
zipOut.closeArchiveEntry();
|
||||
} catch (IOException e) {
|
||||
log.error("Failed to add file to zip: {}", path, e);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 删除目录及其所有内容
|
||||
*/
|
||||
@Transactional
|
||||
public void deleteDirectory(String datasetId, String prefix) {
|
||||
Dataset dataset = datasetRepository.getById(datasetId);
|
||||
if (dataset == null) {
|
||||
throw BusinessException.of(DataManagementErrorCode.DATASET_NOT_FOUND);
|
||||
}
|
||||
|
||||
prefix = Optional.ofNullable(prefix).orElse("").trim();
|
||||
prefix = prefix.replace("\\", "/");
|
||||
while (prefix.startsWith("/")) {
|
||||
prefix = prefix.substring(1);
|
||||
}
|
||||
while (prefix.endsWith("/")) {
|
||||
prefix = prefix.substring(0, prefix.length() - 1);
|
||||
}
|
||||
|
||||
if (prefix.isEmpty()) {
|
||||
throw BusinessException.of(CommonErrorCode.PARAM_ERROR);
|
||||
}
|
||||
|
||||
String datasetPath = dataset.getPath();
|
||||
Path basePath = Paths.get(datasetPath);
|
||||
Path targetPath = basePath.resolve(prefix);
|
||||
Path normalized = targetPath.normalize();
|
||||
|
||||
if (!normalized.startsWith(basePath)) {
|
||||
throw BusinessException.of(CommonErrorCode.PARAM_ERROR);
|
||||
}
|
||||
|
||||
if (!Files.exists(normalized) || !Files.isDirectory(normalized)) {
|
||||
throw BusinessException.of(DataManagementErrorCode.DIRECTORY_NOT_FOUND);
|
||||
}
|
||||
|
||||
// 删除数据库中该目录下的所有文件记录(基于数据集内相对路径判断)
|
||||
String datasetPathNorm = datasetPath.replace("\\", "/");
|
||||
String logicalPrefix = prefix; // 已经去掉首尾斜杠
|
||||
List<DatasetFile> filesToDelete = datasetFileRepository.findAllByDatasetId(datasetId).stream()
|
||||
.filter(file -> {
|
||||
if (file.getFilePath() == null) {
|
||||
return false;
|
||||
}
|
||||
String filePath = file.getFilePath().replace("\\", "/");
|
||||
if (!filePath.startsWith(datasetPathNorm)) {
|
||||
return false;
|
||||
}
|
||||
String relative = filePath.substring(datasetPathNorm.length());
|
||||
while (relative.startsWith("/")) {
|
||||
relative = relative.substring(1);
|
||||
}
|
||||
return relative.equals(logicalPrefix) || relative.startsWith(logicalPrefix + "/");
|
||||
})
|
||||
.collect(Collectors.toList());
|
||||
|
||||
for (DatasetFile file : filesToDelete) {
|
||||
datasetFileRepository.removeById(file.getId());
|
||||
}
|
||||
|
||||
// 删除文件系统中的目录
|
||||
try {
|
||||
deleteDirectoryRecursively(normalized);
|
||||
} catch (IOException e) {
|
||||
log.error("Failed to delete directory {} for dataset {}", normalized, datasetId, e);
|
||||
throw BusinessException.of(SystemErrorCode.FILE_SYSTEM_ERROR);
|
||||
}
|
||||
|
||||
// 更新数据集
|
||||
dataset.setFiles(filesToDelete);
|
||||
for (DatasetFile file : filesToDelete) {
|
||||
dataset.removeFile(file);
|
||||
}
|
||||
datasetRepository.updateById(dataset);
|
||||
}
|
||||
|
||||
/**
|
||||
* 递归删除目录
|
||||
*/
|
||||
private void deleteDirectoryRecursively(Path directory) throws IOException {
|
||||
try (Stream<Path> paths = Files.walk(directory)) {
|
||||
paths.sorted(Comparator.reverseOrder())
|
||||
.forEach(path -> {
|
||||
try {
|
||||
Files.delete(path);
|
||||
} catch (IOException e) {
|
||||
log.error("Failed to delete: {}", path, e);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 为数据集文件设置文件id
|
||||
*
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.datamate.datamanagement.domain.model.dataset;
|
||||
|
||||
import com.baomidou.mybatisplus.annotation.TableField;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
@@ -38,6 +39,14 @@ public class DatasetFile {
|
||||
private LocalDateTime createdAt;
|
||||
private LocalDateTime updatedAt;
|
||||
|
||||
/** 标记是否为目录(非持久化字段) */
|
||||
@TableField(exist = false)
|
||||
private Boolean directory;
|
||||
|
||||
/** 目录包含的文件数量(非持久化字段) */
|
||||
@TableField(exist = false)
|
||||
private Long fileCount;
|
||||
|
||||
/**
|
||||
* 解析标签
|
||||
*
|
||||
|
||||
@@ -18,4 +18,7 @@ public class DatasetFileUploadCheckInfo {
|
||||
|
||||
/** 是否为压缩包上传 */
|
||||
private boolean hasArchive;
|
||||
|
||||
/** 目标子目录前缀,例如 "images/",为空表示数据集根目录 */
|
||||
private String prefix;
|
||||
}
|
||||
|
||||
@@ -34,9 +34,13 @@ public enum DataManagementErrorCode implements ErrorCode {
|
||||
*/
|
||||
DATASET_TAG_ALREADY_EXISTS("data_management.0005", "数据集标签已存在"),
|
||||
/**
|
||||
* 数据集标签已存在
|
||||
* 数据集文件已存在
|
||||
*/
|
||||
DATASET_FILE_ALREADY_EXISTS("data_management.0006", "数据集文件已存在");
|
||||
DATASET_FILE_ALREADY_EXISTS("data_management.0006", "数据集文件已存在"),
|
||||
/**
|
||||
* 目录不存在
|
||||
*/
|
||||
DIRECTORY_NOT_FOUND("data_management.0007", "目录不存在");
|
||||
|
||||
private final String code;
|
||||
private final String message;
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
package com.datamate.datamanagement.interfaces.dto;
|
||||
|
||||
import jakarta.validation.constraints.NotBlank;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
/**
|
||||
* 创建数据集子目录请求
|
||||
*/
|
||||
@Getter
|
||||
@Setter
|
||||
public class CreateDirectoryRequest {
|
||||
|
||||
/** 父级前缀路径,例如 "images/",为空表示数据集根目录 */
|
||||
private String parentPrefix;
|
||||
|
||||
/** 新建目录名称 */
|
||||
@NotBlank
|
||||
private String directoryName;
|
||||
}
|
||||
@@ -33,4 +33,8 @@ public class DatasetFileResponse {
|
||||
private LocalDateTime lastAccessTime;
|
||||
/** 上传者 */
|
||||
private String uploadedBy;
|
||||
/** 是否为目录 */
|
||||
private Boolean directory;
|
||||
/** 目录文件数量 */
|
||||
private Long fileCount;
|
||||
}
|
||||
|
||||
@@ -19,4 +19,7 @@ public class UploadFilesPreRequest {
|
||||
|
||||
/** 总文件大小 */
|
||||
private long totalSize;
|
||||
|
||||
/** 目标子目录前缀,例如 "images/",为空表示数据集根目录 */
|
||||
private String prefix;
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import com.datamate.datamanagement.domain.model.dataset.DatasetFile;
|
||||
import com.datamate.datamanagement.interfaces.converter.DatasetConverter;
|
||||
import com.datamate.datamanagement.interfaces.dto.AddFilesRequest;
|
||||
import com.datamate.datamanagement.interfaces.dto.CopyFilesRequest;
|
||||
import com.datamate.datamanagement.interfaces.dto.CreateDirectoryRequest;
|
||||
import com.datamate.datamanagement.interfaces.dto.DatasetFileResponse;
|
||||
import com.datamate.datamanagement.interfaces.dto.UploadFileRequest;
|
||||
import com.datamate.datamanagement.interfaces.dto.UploadFilesPreRequest;
|
||||
@@ -162,4 +163,35 @@ public class DatasetFileController {
|
||||
List<DatasetFile> datasetFiles = datasetFileApplicationService.addFilesToDataset(datasetId, req);
|
||||
return DatasetConverter.INSTANCE.convertToResponseList(datasetFiles);
|
||||
}
|
||||
|
||||
/**
|
||||
* 在数据集下创建子目录
|
||||
*/
|
||||
@PostMapping("/directories")
|
||||
public ResponseEntity<Void> createDirectory(@PathVariable("datasetId") String datasetId,
|
||||
@RequestBody @Valid CreateDirectoryRequest req) {
|
||||
datasetFileApplicationService.createDirectory(datasetId, req);
|
||||
return ResponseEntity.ok().build();
|
||||
}
|
||||
|
||||
/**
|
||||
* 下载目录(压缩为 ZIP)
|
||||
*/
|
||||
@IgnoreResponseWrap
|
||||
@GetMapping(value = "/directories/download", produces = "application/zip")
|
||||
public void downloadDirectory(@PathVariable("datasetId") String datasetId,
|
||||
@RequestParam(value = "prefix", required = false, defaultValue = "") String prefix,
|
||||
HttpServletResponse response) {
|
||||
datasetFileApplicationService.downloadDirectory(datasetId, prefix, response);
|
||||
}
|
||||
|
||||
/**
|
||||
* 删除目录及其所有内容
|
||||
*/
|
||||
@DeleteMapping("/directories")
|
||||
public ResponseEntity<Void> deleteDirectory(@PathVariable("datasetId") String datasetId,
|
||||
@RequestParam(value = "prefix", required = false, defaultValue = "") String prefix) {
|
||||
datasetFileApplicationService.deleteDirectory(datasetId, prefix);
|
||||
return ResponseEntity.ok().build();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import lombok.Getter;
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public enum CommonErrorCode implements ErrorCode{
|
||||
PARAM_ERROR("common.0001", "参数错误"),
|
||||
PRE_UPLOAD_REQUEST_NOT_EXIST("common.0101", "预上传请求不存在");
|
||||
private final String code;
|
||||
private final String message;
|
||||
|
||||
@@ -199,15 +199,11 @@ function CardView<T extends BaseCardDataType>(props: CardViewProps<T>) {
|
||||
? ""
|
||||
: "bg-gradient-to-br from-sky-300 to-blue-500 text-white"
|
||||
}`}
|
||||
style={{
|
||||
...(item?.iconColor
|
||||
style={
|
||||
item?.iconColor
|
||||
? { backgroundColor: item.iconColor }
|
||||
: {}),
|
||||
backgroundImage:
|
||||
"linear-gradient(180deg, rgba(255,255,255,0.35), rgba(255,255,255,0.05))",
|
||||
boxShadow:
|
||||
"inset 0 0 0 1px rgba(255,255,255,0.25)",
|
||||
}}
|
||||
: {}
|
||||
}
|
||||
>
|
||||
<div className="w-[2.1rem] h-[2.1rem] text-gray-50">{item?.icon}</div>
|
||||
</div>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import React, { useCallback, useEffect } from "react";
|
||||
import { Button, Input, Table } from "antd";
|
||||
import { Button, Input, Table, message } from "antd";
|
||||
import { RightOutlined } from "@ant-design/icons";
|
||||
import { mapDataset } from "@/pages/DataManagement/dataset.const";
|
||||
import {
|
||||
@@ -20,6 +20,7 @@ interface DatasetFileTransferProps
|
||||
selectedFilesMap: { [key: string]: DatasetFile };
|
||||
onSelectedFilesChange: (filesMap: { [key: string]: DatasetFile }) => void;
|
||||
onDatasetSelect?: (dataset: Dataset | null) => void;
|
||||
datasetTypeFilter?: DatasetType;
|
||||
}
|
||||
|
||||
const fileCols = [
|
||||
@@ -50,6 +51,7 @@ const DatasetFileTransfer: React.FC<DatasetFileTransferProps> = ({
|
||||
selectedFilesMap,
|
||||
onSelectedFilesChange,
|
||||
onDatasetSelect,
|
||||
datasetTypeFilter = DatasetType.TEXT,
|
||||
...props
|
||||
}) => {
|
||||
const [datasets, setDatasets] = React.useState<Dataset[]>([]);
|
||||
@@ -75,6 +77,7 @@ const DatasetFileTransfer: React.FC<DatasetFileTransferProps> = ({
|
||||
const [datasetSelections, setDatasetSelections] = React.useState<Dataset[]>(
|
||||
[]
|
||||
);
|
||||
const [selectingAll, setSelectingAll] = React.useState<boolean>(false);
|
||||
|
||||
const fetchDatasets = async () => {
|
||||
const { data } = await queryDatasetsUsingGet({
|
||||
@@ -82,7 +85,7 @@ const DatasetFileTransfer: React.FC<DatasetFileTransferProps> = ({
|
||||
page: datasetPagination.current,
|
||||
size: datasetPagination.pageSize,
|
||||
keyword: datasetSearch,
|
||||
type: DatasetType.TEXT,
|
||||
type: datasetTypeFilter,
|
||||
});
|
||||
setDatasets(data.content.map(mapDataset) || []);
|
||||
setDatasetPagination((prev) => ({
|
||||
@@ -116,7 +119,8 @@ const DatasetFileTransfer: React.FC<DatasetFileTransferProps> = ({
|
||||
setFiles(
|
||||
(data.content || []).map((item: DatasetFile) => ({
|
||||
...item,
|
||||
key: item.id,
|
||||
id: item.id,
|
||||
key: String(item.id), // rowKey 使用字符串,确保与 selectedRowKeys 类型一致
|
||||
datasetName: selectedDataset.name,
|
||||
}))
|
||||
);
|
||||
@@ -134,7 +138,8 @@ const DatasetFileTransfer: React.FC<DatasetFileTransferProps> = ({
|
||||
// 当数据集变化时,重置文件分页并拉取第一页文件,避免额外的循环请求
|
||||
if (selectedDataset) {
|
||||
setFilesPagination({ current: 1, pageSize: 10, total: 0 });
|
||||
fetchFiles({ page: 1, pageSize: 10 }).catch(() => {});
|
||||
// 后端 page 参数为 0-based,这里传 0 获取第一页
|
||||
fetchFiles({ page: 0, pageSize: 10 }).catch(() => {});
|
||||
} else {
|
||||
setFiles([]);
|
||||
setFilesPagination({ current: 1, pageSize: 10, total: 0 });
|
||||
@@ -147,6 +152,73 @@ const DatasetFileTransfer: React.FC<DatasetFileTransferProps> = ({
|
||||
onDatasetSelect?.(selectedDataset);
|
||||
}, [selectedDataset, onDatasetSelect]);
|
||||
|
||||
const handleSelectAllInDataset = useCallback(async () => {
|
||||
if (!selectedDataset) {
|
||||
message.warning("请先选择一个数据集");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
setSelectingAll(true);
|
||||
|
||||
const pageSize = 1000; // 分批拉取,避免后端单页限制
|
||||
let page = 0; // 后端 page 参数为 0-based,从 0 开始
|
||||
let total = 0;
|
||||
const allFiles: DatasetFile[] = [];
|
||||
|
||||
while (true) {
|
||||
const { data } = await queryDatasetFilesUsingGet(selectedDataset.id, {
|
||||
page,
|
||||
size: pageSize,
|
||||
});
|
||||
|
||||
const content: DatasetFile[] = (data.content || []).map(
|
||||
(item: DatasetFile) => ({
|
||||
...item,
|
||||
key: item.id,
|
||||
datasetName: selectedDataset.name,
|
||||
}),
|
||||
);
|
||||
|
||||
if (!content.length) {
|
||||
break;
|
||||
}
|
||||
|
||||
allFiles.push(...content);
|
||||
// 优先用后端的 totalElements,否则用当前累积数
|
||||
total = typeof data.totalElements === "number" ? data.totalElements : allFiles.length;
|
||||
|
||||
// 如果这一页数量小于 pageSize,说明已经拿完;否则继续下一页
|
||||
if (content.length < pageSize) {
|
||||
break;
|
||||
}
|
||||
|
||||
page += 1;
|
||||
}
|
||||
|
||||
const newMap: { [key: string]: DatasetFile } = { ...selectedFilesMap };
|
||||
allFiles.forEach((file) => {
|
||||
if (file && file.id != null) {
|
||||
newMap[String(file.id)] = file;
|
||||
}
|
||||
});
|
||||
|
||||
onSelectedFilesChange(newMap);
|
||||
|
||||
const count = total || allFiles.length;
|
||||
if (count > 0) {
|
||||
message.success(`已选中当前数据集的全部 ${count} 个文件`);
|
||||
} else {
|
||||
message.info("当前数据集下没有可选文件");
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to select all files in dataset", error);
|
||||
message.error("全选整个数据集失败,请稍后重试");
|
||||
} finally {
|
||||
setSelectingAll(false);
|
||||
}
|
||||
}, [selectedDataset, selectedFilesMap, onSelectedFilesChange]);
|
||||
|
||||
const toggleSelectFile = (record: DatasetFile) => {
|
||||
if (!selectedFilesMap[record.id]) {
|
||||
onSelectedFilesChange({
|
||||
@@ -245,7 +317,18 @@ const DatasetFileTransfer: React.FC<DatasetFileTransferProps> = ({
|
||||
</div>
|
||||
<RightOutlined />
|
||||
<div className="border-card flex flex-col col-span-12">
|
||||
<div className="border-bottom p-2 font-bold">选择文件</div>
|
||||
<div className="border-bottom p-2 font-bold flex justify-between items-center">
|
||||
<span>选择文件</span>
|
||||
<Button
|
||||
type="link"
|
||||
size="small"
|
||||
onClick={handleSelectAllInDataset}
|
||||
disabled={!selectedDataset}
|
||||
loading={selectingAll}
|
||||
>
|
||||
全选当前数据集
|
||||
</Button>
|
||||
</div>
|
||||
<div className="p-2">
|
||||
<Input
|
||||
placeholder="搜索文件名称..."
|
||||
@@ -255,7 +338,7 @@ const DatasetFileTransfer: React.FC<DatasetFileTransferProps> = ({
|
||||
</div>
|
||||
<Table
|
||||
scroll={{ y: 400 }}
|
||||
rowKey="id"
|
||||
rowKey={(record) => String(record.id)}
|
||||
size="small"
|
||||
dataSource={files}
|
||||
columns={fileCols.slice(1, fileCols.length)}
|
||||
@@ -268,7 +351,8 @@ const DatasetFileTransfer: React.FC<DatasetFileTransferProps> = ({
|
||||
current: page,
|
||||
pageSize: nextPageSize,
|
||||
}));
|
||||
fetchFiles({ page, pageSize: nextPageSize }).catch(() => {});
|
||||
// 前端分页是 1-based,后端是 0-based,所以这里传 page - 1
|
||||
fetchFiles({ page: page - 1, pageSize: nextPageSize }).catch(() => {});
|
||||
},
|
||||
}}
|
||||
onRow={(record: DatasetFile) => ({
|
||||
@@ -277,31 +361,22 @@ const DatasetFileTransfer: React.FC<DatasetFileTransferProps> = ({
|
||||
rowSelection={{
|
||||
type: "checkbox",
|
||||
selectedRowKeys: Object.keys(selectedFilesMap),
|
||||
preserveSelectedRowKeys: true,
|
||||
|
||||
// 单选
|
||||
onSelect: (record: DatasetFile) => {
|
||||
toggleSelectFile(record);
|
||||
},
|
||||
|
||||
// 全选
|
||||
// 全选 - 改为全选整个数据集而不是当前页
|
||||
onSelectAll: (selected, selectedRows: DatasetFile[]) => {
|
||||
if (selected) {
|
||||
// ✔ 全选 -> 将 files 列表全部加入 selectedFilesMap
|
||||
const newMap: Record<string, DatasetFile> = { ...selectedFilesMap };
|
||||
selectedRows.forEach((f) => {
|
||||
newMap[f.id] = f;
|
||||
});
|
||||
onSelectedFilesChange(newMap);
|
||||
// 点击表头“全选”时,改为一键全选当前数据集的全部文件
|
||||
// 而不是只选中当前页
|
||||
handleSelectAllInDataset();
|
||||
} else {
|
||||
// ✘ 取消全选 -> 清空 map
|
||||
const newMap = { ...selectedFilesMap };
|
||||
Object.keys(newMap).forEach((id) => {
|
||||
if (files.some((f) => String(f.id) === id)) {
|
||||
// 仅移除当前页对应文件
|
||||
delete newMap[id];
|
||||
}
|
||||
});
|
||||
onSelectedFilesChange(newMap);
|
||||
// 取消表头“全选”时,清空当前已选文件
|
||||
onSelectedFilesChange({});
|
||||
}
|
||||
},
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ export function useFileSliceUpload(
|
||||
size: 0,
|
||||
updateEvent: detail.updateEvent,
|
||||
hasArchive: detail.hasArchive,
|
||||
prefix: detail.prefix,
|
||||
};
|
||||
taskListRef.current = [task, ...taskListRef.current];
|
||||
|
||||
@@ -55,7 +56,14 @@ export function useFileSliceUpload(
|
||||
if (task.isCancel && task.cancelFn) {
|
||||
task.cancelFn();
|
||||
}
|
||||
if (task.updateEvent) window.dispatchEvent(new Event(task.updateEvent));
|
||||
if (task.updateEvent) {
|
||||
// 携带前缀信息,便于刷新后仍停留在当前目录
|
||||
window.dispatchEvent(
|
||||
new CustomEvent(task.updateEvent, {
|
||||
detail: { prefix: (task as any).prefix },
|
||||
})
|
||||
);
|
||||
}
|
||||
if (showTaskCenter) {
|
||||
window.dispatchEvent(
|
||||
new CustomEvent("show:task-popover", { detail: { show: false } })
|
||||
@@ -109,12 +117,15 @@ export function useFileSliceUpload(
|
||||
}
|
||||
|
||||
async function uploadFile({ task, files, totalSize }) {
|
||||
console.log('[useSliceUpload] Calling preUpload with prefix:', task.prefix);
|
||||
const { data: reqId } = await preUpload(task.key, {
|
||||
totalFileNum: files.length,
|
||||
totalSize,
|
||||
datasetId: task.key,
|
||||
hasArchive: task.hasArchive,
|
||||
prefix: task.prefix,
|
||||
});
|
||||
console.log('[useSliceUpload] PreUpload response reqId:', reqId);
|
||||
|
||||
const newTask: TaskItem = {
|
||||
...task,
|
||||
|
||||
@@ -0,0 +1,302 @@
|
||||
import { useState, useEffect } from "react";
|
||||
import { Card, Button, Table, message, Modal, Tag, Progress, Space, Tooltip } from "antd";
|
||||
import {
|
||||
PlusOutlined,
|
||||
DeleteOutlined,
|
||||
DownloadOutlined,
|
||||
ReloadOutlined,
|
||||
EyeOutlined,
|
||||
} from "@ant-design/icons";
|
||||
import type { ColumnType } from "antd/es/table";
|
||||
import type { AutoAnnotationTask, AutoAnnotationStatus } from "../annotation.model";
|
||||
import {
|
||||
queryAutoAnnotationTasksUsingGet,
|
||||
deleteAutoAnnotationTaskByIdUsingDelete,
|
||||
downloadAutoAnnotationResultUsingGet,
|
||||
} from "../annotation.api";
|
||||
import CreateAutoAnnotationDialog from "./components/CreateAutoAnnotationDialog";
|
||||
|
||||
const STATUS_COLORS: Record<AutoAnnotationStatus, string> = {
|
||||
pending: "default",
|
||||
running: "processing",
|
||||
completed: "success",
|
||||
failed: "error",
|
||||
cancelled: "default",
|
||||
};
|
||||
|
||||
const STATUS_LABELS: Record<AutoAnnotationStatus, string> = {
|
||||
pending: "等待中",
|
||||
running: "处理中",
|
||||
completed: "已完成",
|
||||
failed: "失败",
|
||||
cancelled: "已取消",
|
||||
};
|
||||
|
||||
const MODEL_SIZE_LABELS: Record<string, string> = {
|
||||
n: "YOLOv8n (最快)",
|
||||
s: "YOLOv8s",
|
||||
m: "YOLOv8m",
|
||||
l: "YOLOv8l (推荐)",
|
||||
x: "YOLOv8x (最精确)",
|
||||
};
|
||||
|
||||
export default function AutoAnnotation() {
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [tasks, setTasks] = useState<AutoAnnotationTask[]>([]);
|
||||
const [showCreateDialog, setShowCreateDialog] = useState(false);
|
||||
const [selectedRowKeys, setSelectedRowKeys] = useState<string[]>([]);
|
||||
|
||||
useEffect(() => {
|
||||
fetchTasks();
|
||||
const interval = setInterval(() => {
|
||||
fetchTasks(true);
|
||||
}, 3000);
|
||||
return () => clearInterval(interval);
|
||||
}, []);
|
||||
|
||||
const fetchTasks = async (silent = false) => {
|
||||
if (!silent) setLoading(true);
|
||||
try {
|
||||
const response = await queryAutoAnnotationTasksUsingGet();
|
||||
setTasks(response.data || response || []);
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch auto annotation tasks:", error);
|
||||
if (!silent) message.error("获取任务列表失败");
|
||||
} finally {
|
||||
if (!silent) setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDelete = (task: AutoAnnotationTask) => {
|
||||
Modal.confirm({
|
||||
title: `确认删除自动标注任务「${task.name}」吗?`,
|
||||
content: "删除任务后,已生成的标注结果不会被删除。",
|
||||
okText: "删除",
|
||||
okType: "danger",
|
||||
cancelText: "取消",
|
||||
onOk: async () => {
|
||||
try {
|
||||
await deleteAutoAnnotationTaskByIdUsingDelete(task.id);
|
||||
message.success("任务删除成功");
|
||||
fetchTasks();
|
||||
setSelectedRowKeys((keys) => keys.filter((k) => k !== task.id));
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
message.error("删除失败,请稍后重试");
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
const handleDownload = async (task: AutoAnnotationTask) => {
|
||||
try {
|
||||
message.loading("正在准备下载...", 0);
|
||||
await downloadAutoAnnotationResultUsingGet(task.id);
|
||||
message.destroy();
|
||||
message.success("下载已开始");
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
message.destroy();
|
||||
message.error("下载失败");
|
||||
}
|
||||
};
|
||||
|
||||
const handleViewResult = (task: AutoAnnotationTask) => {
|
||||
if (task.outputPath) {
|
||||
Modal.info({
|
||||
title: "标注结果路径",
|
||||
content: (
|
||||
<div>
|
||||
<p>输出路径:{task.outputPath}</p>
|
||||
<p>检测对象数:{task.detectedObjects}</p>
|
||||
<p>
|
||||
处理图片数:{task.processedImages} / {task.totalImages}
|
||||
</p>
|
||||
</div>
|
||||
),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const columns: ColumnType<AutoAnnotationTask>[] = [
|
||||
{ title: "任务名称", dataIndex: "name", key: "name", width: 200 },
|
||||
{
|
||||
title: "数据集",
|
||||
dataIndex: "datasetName",
|
||||
key: "datasetName",
|
||||
width: 220,
|
||||
render: (_: any, record: AutoAnnotationTask) => {
|
||||
const list =
|
||||
record.sourceDatasets && record.sourceDatasets.length > 0
|
||||
? record.sourceDatasets
|
||||
: record.datasetName
|
||||
? [record.datasetName]
|
||||
: [];
|
||||
|
||||
if (list.length === 0) return "-";
|
||||
|
||||
const text = list.join(",");
|
||||
return (
|
||||
<Tooltip title={text}>
|
||||
<span>{text}</span>
|
||||
</Tooltip>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: "模型",
|
||||
dataIndex: ["config", "modelSize"],
|
||||
key: "modelSize",
|
||||
width: 120,
|
||||
render: (size: string) => MODEL_SIZE_LABELS[size] || size,
|
||||
},
|
||||
{
|
||||
title: "置信度",
|
||||
dataIndex: ["config", "confThreshold"],
|
||||
key: "confThreshold",
|
||||
width: 100,
|
||||
render: (threshold: number) => `${(threshold * 100).toFixed(0)}%`,
|
||||
},
|
||||
{
|
||||
title: "目标类别",
|
||||
dataIndex: ["config", "targetClasses"],
|
||||
key: "targetClasses",
|
||||
width: 120,
|
||||
render: (classes: number[]) => (
|
||||
<Tooltip
|
||||
title={classes.length > 0 ? classes.join(", ") : "全部类别"}
|
||||
>
|
||||
<span>
|
||||
{classes.length > 0
|
||||
? `${classes.length} 个类别`
|
||||
: "全部类别"}
|
||||
</span>
|
||||
</Tooltip>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: "状态",
|
||||
dataIndex: "status",
|
||||
key: "status",
|
||||
width: 100,
|
||||
render: (status: AutoAnnotationStatus) => (
|
||||
<Tag color={STATUS_COLORS[status]}>{STATUS_LABELS[status]}</Tag>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: "进度",
|
||||
dataIndex: "progress",
|
||||
key: "progress",
|
||||
width: 150,
|
||||
render: (progress: number, record: AutoAnnotationTask) => (
|
||||
<div>
|
||||
<Progress percent={progress} size="small" />
|
||||
<div style={{ fontSize: "12px", color: "#999" }}>
|
||||
{record.processedImages} / {record.totalImages}
|
||||
</div>
|
||||
</div>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: "检测对象数",
|
||||
dataIndex: "detectedObjects",
|
||||
key: "detectedObjects",
|
||||
width: 100,
|
||||
render: (count: number) => count.toLocaleString(),
|
||||
},
|
||||
{
|
||||
title: "创建时间",
|
||||
dataIndex: "createdAt",
|
||||
key: "createdAt",
|
||||
width: 150,
|
||||
render: (time: string) => new Date(time).toLocaleString(),
|
||||
},
|
||||
{
|
||||
title: "操作",
|
||||
key: "actions",
|
||||
width: 180,
|
||||
fixed: "right",
|
||||
render: (_: any, record: AutoAnnotationTask) => (
|
||||
<Space size="small">
|
||||
{record.status === "completed" && (
|
||||
<>
|
||||
<Tooltip title="查看结果">
|
||||
<Button
|
||||
type="link"
|
||||
size="small"
|
||||
icon={<EyeOutlined />}
|
||||
onClick={() => handleViewResult(record)}
|
||||
/>
|
||||
</Tooltip>
|
||||
<Tooltip title="下载结果">
|
||||
<Button
|
||||
type="link"
|
||||
size="small"
|
||||
icon={<DownloadOutlined />}
|
||||
onClick={() => handleDownload(record)}
|
||||
/>
|
||||
</Tooltip>
|
||||
</>
|
||||
)}
|
||||
<Tooltip title="删除">
|
||||
<Button
|
||||
type="link"
|
||||
size="small"
|
||||
danger
|
||||
icon={<DeleteOutlined />}
|
||||
onClick={() => handleDelete(record)}
|
||||
/>
|
||||
</Tooltip>
|
||||
</Space>
|
||||
),
|
||||
},
|
||||
];
|
||||
|
||||
return (
|
||||
<div>
|
||||
<Card
|
||||
title="自动标注任务"
|
||||
extra={
|
||||
<Space>
|
||||
<Button
|
||||
type="primary"
|
||||
icon={<PlusOutlined />}
|
||||
onClick={() => setShowCreateDialog(true)}
|
||||
>
|
||||
创建任务
|
||||
</Button>
|
||||
<Button
|
||||
icon={<ReloadOutlined />}
|
||||
loading={loading}
|
||||
onClick={() => fetchTasks()}
|
||||
>
|
||||
刷新
|
||||
</Button>
|
||||
</Space>
|
||||
}
|
||||
>
|
||||
<Table
|
||||
rowKey="id"
|
||||
loading={loading}
|
||||
columns={columns}
|
||||
dataSource={tasks}
|
||||
rowSelection={{
|
||||
selectedRowKeys,
|
||||
onChange: (keys) => setSelectedRowKeys(keys as string[]),
|
||||
}}
|
||||
pagination={{ pageSize: 10 }}
|
||||
scroll={{ x: 1000 }}
|
||||
/>
|
||||
</Card>
|
||||
|
||||
<CreateAutoAnnotationDialog
|
||||
visible={showCreateDialog}
|
||||
onCancel={() => setShowCreateDialog(false)}
|
||||
onSuccess={() => {
|
||||
setShowCreateDialog(false);
|
||||
fetchTasks();
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,286 @@
|
||||
import { useState, useEffect } from "react";
|
||||
import { Modal, Form, Input, Select, Slider, message, Checkbox } from "antd";
|
||||
import { createAutoAnnotationTaskUsingPost } from "../../annotation.api";
|
||||
import { queryDatasetsUsingGet } from "@/pages/DataManagement/dataset.api";
|
||||
import { mapDataset } from "@/pages/DataManagement/dataset.const";
|
||||
import { DatasetType, type DatasetFile, type Dataset } from "@/pages/DataManagement/dataset.model";
|
||||
import DatasetFileTransfer from "@/components/business/DatasetFileTransfer";
|
||||
|
||||
const { Option } = Select;
|
||||
|
||||
interface CreateAutoAnnotationDialogProps {
|
||||
visible: boolean;
|
||||
onCancel: () => void;
|
||||
onSuccess: () => void;
|
||||
}
|
||||
|
||||
const COCO_CLASSES = [
|
||||
{ id: 0, name: "person", label: "人" },
|
||||
{ id: 1, name: "bicycle", label: "自行车" },
|
||||
{ id: 2, name: "car", label: "汽车" },
|
||||
{ id: 3, name: "motorcycle", label: "摩托车" },
|
||||
{ id: 4, name: "airplane", label: "飞机" },
|
||||
{ id: 5, name: "bus", label: "公交车" },
|
||||
{ id: 6, name: "train", label: "火车" },
|
||||
{ id: 7, name: "truck", label: "卡车" },
|
||||
{ id: 8, name: "boat", label: "船" },
|
||||
{ id: 9, name: "traffic light", label: "红绿灯" },
|
||||
{ id: 10, name: "fire hydrant", label: "消防栓" },
|
||||
{ id: 11, name: "stop sign", label: "停止标志" },
|
||||
{ id: 12, name: "parking meter", label: "停车计时器" },
|
||||
{ id: 13, name: "bench", label: "长椅" },
|
||||
{ id: 14, name: "bird", label: "鸟" },
|
||||
{ id: 15, name: "cat", label: "猫" },
|
||||
{ id: 16, name: "dog", label: "狗" },
|
||||
{ id: 17, name: "horse", label: "马" },
|
||||
{ id: 18, name: "sheep", label: "羊" },
|
||||
{ id: 19, name: "cow", label: "牛" },
|
||||
{ id: 20, name: "elephant", label: "大象" },
|
||||
{ id: 21, name: "bear", label: "熊" },
|
||||
{ id: 22, name: "zebra", label: "斑马" },
|
||||
{ id: 23, name: "giraffe", label: "长颈鹿" },
|
||||
{ id: 24, name: "backpack", label: "背包" },
|
||||
{ id: 25, name: "umbrella", label: "雨伞" },
|
||||
{ id: 26, name: "handbag", label: "手提包" },
|
||||
{ id: 27, name: "tie", label: "领带" },
|
||||
{ id: 28, name: "suitcase", label: "行李箱" },
|
||||
{ id: 29, name: "frisbee", label: "飞盘" },
|
||||
{ id: 30, name: "skis", label: "滑雪板" },
|
||||
{ id: 31, name: "snowboard", label: "滑雪板" },
|
||||
{ id: 32, name: "sports ball", label: "球类" },
|
||||
{ id: 33, name: "kite", label: "风筝" },
|
||||
{ id: 34, name: "baseball bat", label: "棒球棒" },
|
||||
{ id: 35, name: "baseball glove", label: "棒球手套" },
|
||||
{ id: 36, name: "skateboard", label: "滑板" },
|
||||
{ id: 37, name: "surfboard", label: "冲浪板" },
|
||||
{ id: 38, name: "tennis racket", label: "网球拍" },
|
||||
{ id: 39, name: "bottle", label: "瓶子" },
|
||||
{ id: 40, name: "wine glass", label: "酒杯" },
|
||||
{ id: 41, name: "cup", label: "杯子" },
|
||||
{ id: 42, name: "fork", label: "叉子" },
|
||||
{ id: 43, name: "knife", label: "刀" },
|
||||
{ id: 44, name: "spoon", label: "勺子" },
|
||||
{ id: 45, name: "bowl", label: "碗" },
|
||||
{ id: 46, name: "banana", label: "香蕉" },
|
||||
{ id: 47, name: "apple", label: "苹果" },
|
||||
{ id: 48, name: "sandwich", label: "三明治" },
|
||||
{ id: 49, name: "orange", label: "橙子" },
|
||||
{ id: 50, name: "broccoli", label: "西兰花" },
|
||||
{ id: 51, name: "carrot", label: "胡萝卜" },
|
||||
{ id: 52, name: "hot dog", label: "热狗" },
|
||||
{ id: 53, name: "pizza", label: "披萨" },
|
||||
{ id: 54, name: "donut", label: "甜甜圈" },
|
||||
{ id: 55, name: "cake", label: "蛋糕" },
|
||||
{ id: 56, name: "chair", label: "椅子" },
|
||||
{ id: 57, name: "couch", label: "沙发" },
|
||||
{ id: 58, name: "potted plant", label: "盆栽" },
|
||||
{ id: 59, name: "bed", label: "床" },
|
||||
{ id: 60, name: "dining table", label: "餐桌" },
|
||||
{ id: 61, name: "toilet", label: "马桶" },
|
||||
{ id: 62, name: "tv", label: "电视" },
|
||||
{ id: 63, name: "laptop", label: "笔记本电脑" },
|
||||
{ id: 64, name: "mouse", label: "鼠标" },
|
||||
{ id: 65, name: "remote", label: "遥控器" },
|
||||
{ id: 66, name: "keyboard", label: "键盘" },
|
||||
{ id: 67, name: "cell phone", label: "手机" },
|
||||
{ id: 68, name: "microwave", label: "微波炉" },
|
||||
{ id: 69, name: "oven", label: "烤箱" },
|
||||
{ id: 70, name: "toaster", label: "烤面包机" },
|
||||
{ id: 71, name: "sink", label: "水槽" },
|
||||
{ id: 72, name: "refrigerator", label: "冰箱" },
|
||||
{ id: 73, name: "book", label: "书" },
|
||||
{ id: 74, name: "clock", label: "钟表" },
|
||||
{ id: 75, name: "vase", label: "花瓶" },
|
||||
{ id: 76, name: "scissors", label: "剪刀" },
|
||||
{ id: 77, name: "teddy bear", label: "玩具熊" },
|
||||
{ id: 78, name: "hair drier", label: "吹风机" },
|
||||
{ id: 79, name: "toothbrush", label: "牙刷" },
|
||||
];
|
||||
|
||||
export default function CreateAutoAnnotationDialog({
|
||||
visible,
|
||||
onCancel,
|
||||
onSuccess,
|
||||
}: CreateAutoAnnotationDialogProps) {
|
||||
const [form] = Form.useForm();
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [datasets, setDatasets] = useState<any[]>([]);
|
||||
const [selectAllClasses, setSelectAllClasses] = useState(true);
|
||||
const [selectedFilesMap, setSelectedFilesMap] = useState<Record<string, DatasetFile>>({});
|
||||
const [selectedDataset, setSelectedDataset] = useState<Dataset | null>(null);
|
||||
const [imageFileCount, setImageFileCount] = useState(0);
|
||||
|
||||
useEffect(() => {
|
||||
if (visible) {
|
||||
fetchDatasets();
|
||||
form.resetFields();
|
||||
form.setFieldsValue({
|
||||
modelSize: "l",
|
||||
confThreshold: 0.7,
|
||||
targetClasses: [],
|
||||
});
|
||||
}
|
||||
}, [visible, form]);
|
||||
|
||||
const fetchDatasets = async () => {
|
||||
try {
|
||||
const { data } = await queryDatasetsUsingGet({
|
||||
page: 0,
|
||||
pageSize: 1000,
|
||||
});
|
||||
const imageDatasets = (data.content || [])
|
||||
.map(mapDataset)
|
||||
.filter((ds: any) => ds.datasetType === DatasetType.IMAGE);
|
||||
setDatasets(imageDatasets);
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch datasets:", error);
|
||||
message.error("获取数据集列表失败");
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const imageExtensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff", ".webp"];
|
||||
const count = Object.values(selectedFilesMap).filter((file) => {
|
||||
const ext = file.fileName?.toLowerCase().match(/\.[^.]+$/)?.[0] || "";
|
||||
return imageExtensions.includes(ext);
|
||||
}).length;
|
||||
setImageFileCount(count);
|
||||
}, [selectedFilesMap]);
|
||||
|
||||
const handleSubmit = async () => {
|
||||
try {
|
||||
const values = await form.validateFields();
|
||||
|
||||
if (imageFileCount === 0) {
|
||||
message.error("请至少选择一个图像文件");
|
||||
return;
|
||||
}
|
||||
|
||||
setLoading(true);
|
||||
|
||||
const imageExtensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff", ".webp"];
|
||||
const imageFileIds = Object.values(selectedFilesMap)
|
||||
.filter((file) => {
|
||||
const ext = file.fileName?.toLowerCase().match(/\.[^.]+$/)?.[0] || "";
|
||||
return imageExtensions.includes(ext);
|
||||
})
|
||||
.map((file) => file.id);
|
||||
|
||||
const payload = {
|
||||
name: values.name,
|
||||
datasetId: values.datasetId,
|
||||
fileIds: imageFileIds,
|
||||
config: {
|
||||
modelSize: values.modelSize,
|
||||
confThreshold: values.confThreshold,
|
||||
targetClasses: selectAllClasses ? [] : values.targetClasses || [],
|
||||
outputDatasetName: values.outputDatasetName || undefined,
|
||||
},
|
||||
};
|
||||
|
||||
await createAutoAnnotationTaskUsingPost(payload);
|
||||
message.success("自动标注任务创建成功");
|
||||
onSuccess();
|
||||
} catch (error: any) {
|
||||
if (error.errorFields) return;
|
||||
console.error("Failed to create auto annotation task:", error);
|
||||
message.error(error.message || "创建任务失败");
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleClassSelectionChange = (checked: boolean) => {
|
||||
setSelectAllClasses(checked);
|
||||
if (checked) {
|
||||
form.setFieldsValue({ targetClasses: [] });
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Modal
|
||||
title="创建自动标注任务"
|
||||
open={visible}
|
||||
onCancel={onCancel}
|
||||
onOk={handleSubmit}
|
||||
confirmLoading={loading}
|
||||
width={600}
|
||||
destroyOnClose
|
||||
>
|
||||
<Form form={form} layout="vertical" preserve={false}>
|
||||
<Form.Item
|
||||
name="name"
|
||||
label="任务名称"
|
||||
rules={[
|
||||
{ required: true, message: "请输入任务名称" },
|
||||
{ max: 100, message: "任务名称不能超过100个字符" },
|
||||
]}
|
||||
>
|
||||
<Input placeholder="请输入任务名称" />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item label="选择数据集和图像文件" required>
|
||||
<DatasetFileTransfer
|
||||
open
|
||||
selectedFilesMap={selectedFilesMap}
|
||||
onSelectedFilesChange={setSelectedFilesMap}
|
||||
onDatasetSelect={(dataset) => {
|
||||
setSelectedDataset(dataset);
|
||||
form.setFieldsValue({ datasetId: dataset?.id ?? "" });
|
||||
}}
|
||||
datasetTypeFilter={DatasetType.IMAGE}
|
||||
/>
|
||||
{selectedDataset && (
|
||||
<div className="mt-2 p-2 bg-blue-50 rounded border border-blue-200 text-xs">
|
||||
当前数据集:<span className="font-medium">{selectedDataset.name}</span> - 已选择
|
||||
<span className="font-medium text-blue-600"> {imageFileCount} </span>个图像文件
|
||||
</div>
|
||||
)}
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item hidden name="datasetId" rules={[{ required: true, message: "请选择数据集" }]}>
|
||||
<Input type="hidden" />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item name="modelSize" label="模型规模" rules={[{ required: true, message: "请选择模型规模" }]}>
|
||||
<Select>
|
||||
<Option value="n">YOLOv8n (最快)</Option>
|
||||
<Option value="s">YOLOv8s</Option>
|
||||
<Option value="m">YOLOv8m</Option>
|
||||
<Option value="l">YOLOv8l (推荐)</Option>
|
||||
<Option value="x">YOLOv8x (最精确)</Option>
|
||||
</Select>
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
name="confThreshold"
|
||||
label="置信度阈值"
|
||||
rules={[{ required: true, message: "请选择置信度阈值" }]}
|
||||
>
|
||||
<Slider min={0.1} max={0.9} step={0.05} tooltip={{ formatter: (v) => `${(v || 0) * 100}%` }} />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item label="目标类别">
|
||||
<Checkbox checked={selectAllClasses} onChange={(e) => handleClassSelectionChange(e.target.checked)}>
|
||||
选中所有类别
|
||||
</Checkbox>
|
||||
{!selectAllClasses && (
|
||||
<Form.Item name="targetClasses" noStyle>
|
||||
<Select mode="multiple" placeholder="选择目标类别" style={{ marginTop: 8 }}>
|
||||
{COCO_CLASSES.map((cls) => (
|
||||
<Option key={cls.id} value={cls.id}>
|
||||
{cls.label} ({cls.name})
|
||||
</Option>
|
||||
))}
|
||||
</Select>
|
||||
</Form.Item>
|
||||
)}
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item name="outputDatasetName" label="输出数据集名称 (可选)">
|
||||
<Input placeholder="留空则将结果写入原数据集的标签中" />
|
||||
</Form.Item>
|
||||
</Form>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
export { default } from "./AutoAnnotation";
|
||||
@@ -1,12 +1,102 @@
|
||||
import { queryDatasetsUsingGet } from "@/pages/DataManagement/dataset.api";
|
||||
import { mapDataset } from "@/pages/DataManagement/dataset.const";
|
||||
import { Button, Form, Input, Modal, Select, message } from "antd";
|
||||
import { Button, Form, Input, Modal, Select, message, Tabs, Slider, Checkbox } from "antd";
|
||||
import TextArea from "antd/es/input/TextArea";
|
||||
import { useEffect, useState } from "react";
|
||||
import { createAnnotationTaskUsingPost, queryAnnotationTemplatesUsingGet } from "../../annotation.api";
|
||||
import { Dataset } from "@/pages/DataManagement/dataset.model";
|
||||
import {
|
||||
createAnnotationTaskUsingPost,
|
||||
queryAnnotationTemplatesUsingGet,
|
||||
createAutoAnnotationTaskUsingPost,
|
||||
} from "../../annotation.api";
|
||||
import DatasetFileTransfer from "@/components/business/DatasetFileTransfer";
|
||||
import { DatasetType, type Dataset, type DatasetFile } from "@/pages/DataManagement/dataset.model";
|
||||
import type { AnnotationTemplate } from "../../annotation.model";
|
||||
|
||||
const { Option } = Select;
|
||||
|
||||
const COCO_CLASSES = [
|
||||
{ id: 0, name: "person", label: "人" },
|
||||
{ id: 1, name: "bicycle", label: "自行车" },
|
||||
{ id: 2, name: "car", label: "汽车" },
|
||||
{ id: 3, name: "motorcycle", label: "摩托车" },
|
||||
{ id: 4, name: "airplane", label: "飞机" },
|
||||
{ id: 5, name: "bus", label: "公交车" },
|
||||
{ id: 6, name: "train", label: "火车" },
|
||||
{ id: 7, name: "truck", label: "卡车" },
|
||||
{ id: 8, name: "boat", label: "船" },
|
||||
{ id: 9, name: "traffic light", label: "红绿灯" },
|
||||
{ id: 10, name: "fire hydrant", label: "消防栓" },
|
||||
{ id: 11, name: "stop sign", label: "停止标志" },
|
||||
{ id: 12, name: "parking meter", label: "停车计时器" },
|
||||
{ id: 13, name: "bench", label: "长椅" },
|
||||
{ id: 14, name: "bird", label: "鸟" },
|
||||
{ id: 15, name: "cat", label: "猫" },
|
||||
{ id: 16, name: "dog", label: "狗" },
|
||||
{ id: 17, name: "horse", label: "马" },
|
||||
{ id: 18, name: "sheep", label: "羊" },
|
||||
{ id: 19, name: "cow", label: "牛" },
|
||||
{ id: 20, name: "elephant", label: "大象" },
|
||||
{ id: 21, name: "bear", label: "熊" },
|
||||
{ id: 22, name: "zebra", label: "斑马" },
|
||||
{ id: 23, name: "giraffe", label: "长颈鹿" },
|
||||
{ id: 24, name: "backpack", label: "背包" },
|
||||
{ id: 25, name: "umbrella", label: "雨伞" },
|
||||
{ id: 26, name: "handbag", label: "手提包" },
|
||||
{ id: 27, name: "tie", label: "领带" },
|
||||
{ id: 28, name: "suitcase", label: "行李箱" },
|
||||
{ id: 29, name: "frisbee", label: "飞盘" },
|
||||
{ id: 30, name: "skis", label: "滑雪板" },
|
||||
{ id: 31, name: "snowboard", label: "滑雪板" },
|
||||
{ id: 32, name: "sports ball", label: "球类" },
|
||||
{ id: 33, name: "kite", label: "风筝" },
|
||||
{ id: 34, name: "baseball bat", label: "棒球棒" },
|
||||
{ id: 35, name: "baseball glove", label: "棒球手套" },
|
||||
{ id: 36, name: "skateboard", label: "滑板" },
|
||||
{ id: 37, name: "surfboard", label: "冲浪板" },
|
||||
{ id: 38, name: "tennis racket", label: "网球拍" },
|
||||
{ id: 39, name: "bottle", label: "瓶子" },
|
||||
{ id: 40, name: "wine glass", label: "酒杯" },
|
||||
{ id: 41, name: "cup", label: "杯子" },
|
||||
{ id: 42, name: "fork", label: "叉子" },
|
||||
{ id: 43, name: "knife", label: "刀" },
|
||||
{ id: 44, name: "spoon", label: "勺子" },
|
||||
{ id: 45, name: "bowl", label: "碗" },
|
||||
{ id: 46, name: "banana", label: "香蕉" },
|
||||
{ id: 47, name: "apple", label: "苹果" },
|
||||
{ id: 48, name: "sandwich", label: "三明治" },
|
||||
{ id: 49, name: "orange", label: "橙子" },
|
||||
{ id: 50, name: "broccoli", label: "西兰花" },
|
||||
{ id: 51, name: "carrot", label: "胡萝卜" },
|
||||
{ id: 52, name: "hot dog", label: "热狗" },
|
||||
{ id: 53, name: "pizza", label: "披萨" },
|
||||
{ id: 54, name: "donut", label: "甜甜圈" },
|
||||
{ id: 55, name: "cake", label: "蛋糕" },
|
||||
{ id: 56, name: "chair", label: "椅子" },
|
||||
{ id: 57, name: "couch", label: "沙发" },
|
||||
{ id: 58, name: "potted plant", label: "盆栽" },
|
||||
{ id: 59, name: "bed", label: "床" },
|
||||
{ id: 60, name: "dining table", label: "餐桌" },
|
||||
{ id: 61, name: "toilet", label: "马桶" },
|
||||
{ id: 62, name: "tv", label: "电视" },
|
||||
{ id: 63, name: "laptop", label: "笔记本电脑" },
|
||||
{ id: 64, name: "mouse", label: "鼠标" },
|
||||
{ id: 65, name: "remote", label: "遥控器" },
|
||||
{ id: 66, name: "keyboard", label: "键盘" },
|
||||
{ id: 67, name: "cell phone", label: "手机" },
|
||||
{ id: 68, name: "microwave", label: "微波炉" },
|
||||
{ id: 69, name: "oven", label: "烤箱" },
|
||||
{ id: 70, name: "toaster", label: "烤面包机" },
|
||||
{ id: 71, name: "sink", label: "水槽" },
|
||||
{ id: 72, name: "refrigerator", label: "冰箱" },
|
||||
{ id: 73, name: "book", label: "书" },
|
||||
{ id: 74, name: "clock", label: "钟表" },
|
||||
{ id: 75, name: "vase", label: "花瓶" },
|
||||
{ id: 76, name: "scissors", label: "剪刀" },
|
||||
{ id: 77, name: "teddy bear", label: "玩具熊" },
|
||||
{ id: 78, name: "hair drier", label: "吹风机" },
|
||||
{ id: 79, name: "toothbrush", label: "牙刷" },
|
||||
];
|
||||
|
||||
export default function CreateAnnotationTask({
|
||||
open,
|
||||
onClose,
|
||||
@@ -16,11 +106,18 @@ export default function CreateAnnotationTask({
|
||||
onClose: () => void;
|
||||
onRefresh: () => void;
|
||||
}) {
|
||||
const [form] = Form.useForm();
|
||||
const [manualForm] = Form.useForm();
|
||||
const [autoForm] = Form.useForm();
|
||||
const [datasets, setDatasets] = useState<Dataset[]>([]);
|
||||
const [templates, setTemplates] = useState<AnnotationTemplate[]>([]);
|
||||
const [submitting, setSubmitting] = useState(false);
|
||||
const [nameManuallyEdited, setNameManuallyEdited] = useState(false);
|
||||
const [activeMode, setActiveMode] = useState<"manual" | "auto">("manual");
|
||||
|
||||
const [selectAllClasses, setSelectAllClasses] = useState(true);
|
||||
const [selectedFilesMap, setSelectedFilesMap] = useState<Record<string, DatasetFile>>({});
|
||||
const [selectedDataset, setSelectedDataset] = useState<Dataset | null>(null);
|
||||
const [imageFileCount, setImageFileCount] = useState(0);
|
||||
|
||||
useEffect(() => {
|
||||
if (!open) return;
|
||||
@@ -59,14 +156,29 @@ export default function CreateAnnotationTask({
|
||||
// Reset form and manual-edit flag when modal opens
|
||||
useEffect(() => {
|
||||
if (open) {
|
||||
form.resetFields();
|
||||
manualForm.resetFields();
|
||||
autoForm.resetFields();
|
||||
setNameManuallyEdited(false);
|
||||
setActiveMode("manual");
|
||||
setSelectAllClasses(true);
|
||||
setSelectedFilesMap({});
|
||||
setSelectedDataset(null);
|
||||
setImageFileCount(0);
|
||||
}
|
||||
}, [open, form]);
|
||||
}, [open, manualForm, autoForm]);
|
||||
|
||||
const handleSubmit = async () => {
|
||||
useEffect(() => {
|
||||
const imageExtensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff", ".webp"];
|
||||
const count = Object.values(selectedFilesMap).filter((file) => {
|
||||
const ext = file.fileName?.toLowerCase().match(/\.[^.]+$/)?.[0] || "";
|
||||
return imageExtensions.includes(ext);
|
||||
}).length;
|
||||
setImageFileCount(count);
|
||||
}, [selectedFilesMap]);
|
||||
|
||||
const handleManualSubmit = async () => {
|
||||
try {
|
||||
const values = await form.validateFields();
|
||||
const values = await manualForm.validateFields();
|
||||
setSubmitting(true);
|
||||
// Send templateId instead of labelingConfig
|
||||
const requestData = {
|
||||
@@ -88,6 +200,58 @@ export default function CreateAnnotationTask({
|
||||
}
|
||||
};
|
||||
|
||||
const handleAutoSubmit = async () => {
|
||||
try {
|
||||
const values = await autoForm.validateFields();
|
||||
|
||||
if (imageFileCount === 0) {
|
||||
message.error("请至少选择一个图像文件");
|
||||
return;
|
||||
}
|
||||
|
||||
setSubmitting(true);
|
||||
|
||||
const imageExtensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff", ".webp"];
|
||||
const imageFileIds = Object.values(selectedFilesMap)
|
||||
.filter((file) => {
|
||||
const ext = file.fileName?.toLowerCase().match(/\.[^.]+$/)?.[0] || "";
|
||||
return imageExtensions.includes(ext);
|
||||
})
|
||||
.map((file) => file.id);
|
||||
|
||||
const payload = {
|
||||
name: values.name,
|
||||
datasetId: values.datasetId,
|
||||
fileIds: imageFileIds,
|
||||
config: {
|
||||
modelSize: values.modelSize,
|
||||
confThreshold: values.confThreshold,
|
||||
targetClasses: selectAllClasses ? [] : values.targetClasses || [],
|
||||
outputDatasetName: values.outputDatasetName || undefined,
|
||||
},
|
||||
};
|
||||
|
||||
await createAutoAnnotationTaskUsingPost(payload);
|
||||
message.success("自动标注任务创建成功");
|
||||
// 触发上层刷新自动标注任务列表
|
||||
(onRefresh as any)?.("auto");
|
||||
onClose();
|
||||
} catch (error: any) {
|
||||
if (error.errorFields) return;
|
||||
console.error("Failed to create auto annotation task:", error);
|
||||
message.error(error.message || "创建自动标注任务失败");
|
||||
} finally {
|
||||
setSubmitting(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleClassSelectionChange = (checked: boolean) => {
|
||||
setSelectAllClasses(checked);
|
||||
if (checked) {
|
||||
autoForm.setFieldsValue({ targetClasses: [] });
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Modal
|
||||
open={open}
|
||||
@@ -98,14 +262,26 @@ export default function CreateAnnotationTask({
|
||||
<Button onClick={onClose} disabled={submitting}>
|
||||
取消
|
||||
</Button>
|
||||
<Button type="primary" onClick={handleSubmit} loading={submitting}>
|
||||
<Button
|
||||
type="primary"
|
||||
onClick={activeMode === "manual" ? handleManualSubmit : handleAutoSubmit}
|
||||
loading={submitting}
|
||||
>
|
||||
确定
|
||||
</Button>
|
||||
</>
|
||||
}
|
||||
width={800}
|
||||
>
|
||||
<Form form={form} layout="vertical">
|
||||
<Tabs
|
||||
activeKey={activeMode}
|
||||
onChange={(key) => setActiveMode(key as "manual" | "auto")}
|
||||
items={[
|
||||
{
|
||||
key: "manual",
|
||||
label: "手动标注",
|
||||
children: (
|
||||
<Form form={manualForm} layout="vertical">
|
||||
{/* 数据集 与 标注工程名称 并排显示(数据集在左) */}
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<Form.Item
|
||||
@@ -134,7 +310,11 @@ export default function CreateAnnotationTask({
|
||||
if (!nameManuallyEdited) {
|
||||
const ds = datasets.find((d) => d.id === value);
|
||||
if (ds) {
|
||||
form.setFieldsValue({ name: ds.name });
|
||||
let defaultName = ds.name || "";
|
||||
if (defaultName.length < 3) {
|
||||
defaultName = `${defaultName}-标注`;
|
||||
}
|
||||
manualForm.setFieldsValue({ name: defaultName });
|
||||
}
|
||||
}
|
||||
}}
|
||||
@@ -144,7 +324,22 @@ export default function CreateAnnotationTask({
|
||||
<Form.Item
|
||||
label="标注工程名称"
|
||||
name="name"
|
||||
rules={[{ required: true, message: "请输入任务名称" }]}
|
||||
rules={[
|
||||
{
|
||||
validator: (_rule, value) => {
|
||||
const trimmed = (value || "").trim();
|
||||
if (!trimmed) {
|
||||
return Promise.reject(new Error("请输入任务名称"));
|
||||
}
|
||||
if (trimmed.length < 3) {
|
||||
return Promise.reject(
|
||||
new Error("任务名称至少需要 3 个字符(不含首尾空格,Label Studio 限制)"),
|
||||
);
|
||||
}
|
||||
return Promise.resolve();
|
||||
},
|
||||
},
|
||||
]}
|
||||
>
|
||||
<Input
|
||||
placeholder="输入标注工程名称"
|
||||
@@ -187,6 +382,108 @@ export default function CreateAnnotationTask({
|
||||
/>
|
||||
</Form.Item>
|
||||
</Form>
|
||||
),
|
||||
},
|
||||
{
|
||||
key: "auto",
|
||||
label: "自动标注",
|
||||
children: (
|
||||
<Form form={autoForm} layout="vertical" preserve={false}>
|
||||
<Form.Item
|
||||
name="name"
|
||||
label="任务名称"
|
||||
rules={[
|
||||
{ required: true, message: "请输入任务名称" },
|
||||
{ max: 100, message: "任务名称不能超过100个字符" },
|
||||
]}
|
||||
>
|
||||
<Input placeholder="请输入任务名称" />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item label="选择数据集和图像文件" required>
|
||||
<DatasetFileTransfer
|
||||
open
|
||||
selectedFilesMap={selectedFilesMap}
|
||||
onSelectedFilesChange={setSelectedFilesMap}
|
||||
onDatasetSelect={(dataset) => {
|
||||
setSelectedDataset(dataset as Dataset | null);
|
||||
autoForm.setFieldsValue({ datasetId: dataset?.id ?? "" });
|
||||
}}
|
||||
datasetTypeFilter={DatasetType.IMAGE}
|
||||
/>
|
||||
{selectedDataset && (
|
||||
<div className="mt-2 p-2 bg-blue-50 rounded border border-blue-200 text-xs">
|
||||
当前数据集:<span className="font-medium">{selectedDataset.name}</span> - 已选择
|
||||
<span className="font-medium text-blue-600"> {imageFileCount} </span>个图像文件
|
||||
</div>
|
||||
)}
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
hidden
|
||||
name="datasetId"
|
||||
rules={[{ required: true, message: "请选择数据集" }]}
|
||||
>
|
||||
<Input type="hidden" />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
name="modelSize"
|
||||
label="模型规模"
|
||||
rules={[{ required: true, message: "请选择模型规模" }]}
|
||||
initialValue="l"
|
||||
>
|
||||
<Select>
|
||||
<Option value="n">YOLOv8n (最快)</Option>
|
||||
<Option value="s">YOLOv8s</Option>
|
||||
<Option value="m">YOLOv8m</Option>
|
||||
<Option value="l">YOLOv8l (推荐)</Option>
|
||||
<Option value="x">YOLOv8x (最精确)</Option>
|
||||
</Select>
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
name="confThreshold"
|
||||
label="置信度阈值"
|
||||
rules={[{ required: true, message: "请选择置信度阈值" }]}
|
||||
initialValue={0.7}
|
||||
>
|
||||
<Slider
|
||||
min={0.1}
|
||||
max={0.9}
|
||||
step={0.05}
|
||||
tooltip={{ formatter: (v) => `${(v || 0) * 100}%` }}
|
||||
/>
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item label="目标类别">
|
||||
<Checkbox
|
||||
checked={selectAllClasses}
|
||||
onChange={(e) => handleClassSelectionChange(e.target.checked)}
|
||||
>
|
||||
选中所有类别
|
||||
</Checkbox>
|
||||
{!selectAllClasses && (
|
||||
<Form.Item name="targetClasses" noStyle>
|
||||
<Select mode="multiple" placeholder="选择目标类别" style={{ marginTop: 8 }}>
|
||||
{COCO_CLASSES.map((cls) => (
|
||||
<Option key={cls.id} value={cls.id}>
|
||||
{cls.label} ({cls.name})
|
||||
</Option>
|
||||
))}
|
||||
</Select>
|
||||
</Form.Item>
|
||||
)}
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item name="outputDatasetName" label="输出数据集名称 (可选)">
|
||||
<Input placeholder="留空则将结果写入原数据集的标签中" />
|
||||
</Form.Item>
|
||||
</Form>
|
||||
),
|
||||
},
|
||||
]}
|
||||
/>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useState, useEffect } from "react";
|
||||
import { Card, Button, Table, message, Modal, Tabs } from "antd";
|
||||
import { Card, Button, Table, message, Modal, Tabs, Tag, Progress, Tooltip } from "antd";
|
||||
import {
|
||||
PlusOutlined,
|
||||
EditOutlined,
|
||||
@@ -11,9 +11,11 @@ import CardView from "@/components/CardView";
|
||||
import type { AnnotationTask } from "../annotation.model";
|
||||
import useFetchData from "@/hooks/useFetchData";
|
||||
import {
|
||||
deleteAnnotationTaskByIdUsingDelete, loginAnnotationUsingGet,
|
||||
deleteAnnotationTaskByIdUsingDelete,
|
||||
queryAnnotationTasksUsingGet,
|
||||
syncAnnotationTaskUsingPost,
|
||||
queryAutoAnnotationTasksUsingGet,
|
||||
deleteAutoAnnotationTaskByIdUsingDelete,
|
||||
} from "../annotation.api";
|
||||
import { mapAnnotationTask } from "../annotation.const";
|
||||
import CreateAnnotationTask from "../Create/components/CreateAnnotationTaskDialog";
|
||||
@@ -21,11 +23,28 @@ import { ColumnType } from "antd/es/table";
|
||||
import { TemplateList } from "../Template";
|
||||
// Note: DevelopmentInProgress intentionally not used here
|
||||
|
||||
const AUTO_STATUS_LABELS: Record<string, string> = {
|
||||
pending: "等待中",
|
||||
running: "处理中",
|
||||
completed: "已完成",
|
||||
failed: "失败",
|
||||
cancelled: "已取消",
|
||||
};
|
||||
|
||||
const AUTO_MODEL_SIZE_LABELS: Record<string, string> = {
|
||||
n: "YOLOv8n (最快)",
|
||||
s: "YOLOv8s",
|
||||
m: "YOLOv8m",
|
||||
l: "YOLOv8l (推荐)",
|
||||
x: "YOLOv8x (最精确)",
|
||||
};
|
||||
|
||||
export default function DataAnnotation() {
|
||||
// return <DevelopmentInProgress showTime="2025.10.30" />;
|
||||
const [activeTab, setActiveTab] = useState("tasks");
|
||||
const [viewMode, setViewMode] = useState<"list" | "card">("list");
|
||||
const [showCreateDialog, setShowCreateDialog] = useState(false);
|
||||
const [autoTasks, setAutoTasks] = useState<any[]>([]);
|
||||
|
||||
const {
|
||||
loading,
|
||||
@@ -41,6 +60,22 @@ export default function DataAnnotation() {
|
||||
const [selectedRowKeys, setSelectedRowKeys] = useState<(string | number)[]>([]);
|
||||
const [selectedRows, setSelectedRows] = useState<any[]>([]);
|
||||
|
||||
// 拉取自动标注任务(供轮询和创建成功后立即刷新复用)
|
||||
const refreshAutoTasks = async (silent = false) => {
|
||||
try {
|
||||
const response = await queryAutoAnnotationTasksUsingGet();
|
||||
const tasks = (response as any)?.data || response || [];
|
||||
if (Array.isArray(tasks)) {
|
||||
setAutoTasks(tasks);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch auto annotation tasks:", error);
|
||||
if (!silent) {
|
||||
message.error("获取自动标注任务失败");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// prefetch config on mount so clicking annotate is fast and we know whether base URL exists
|
||||
// useEffect ensures this runs once
|
||||
useEffect(() => {
|
||||
@@ -58,6 +93,16 @@ export default function DataAnnotation() {
|
||||
};
|
||||
}, []);
|
||||
|
||||
// 自动标注任务轮询(用于在同一表格中展示处理进度)
|
||||
useEffect(() => {
|
||||
refreshAutoTasks();
|
||||
const timer = setInterval(() => refreshAutoTasks(true), 3000);
|
||||
|
||||
return () => {
|
||||
clearInterval(timer);
|
||||
};
|
||||
}, []);
|
||||
|
||||
const handleAnnotate = (task: AnnotationTask) => {
|
||||
// Open Label Studio project page in a new tab
|
||||
(async () => {
|
||||
@@ -76,7 +121,6 @@ export default function DataAnnotation() {
|
||||
|
||||
if (labelingProjId) {
|
||||
// only open external Label Studio when we have a configured base url
|
||||
await loginAnnotationUsingGet(labelingProjId)
|
||||
if (base) {
|
||||
const target = `${base}/projects/${labelingProjId}/data`;
|
||||
window.open(target, "_blank");
|
||||
@@ -126,6 +170,30 @@ export default function DataAnnotation() {
|
||||
});
|
||||
};
|
||||
|
||||
const handleDeleteAuto = (task: any) => {
|
||||
Modal.confirm({
|
||||
title: `确认删除自动标注任务「${task.name}」吗?`,
|
||||
content: <div>删除任务后,已生成的标注结果不会被删除。</div>,
|
||||
okText: "删除",
|
||||
okType: "danger",
|
||||
cancelText: "取消",
|
||||
onOk: async () => {
|
||||
try {
|
||||
await deleteAutoAnnotationTaskByIdUsingDelete(task.id);
|
||||
message.success("自动标注任务删除成功");
|
||||
// 重新拉取自动标注任务
|
||||
setAutoTasks((prev) => prev.filter((t: any) => t.id !== task.id));
|
||||
// 清理选中
|
||||
setSelectedRowKeys((keys) => keys.filter((k) => k !== task.id));
|
||||
setSelectedRows((rows) => rows.filter((r) => r.id !== task.id));
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
message.error("删除失败,请稍后重试");
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
const handleSync = (task: AnnotationTask, batchSize: number = 50) => {
|
||||
Modal.confirm({
|
||||
title: `确认同步标注任务「${task.name}」吗?`,
|
||||
@@ -156,8 +224,13 @@ export default function DataAnnotation() {
|
||||
|
||||
const handleBatchSync = (batchSize: number = 50) => {
|
||||
if (!selectedRows || selectedRows.length === 0) return;
|
||||
const manualRows = selectedRows.filter((r) => r._kind !== "auto");
|
||||
if (manualRows.length === 0) {
|
||||
message.warning("请选择手动标注任务进行同步");
|
||||
return;
|
||||
}
|
||||
Modal.confirm({
|
||||
title: `确认同步所选 ${selectedRows.length} 个标注任务吗?`,
|
||||
title: `确认同步所选 ${manualRows.length} 个标注任务吗?`,
|
||||
content: (
|
||||
<div>
|
||||
<div>标注工程中文件列表将与数据集保持一致,差异项将会被修正。</div>
|
||||
@@ -169,7 +242,7 @@ export default function DataAnnotation() {
|
||||
onOk: async () => {
|
||||
try {
|
||||
await Promise.all(
|
||||
selectedRows.map((r) => syncAnnotationTaskUsingPost({ id: r.id, batchSize }))
|
||||
manualRows.map((r) => syncAnnotationTaskUsingPost({ id: r.id, batchSize }))
|
||||
);
|
||||
message.success("批量同步请求已发送");
|
||||
fetchData();
|
||||
@@ -185,6 +258,8 @@ export default function DataAnnotation() {
|
||||
|
||||
const handleBatchDelete = () => {
|
||||
if (!selectedRows || selectedRows.length === 0) return;
|
||||
const manualRows = selectedRows.filter((r) => r._kind !== "auto");
|
||||
const autoRows = selectedRows.filter((r) => r._kind === "auto");
|
||||
Modal.confirm({
|
||||
title: `确认删除所选 ${selectedRows.length} 个标注任务吗?`,
|
||||
content: (
|
||||
@@ -199,7 +274,10 @@ export default function DataAnnotation() {
|
||||
onOk: async () => {
|
||||
try {
|
||||
await Promise.all(
|
||||
selectedRows.map((r) => deleteAnnotationTaskByIdUsingDelete(r.id))
|
||||
[
|
||||
...manualRows.map((r) => deleteAnnotationTaskByIdUsingDelete(r.id)),
|
||||
...autoRows.map((r) => deleteAutoAnnotationTaskByIdUsingDelete(r.id)),
|
||||
]
|
||||
);
|
||||
message.success("批量删除已完成");
|
||||
fetchData();
|
||||
@@ -238,6 +316,38 @@ export default function DataAnnotation() {
|
||||
onClick: handleDelete,
|
||||
},
|
||||
];
|
||||
// 合并手动标注任务与自动标注任务
|
||||
const mergedTableData = [
|
||||
// 手动标注任务
|
||||
...tableData.map((task) => ({
|
||||
...task,
|
||||
_kind: "manual" as const,
|
||||
})),
|
||||
// 自动标注任务
|
||||
...autoTasks.map((task: any) => {
|
||||
const sourceList = Array.isArray(task.sourceDatasets)
|
||||
? task.sourceDatasets
|
||||
: task.datasetName
|
||||
? [task.datasetName]
|
||||
: [];
|
||||
const datasetName = sourceList.length > 0 ? sourceList.join(",") : "-";
|
||||
|
||||
return {
|
||||
id: task.id,
|
||||
name: task.name,
|
||||
datasetName,
|
||||
createdAt: task.createdAt || "-",
|
||||
updatedAt: task.updatedAt || "-",
|
||||
_kind: "auto" as const,
|
||||
autoStatus: task.status,
|
||||
autoProgress: task.progress,
|
||||
autoProcessedImages: task.processedImages,
|
||||
autoTotalImages: task.totalImages,
|
||||
autoDetectedObjects: task.detectedObjects,
|
||||
autoConfig: task.config || {},
|
||||
};
|
||||
}),
|
||||
];
|
||||
|
||||
const columns: ColumnType<any>[] = [
|
||||
{
|
||||
@@ -246,6 +356,13 @@ export default function DataAnnotation() {
|
||||
key: "name",
|
||||
fixed: "left" as const,
|
||||
},
|
||||
{
|
||||
title: "类型",
|
||||
key: "kind",
|
||||
width: 100,
|
||||
render: (_: any, record: any) =>
|
||||
record._kind === "auto" ? "自动标注" : "手动标注",
|
||||
},
|
||||
{
|
||||
title: "任务ID",
|
||||
dataIndex: "id",
|
||||
@@ -257,6 +374,88 @@ export default function DataAnnotation() {
|
||||
key: "datasetName",
|
||||
width: 180,
|
||||
},
|
||||
{
|
||||
title: "模型",
|
||||
key: "modelSize",
|
||||
width: 160,
|
||||
render: (_: any, record: any) => {
|
||||
if (record._kind !== "auto") return "-";
|
||||
const size = record.autoConfig?.modelSize;
|
||||
return AUTO_MODEL_SIZE_LABELS[size] || size || "-";
|
||||
},
|
||||
},
|
||||
{
|
||||
title: "置信度",
|
||||
key: "confThreshold",
|
||||
width: 120,
|
||||
render: (_: any, record: any) => {
|
||||
if (record._kind !== "auto") return "-";
|
||||
const threshold = record.autoConfig?.confThreshold;
|
||||
if (typeof threshold !== "number") return "-";
|
||||
return `${(threshold * 100).toFixed(0)}%`;
|
||||
},
|
||||
},
|
||||
{
|
||||
title: "目标类别",
|
||||
key: "targetClasses",
|
||||
width: 160,
|
||||
render: (_: any, record: any) => {
|
||||
if (record._kind !== "auto") return "-";
|
||||
const classes: number[] = record.autoConfig?.targetClasses || [];
|
||||
if (!classes.length) return "全部类别";
|
||||
const text = classes.join(", ");
|
||||
return (
|
||||
<Tooltip title={text}>
|
||||
<span>{`${classes.length} 个类别`}</span>
|
||||
</Tooltip>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: "自动标注状态",
|
||||
key: "autoStatus",
|
||||
width: 130,
|
||||
render: (_: any, record: any) => {
|
||||
if (record._kind !== "auto") return "-";
|
||||
const status = record.autoStatus as string;
|
||||
const label = AUTO_STATUS_LABELS[status] || status || "-";
|
||||
return <Tag>{label}</Tag>;
|
||||
},
|
||||
},
|
||||
{
|
||||
title: "自动标注进度",
|
||||
key: "autoProgress",
|
||||
width: 200,
|
||||
render: (_: any, record: any) => {
|
||||
if (record._kind !== "auto") return "-";
|
||||
const progress = typeof record.autoProgress === "number" ? record.autoProgress : 0;
|
||||
const processed = record.autoProcessedImages ?? 0;
|
||||
const total = record.autoTotalImages ?? 0;
|
||||
return (
|
||||
<div>
|
||||
<Progress percent={progress} size="small" />
|
||||
<div style={{ fontSize: 12, color: "#999" }}>
|
||||
{processed} / {total}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: "检测对象数",
|
||||
key: "detectedObjects",
|
||||
width: 120,
|
||||
render: (_: any, record: any) => {
|
||||
if (record._kind !== "auto") return "-";
|
||||
const count = record.autoDetectedObjects;
|
||||
if (typeof count !== "number") return "-";
|
||||
try {
|
||||
return count.toLocaleString();
|
||||
} catch {
|
||||
return String(count);
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
title: "创建时间",
|
||||
dataIndex: "createdAt",
|
||||
@@ -277,7 +476,8 @@ export default function DataAnnotation() {
|
||||
dataIndex: "actions",
|
||||
render: (_: any, task: any) => (
|
||||
<div className="flex items-center justify-center space-x-1">
|
||||
{operations.map((operation) => (
|
||||
{task._kind === "manual" &&
|
||||
operations.map((operation) => (
|
||||
<Button
|
||||
key={operation.key}
|
||||
type="text"
|
||||
@@ -286,6 +486,14 @@ export default function DataAnnotation() {
|
||||
title={operation.label}
|
||||
/>
|
||||
))}
|
||||
{task._kind === "auto" && (
|
||||
<Button
|
||||
type="text"
|
||||
icon={<DeleteOutlined style={{ color: "#f5222d" }} />}
|
||||
onClick={() => handleDeleteAuto(task)}
|
||||
title="删除自动标注任务"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
),
|
||||
},
|
||||
@@ -357,7 +565,7 @@ export default function DataAnnotation() {
|
||||
rowKey="id"
|
||||
loading={loading}
|
||||
columns={columns}
|
||||
dataSource={tableData}
|
||||
dataSource={mergedTableData}
|
||||
pagination={pagination}
|
||||
rowSelection={{
|
||||
selectedRowKeys,
|
||||
@@ -381,7 +589,14 @@ export default function DataAnnotation() {
|
||||
<CreateAnnotationTask
|
||||
open={showCreateDialog}
|
||||
onClose={() => setShowCreateDialog(false)}
|
||||
onRefresh={fetchData}
|
||||
onRefresh={(mode?: any) => {
|
||||
// 手动标注创建成功后刷新标注任务列表
|
||||
fetchData();
|
||||
// 自动标注创建成功后立即刷新自动标注任务列表
|
||||
if (mode === "auto") {
|
||||
refreshAutoTasks(true);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
),
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { get, post, put, del } from "@/utils/request";
|
||||
import { get, post, put, del, download } from "@/utils/request";
|
||||
|
||||
// 标注任务管理相关接口
|
||||
export function queryAnnotationTasksUsingGet(params?: any) {
|
||||
@@ -18,10 +18,6 @@ export function deleteAnnotationTaskByIdUsingDelete(mappingId: string) {
|
||||
return del(`/api/annotation/project/${mappingId}`);
|
||||
}
|
||||
|
||||
export function loginAnnotationUsingGet(mappingId: string) {
|
||||
return get("/api/annotation/project/${mappingId}/login");
|
||||
}
|
||||
|
||||
// 标签配置管理
|
||||
export function getTagConfigUsingGet() {
|
||||
return get("/api/annotation/tags/config");
|
||||
@@ -48,3 +44,24 @@ export function deleteAnnotationTemplateByIdUsingDelete(
|
||||
) {
|
||||
return del(`/api/annotation/template/${templateId}`);
|
||||
}
|
||||
|
||||
// 自动标注任务管理
|
||||
export function queryAutoAnnotationTasksUsingGet(params?: any) {
|
||||
return get("/api/annotation/auto", params);
|
||||
}
|
||||
|
||||
export function createAutoAnnotationTaskUsingPost(data: any) {
|
||||
return post("/api/annotation/auto", data);
|
||||
}
|
||||
|
||||
export function deleteAutoAnnotationTaskByIdUsingDelete(taskId: string) {
|
||||
return del(`/api/annotation/auto/${taskId}`);
|
||||
}
|
||||
|
||||
export function getAutoAnnotationTaskStatusUsingGet(taskId: string) {
|
||||
return get(`/api/annotation/auto/${taskId}/status`);
|
||||
}
|
||||
|
||||
export function downloadAutoAnnotationResultUsingGet(taskId: string) {
|
||||
return download(`/api/annotation/auto/${taskId}/download`);
|
||||
}
|
||||
|
||||
@@ -71,12 +71,21 @@ export default function DatasetDetail() {
|
||||
|
||||
useEffect(() => {
|
||||
fetchDataset();
|
||||
filesOperation.fetchFiles();
|
||||
filesOperation.fetchFiles('', 1, 10); // 从根目录开始,第一页
|
||||
}, []);
|
||||
|
||||
const handleRefresh = async (showMessage = true) => {
|
||||
const handleRefresh = async (showMessage = true, prefixOverride?: string) => {
|
||||
fetchDataset();
|
||||
filesOperation.fetchFiles();
|
||||
// 刷新当前目录,保持在当前页
|
||||
const targetPrefix =
|
||||
prefixOverride !== undefined
|
||||
? prefixOverride
|
||||
: filesOperation.pagination.prefix;
|
||||
filesOperation.fetchFiles(
|
||||
targetPrefix,
|
||||
filesOperation.pagination.current,
|
||||
filesOperation.pagination.pageSize
|
||||
);
|
||||
if (showMessage) message.success({ content: "数据刷新成功" });
|
||||
};
|
||||
|
||||
@@ -92,12 +101,17 @@ export default function DatasetDetail() {
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const refreshData = () => {
|
||||
handleRefresh(false);
|
||||
const refreshData = (e: Event) => {
|
||||
const custom = e as CustomEvent<{ prefix?: string }>;
|
||||
const prefixOverride = custom.detail?.prefix;
|
||||
handleRefresh(false, prefixOverride);
|
||||
};
|
||||
window.addEventListener("update:dataset", refreshData);
|
||||
window.addEventListener("update:dataset", refreshData as EventListener);
|
||||
return () => {
|
||||
window.removeEventListener("update:dataset", refreshData);
|
||||
window.removeEventListener(
|
||||
"update:dataset",
|
||||
refreshData as EventListener
|
||||
);
|
||||
};
|
||||
}, []);
|
||||
|
||||
@@ -232,6 +246,7 @@ export default function DatasetDetail() {
|
||||
data={dataset}
|
||||
open={showUploadDialog}
|
||||
onClose={() => setShowUploadDialog(false)}
|
||||
prefix={filesOperation.pagination.prefix}
|
||||
updateEvent="update:dataset"
|
||||
/>
|
||||
<EditDataset
|
||||
|
||||
@@ -13,17 +13,20 @@ export default function ImportConfiguration({
|
||||
open,
|
||||
onClose,
|
||||
updateEvent = "update:dataset",
|
||||
prefix,
|
||||
}: {
|
||||
data: Dataset | null;
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
updateEvent?: string;
|
||||
prefix?: string;
|
||||
}) {
|
||||
const [form] = Form.useForm();
|
||||
const [collectionOptions, setCollectionOptions] = useState([]);
|
||||
const [importConfig, setImportConfig] = useState<any>({
|
||||
source: DataSource.UPLOAD,
|
||||
});
|
||||
const [currentPrefix, setCurrentPrefix] = useState<string>("");
|
||||
|
||||
const [fileList, setFileList] = useState<UploadFile[]>([]);
|
||||
const fileSliceList = useMemo(() => {
|
||||
@@ -45,6 +48,7 @@ export default function ImportConfiguration({
|
||||
fileList.forEach((file) => {
|
||||
formData.append("file", file);
|
||||
});
|
||||
console.log('[ImportConfiguration] Uploading with currentPrefix:', currentPrefix);
|
||||
window.dispatchEvent(
|
||||
new CustomEvent("upload:dataset", {
|
||||
detail: {
|
||||
@@ -52,6 +56,7 @@ export default function ImportConfiguration({
|
||||
files: fileSliceList,
|
||||
updateEvent,
|
||||
hasArchive: importConfig.hasArchive,
|
||||
prefix: currentPrefix,
|
||||
},
|
||||
})
|
||||
);
|
||||
@@ -82,14 +87,17 @@ export default function ImportConfiguration({
|
||||
};
|
||||
|
||||
const resetState = () => {
|
||||
console.log('[ImportConfiguration] resetState called, preserving currentPrefix:', currentPrefix);
|
||||
form.resetFields();
|
||||
setFileList([]);
|
||||
form.setFieldsValue({ files: null });
|
||||
setImportConfig({ source: importConfig.source ? importConfig.source : DataSource.UPLOAD });
|
||||
console.log('[ImportConfiguration] resetState done, currentPrefix still:', currentPrefix);
|
||||
};
|
||||
|
||||
const handleImportData = async () => {
|
||||
if (!data) return;
|
||||
console.log('[ImportConfiguration] handleImportData called, currentPrefix:', currentPrefix);
|
||||
if (importConfig.source === DataSource.UPLOAD) {
|
||||
await handleUpload(data);
|
||||
} else if (importConfig.source === DataSource.COLLECTION) {
|
||||
@@ -102,10 +110,19 @@ export default function ImportConfiguration({
|
||||
|
||||
useEffect(() => {
|
||||
if (open) {
|
||||
setCurrentPrefix(prefix || "");
|
||||
console.log('[ImportConfiguration] Modal opened with prefix:', prefix);
|
||||
resetState();
|
||||
fetchCollectionTasks();
|
||||
}
|
||||
}, [open, importConfig.source]);
|
||||
}, [open]);
|
||||
|
||||
// Separate effect for fetching collection tasks when source changes
|
||||
useEffect(() => {
|
||||
if (open && importConfig.source === DataSource.COLLECTION) {
|
||||
fetchCollectionTasks();
|
||||
}
|
||||
}, [importConfig.source]);
|
||||
|
||||
return (
|
||||
<Modal
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import { Button, Descriptions, DescriptionsProps, Modal, Table } from "antd";
|
||||
import { App, Button, Descriptions, DescriptionsProps, Modal, Table, Input } from "antd";
|
||||
import { formatBytes, formatDateTime } from "@/utils/unit";
|
||||
import { Download, Trash2, Folder, File } from "lucide-react";
|
||||
import { datasetTypeMap } from "../../dataset.const";
|
||||
|
||||
export default function Overview({ dataset, filesOperation, fetchDataset }) {
|
||||
const { modal, message } = App.useApp();
|
||||
const {
|
||||
fileList,
|
||||
pagination,
|
||||
@@ -17,6 +18,9 @@ export default function Overview({ dataset, filesOperation, fetchDataset }) {
|
||||
handleDownloadFile,
|
||||
handleBatchDeleteFiles,
|
||||
handleBatchExport,
|
||||
handleCreateDirectory,
|
||||
handleDownloadDirectory,
|
||||
handleDeleteDirectory,
|
||||
} = filesOperation;
|
||||
|
||||
// 文件列表多选配置
|
||||
@@ -123,8 +127,9 @@ export default function Overview({ dataset, filesOperation, fetchDataset }) {
|
||||
type="link"
|
||||
onClick={(e) => {
|
||||
const currentPath = filesOperation.pagination.prefix || '';
|
||||
const newPath = `${currentPath}${record.fileName}`;
|
||||
filesOperation.fetchFiles(newPath);
|
||||
// 文件夹路径必须以斜杠结尾
|
||||
const newPath = `${currentPath}${record.fileName}/`;
|
||||
filesOperation.fetchFiles(newPath, 1, filesOperation.pagination.pageSize);
|
||||
}}
|
||||
>
|
||||
{content}
|
||||
@@ -150,11 +155,24 @@ export default function Overview({ dataset, filesOperation, fetchDataset }) {
|
||||
render: (text: number, record: any) => {
|
||||
const isDirectory = record.id.startsWith('directory-');
|
||||
if (isDirectory) {
|
||||
return "-";
|
||||
return formatBytes(record.fileSize || 0);
|
||||
}
|
||||
return formatBytes(text)
|
||||
},
|
||||
},
|
||||
{
|
||||
title: "包含文件数",
|
||||
dataIndex: "fileCount",
|
||||
key: "fileCount",
|
||||
width: 120,
|
||||
render: (text: number, record: any) => {
|
||||
const isDirectory = record.id.startsWith('directory-');
|
||||
if (!isDirectory) {
|
||||
return "-";
|
||||
}
|
||||
return record.fileCount ?? 0;
|
||||
},
|
||||
},
|
||||
{
|
||||
title: "上传时间",
|
||||
dataIndex: "uploadTime",
|
||||
@@ -169,9 +187,43 @@ export default function Overview({ dataset, filesOperation, fetchDataset }) {
|
||||
fixed: "right",
|
||||
render: (_, record) => {
|
||||
const isDirectory = record.id.startsWith('directory-');
|
||||
|
||||
if (isDirectory) {
|
||||
return <div className="flex"/>;
|
||||
const currentPath = filesOperation.pagination.prefix || '';
|
||||
const fullPath = `${currentPath}${record.fileName}/`;
|
||||
|
||||
return (
|
||||
<div className="flex">
|
||||
<Button
|
||||
size="small"
|
||||
type="link"
|
||||
onClick={() => handleDownloadDirectory(fullPath, record.fileName)}
|
||||
>
|
||||
下载
|
||||
</Button>
|
||||
<Button
|
||||
size="small"
|
||||
type="link"
|
||||
onClick={() => {
|
||||
modal.confirm({
|
||||
title: '确认删除文件夹?',
|
||||
content: `删除文件夹 "${record.fileName}" 将同时删除其中的所有文件和子文件夹,此操作不可恢复。`,
|
||||
okText: '删除',
|
||||
okType: 'danger',
|
||||
cancelText: '取消',
|
||||
onOk: async () => {
|
||||
await handleDeleteDirectory(fullPath, record.fileName);
|
||||
fetchDataset();
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
删除
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex">
|
||||
<Button
|
||||
@@ -210,7 +262,39 @@ export default function Overview({ dataset, filesOperation, fetchDataset }) {
|
||||
/>
|
||||
|
||||
{/* 文件列表 */}
|
||||
<h2 className="text-base font-semibold mt-8">文件列表</h2>
|
||||
<div className="flex items-center justify-between mt-8 mb-2">
|
||||
<h2 className="text-base font-semibold">文件列表</h2>
|
||||
<Button
|
||||
type="primary"
|
||||
size="small"
|
||||
onClick={() => {
|
||||
let dirName = "";
|
||||
modal.confirm({
|
||||
title: "新建文件夹",
|
||||
content: (
|
||||
<Input
|
||||
autoFocus
|
||||
placeholder="请输入文件夹名称"
|
||||
onChange={(e) => {
|
||||
dirName = e.target.value?.trim();
|
||||
}}
|
||||
/>
|
||||
),
|
||||
okText: "确定",
|
||||
cancelText: "取消",
|
||||
onOk: async () => {
|
||||
if (!dirName) {
|
||||
message.warning("请输入文件夹名称");
|
||||
return Promise.reject();
|
||||
}
|
||||
await handleCreateDirectory(dirName);
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
新建文件夹
|
||||
</Button>
|
||||
</div>
|
||||
{selectedFiles.length > 0 && (
|
||||
<div className="flex items-center gap-2 p-3 bg-blue-50 rounded-lg border border-blue-200">
|
||||
<span className="text-sm text-blue-700 font-medium">
|
||||
@@ -240,10 +324,14 @@ export default function Overview({ dataset, filesOperation, fetchDataset }) {
|
||||
onClick={() => {
|
||||
// 获取上一级目录
|
||||
const currentPath = filesOperation.pagination.prefix || '';
|
||||
const pathParts = currentPath.split('/').filter(Boolean);
|
||||
pathParts.pop(); // 移除最后一个目录
|
||||
// 移除末尾的斜杠,然后按斜杠分割
|
||||
const trimmedPath = currentPath.replace(/\/$/, '');
|
||||
const pathParts = trimmedPath.split('/');
|
||||
// 移除最后一个目录名
|
||||
pathParts.pop();
|
||||
// 重新组合路径,如果还有内容则加斜杠,否则为空
|
||||
const parentPath = pathParts.length > 0 ? `${pathParts.join('/')}/` : '';
|
||||
filesOperation.fetchFiles(parentPath);
|
||||
filesOperation.fetchFiles(parentPath, 1, filesOperation.pagination.pageSize);
|
||||
}}
|
||||
className="p-0"
|
||||
>
|
||||
@@ -281,12 +369,7 @@ export default function Overview({ dataset, filesOperation, fetchDataset }) {
|
||||
...pagination,
|
||||
showTotal: (total) => `共 ${total} 条`,
|
||||
onChange: (page, pageSize) => {
|
||||
filesOperation.setPagination(prev => ({
|
||||
...prev,
|
||||
current: page,
|
||||
pageSize: pageSize
|
||||
}));
|
||||
filesOperation.fetchFiles(pagination.prefix, page, pageSize);
|
||||
filesOperation.fetchFiles(filesOperation.pagination.prefix, page, pageSize);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
|
||||
@@ -9,6 +9,9 @@ import {
|
||||
downloadFileByIdUsingGet,
|
||||
exportDatasetUsingPost,
|
||||
queryDatasetFilesUsingGet,
|
||||
createDatasetDirectoryUsingPost,
|
||||
downloadDirectoryUsingGet,
|
||||
deleteDirectoryUsingDelete,
|
||||
} from "../dataset.api";
|
||||
import { useParams } from "react-router";
|
||||
|
||||
@@ -31,18 +34,16 @@ export function useFilesOperation(dataset: Dataset) {
|
||||
const [previewContent, setPreviewContent] = useState("");
|
||||
const [previewFileName, setPreviewFileName] = useState("");
|
||||
|
||||
const fetchFiles = async (prefix: string = '', current, pageSize) => {
|
||||
const params: any = {
|
||||
page: current ? current : pagination.current,
|
||||
size: pageSize ? pageSize : pagination.pageSize,
|
||||
isWithDirectory: true,
|
||||
};
|
||||
const fetchFiles = async (prefix?: string, current?, pageSize?) => {
|
||||
// 如果明确传了 prefix(包括空字符串),使用传入的值;否则使用当前 pagination.prefix
|
||||
const targetPrefix = prefix !== undefined ? prefix : (pagination.prefix || '');
|
||||
|
||||
if (prefix !== undefined) {
|
||||
params.prefix = prefix;
|
||||
} else if (pagination.prefix) {
|
||||
params.prefix = pagination.prefix;
|
||||
}
|
||||
const params: any = {
|
||||
page: current !== undefined ? current : pagination.current,
|
||||
size: pageSize !== undefined ? pageSize : pagination.pageSize,
|
||||
isWithDirectory: true,
|
||||
prefix: targetPrefix,
|
||||
};
|
||||
|
||||
const { data } = await queryDatasetFilesUsingGet(id!, params);
|
||||
setFileList(data.content || []);
|
||||
@@ -50,7 +51,9 @@ export function useFilesOperation(dataset: Dataset) {
|
||||
// Update pagination with current prefix
|
||||
setPagination(prev => ({
|
||||
...prev,
|
||||
prefix: prefix !== undefined ? prefix : prev.prefix,
|
||||
current: params.page,
|
||||
pageSize: params.size,
|
||||
prefix: targetPrefix,
|
||||
total: data.totalElements || 0,
|
||||
}));
|
||||
};
|
||||
@@ -145,5 +148,39 @@ export function useFilesOperation(dataset: Dataset) {
|
||||
handleShowFile,
|
||||
handleDeleteFile,
|
||||
handleBatchExport,
|
||||
handleCreateDirectory: async (directoryName: string) => {
|
||||
const currentPrefix = pagination.prefix || "";
|
||||
try {
|
||||
await createDatasetDirectoryUsingPost(dataset.id, {
|
||||
parentPrefix: currentPrefix,
|
||||
directoryName,
|
||||
});
|
||||
// 创建成功后刷新当前目录,重置到第一页
|
||||
await fetchFiles(currentPrefix, 1, pagination.pageSize);
|
||||
message.success({ content: `文件夹 ${directoryName} 创建成功` });
|
||||
} catch (error) {
|
||||
message.error({ content: `文件夹 ${directoryName} 创建失败` });
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
handleDownloadDirectory: async (directoryPath: string, directoryName: string) => {
|
||||
try {
|
||||
await downloadDirectoryUsingGet(dataset.id, directoryPath);
|
||||
message.success({ content: `文件夹 ${directoryName} 下载成功` });
|
||||
} catch (error) {
|
||||
message.error({ content: `文件夹 ${directoryName} 下载失败` });
|
||||
}
|
||||
},
|
||||
handleDeleteDirectory: async (directoryPath: string, directoryName: string) => {
|
||||
try {
|
||||
await deleteDirectoryUsingDelete(dataset.id, directoryPath);
|
||||
// 删除成功后刷新当前目录
|
||||
const currentPrefix = pagination.prefix || "";
|
||||
await fetchFiles(currentPrefix, 1, pagination.pageSize);
|
||||
message.success({ content: `文件夹 ${directoryName} 已删除` });
|
||||
} catch (error) {
|
||||
message.error({ content: `文件夹 ${directoryName} 删除失败` });
|
||||
}
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
@@ -391,6 +391,7 @@ export default function DatasetManagementPage() {
|
||||
setCurrentDataset(null);
|
||||
setShowUploadDialog(false);
|
||||
}}
|
||||
prefix=""
|
||||
updateEvent="update:datasets"
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -54,6 +54,35 @@ export function uploadDatasetFileUsingPost(id: string | number, data: any) {
|
||||
return post(`/api/data-management/datasets/${id}/files`, data);
|
||||
}
|
||||
|
||||
// 新建数据集文件夹
|
||||
export function createDatasetDirectoryUsingPost(
|
||||
id: string | number,
|
||||
data: { parentPrefix?: string; directoryName: string }
|
||||
) {
|
||||
return post(`/api/data-management/datasets/${id}/files/directories`, data);
|
||||
}
|
||||
|
||||
// 下载文件夹(打包为zip)
|
||||
export function downloadDirectoryUsingGet(
|
||||
id: string | number,
|
||||
directoryPath: string
|
||||
) {
|
||||
const dirName = directoryPath.split('/').filter(Boolean).pop() || 'folder';
|
||||
return download(
|
||||
`/api/data-management/datasets/${id}/files/directories/download?prefix=${encodeURIComponent(directoryPath)}`,
|
||||
null,
|
||||
`${dirName}.zip`
|
||||
);
|
||||
}
|
||||
|
||||
// 删除文件夹(递归删除)
|
||||
export function deleteDirectoryUsingDelete(
|
||||
id: string | number,
|
||||
directoryPath: string
|
||||
) {
|
||||
return del(`/api/data-management/datasets/${id}/files/directories?prefix=${encodeURIComponent(directoryPath)}`);
|
||||
}
|
||||
|
||||
export function downloadFileByIdUsingGet(
|
||||
id: string | number,
|
||||
fileId: string | number,
|
||||
|
||||
@@ -19,8 +19,10 @@ export default function TaskUpload() {
|
||||
|
||||
useEffect(() => {
|
||||
const uploadHandler = (e: any) => {
|
||||
console.log('[TaskUpload] Received upload event detail:', e.detail);
|
||||
const { files } = e.detail;
|
||||
const task = createTask(e.detail);
|
||||
console.log('[TaskUpload] Created task with prefix:', task.prefix);
|
||||
handleUpload({ task, files });
|
||||
};
|
||||
window.addEventListener("upload:dataset", uploadHandler);
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
"""
|
||||
Tables of Annotation Management Module
|
||||
"""
|
||||
"""Tables of Annotation Management Module"""
|
||||
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, BigInteger, Boolean, TIMESTAMP, Text, Integer, JSON, Date, ForeignKey
|
||||
from sqlalchemy import Column, String, Boolean, TIMESTAMP, Text, Integer, JSON, ForeignKey
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.db.session import Base
|
||||
@@ -58,3 +56,40 @@ class LabelingProject(Base):
|
||||
def is_deleted(self) -> bool:
|
||||
"""检查是否已被软删除"""
|
||||
return self.deleted_at is not None
|
||||
|
||||
|
||||
class AutoAnnotationTask(Base):
|
||||
"""自动标注任务模型,对应表 t_dm_auto_annotation_tasks"""
|
||||
|
||||
__tablename__ = "t_dm_auto_annotation_tasks"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
|
||||
name = Column(String(255), nullable=False, comment="任务名称")
|
||||
dataset_id = Column(String(36), nullable=False, comment="数据集ID")
|
||||
dataset_name = Column(String(255), nullable=True, comment="数据集名称(冗余字段,方便查询)")
|
||||
config = Column(JSON, nullable=False, comment="任务配置(模型规模、置信度等)")
|
||||
file_ids = Column(JSON, nullable=True, comment="要处理的文件ID列表,为空则处理数据集所有图像")
|
||||
status = Column(String(50), nullable=False, default="pending", comment="任务状态: pending/running/completed/failed")
|
||||
progress = Column(Integer, default=0, comment="任务进度 0-100")
|
||||
total_images = Column(Integer, default=0, comment="总图片数")
|
||||
processed_images = Column(Integer, default=0, comment="已处理图片数")
|
||||
detected_objects = Column(Integer, default=0, comment="检测到的对象总数")
|
||||
output_path = Column(String(500), nullable=True, comment="输出路径")
|
||||
error_message = Column(Text, nullable=True, comment="错误信息")
|
||||
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
|
||||
updated_at = Column(
|
||||
TIMESTAMP,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
comment="更新时间",
|
||||
)
|
||||
completed_at = Column(TIMESTAMP, nullable=True, comment="完成时间")
|
||||
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover - repr 简单返回
|
||||
return f"<AutoAnnotationTask(id={self.id}, name={self.name}, status={self.status})>"
|
||||
|
||||
@property
|
||||
def is_deleted(self) -> bool:
|
||||
"""检查是否已被软删除"""
|
||||
return self.deleted_at is not None
|
||||
@@ -4,6 +4,7 @@ from .config import router as about_router
|
||||
from .project import router as project_router
|
||||
from .task import router as task_router
|
||||
from .template import router as template_router
|
||||
from .auto import router as auto_router
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/annotation",
|
||||
@@ -14,3 +15,4 @@ router.include_router(about_router)
|
||||
router.include_router(project_router)
|
||||
router.include_router(task_router)
|
||||
router.include_router(template_router)
|
||||
router.include_router(auto_router)
|
||||
196
runtime/datamate-python/app/module/annotation/interface/auto.py
Normal file
196
runtime/datamate-python/app/module/annotation/interface/auto.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""FastAPI routes for Auto Annotation tasks.
|
||||
|
||||
These routes back the frontend AutoAnnotation module:
|
||||
- GET /api/annotation/auto
|
||||
- POST /api/annotation/auto
|
||||
- DELETE /api/annotation/auto/{task_id}
|
||||
- GET /api/annotation/auto/{task_id}/status (simple wrapper)
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.session import get_db
|
||||
from app.module.shared.schema import StandardResponse
|
||||
from app.module.dataset import DatasetManagementService
|
||||
from app.core.logging import get_logger
|
||||
|
||||
from ..schema.auto import (
|
||||
CreateAutoAnnotationTaskRequest,
|
||||
AutoAnnotationTaskResponse,
|
||||
)
|
||||
from ..service.auto import AutoAnnotationTaskService
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/auto",
|
||||
tags=["annotation/auto"],
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
service = AutoAnnotationTaskService()
|
||||
|
||||
|
||||
@router.get("", response_model=StandardResponse[List[AutoAnnotationTaskResponse]])
|
||||
async def list_auto_annotation_tasks(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取自动标注任务列表。
|
||||
|
||||
前端当前不传分页参数,这里直接返回所有未删除任务。
|
||||
"""
|
||||
|
||||
tasks = await service.list_tasks(db)
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=tasks,
|
||||
)
|
||||
|
||||
|
||||
@router.post("", response_model=StandardResponse[AutoAnnotationTaskResponse])
|
||||
async def create_auto_annotation_task(
|
||||
request: CreateAutoAnnotationTaskRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""创建自动标注任务。
|
||||
|
||||
当前仅创建任务记录并置为 pending,实际执行由后续调度/worker 完成。
|
||||
"""
|
||||
|
||||
logger.info(
|
||||
"Creating auto annotation task: name=%s, dataset_id=%s, config=%s, file_ids=%s",
|
||||
request.name,
|
||||
request.dataset_id,
|
||||
request.config.model_dump(by_alias=True),
|
||||
request.file_ids,
|
||||
)
|
||||
|
||||
# 尝试获取数据集名称和文件数量用于冗余字段,失败时不阻塞任务创建
|
||||
dataset_name = None
|
||||
total_images = 0
|
||||
try:
|
||||
dm_client = DatasetManagementService(db)
|
||||
# Service.get_dataset 返回 DatasetResponse,包含 name 和 fileCount
|
||||
dataset = await dm_client.get_dataset(request.dataset_id)
|
||||
if dataset is not None:
|
||||
dataset_name = dataset.name
|
||||
# 如果提供了 file_ids,则 total_images 为选中文件数;否则使用数据集文件数
|
||||
if request.file_ids:
|
||||
total_images = len(request.file_ids)
|
||||
else:
|
||||
total_images = getattr(dataset, "fileCount", 0) or 0
|
||||
except Exception as e: # pragma: no cover - 容错
|
||||
logger.warning("Failed to fetch dataset name for auto task: %s", e)
|
||||
|
||||
task = await service.create_task(
|
||||
db,
|
||||
request,
|
||||
dataset_name=dataset_name,
|
||||
total_images=total_images,
|
||||
)
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=task,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{task_id}/status", response_model=StandardResponse[AutoAnnotationTaskResponse])
|
||||
async def get_auto_annotation_task_status(
|
||||
task_id: str = Path(..., description="任务ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取单个自动标注任务状态。
|
||||
|
||||
前端当前主要通过列表轮询,这里提供按 ID 查询的补充接口。
|
||||
"""
|
||||
|
||||
task = await service.get_task(db, task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=task,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{task_id}", response_model=StandardResponse[bool])
|
||||
async def delete_auto_annotation_task(
|
||||
task_id: str = Path(..., description="任务ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""删除(软删除)自动标注任务,仅标记 deleted_at。"""
|
||||
|
||||
ok = await service.soft_delete_task(db, task_id)
|
||||
if not ok:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=True,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{task_id}/download")
|
||||
async def download_auto_annotation_result(
|
||||
task_id: str = Path(..., description="任务ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""下载指定自动标注任务的结果 ZIP。"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import zipfile
|
||||
import tempfile
|
||||
|
||||
# 复用服务层获取任务信息
|
||||
task = await service.get_task(db, task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
if not task.output_path:
|
||||
raise HTTPException(status_code=400, detail="Task has no output path")
|
||||
|
||||
output_dir = task.output_path
|
||||
if not os.path.isdir(output_dir):
|
||||
raise HTTPException(status_code=404, detail="Output directory not found")
|
||||
|
||||
tmp_fd, tmp_path = tempfile.mkstemp(suffix=".zip")
|
||||
os.close(tmp_fd)
|
||||
|
||||
with zipfile.ZipFile(tmp_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
for root, _, files in os.walk(output_dir):
|
||||
for filename in files:
|
||||
file_path = os.path.join(root, filename)
|
||||
arcname = os.path.relpath(file_path, output_dir)
|
||||
zf.write(file_path, arcname)
|
||||
|
||||
file_size = os.path.getsize(tmp_path)
|
||||
if file_size == 0:
|
||||
raise HTTPException(status_code=500, detail="Generated ZIP is empty")
|
||||
|
||||
def iterfile():
|
||||
with open(tmp_path, "rb") as f:
|
||||
while True:
|
||||
chunk = f.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
filename = f"{task.name}_annotations.zip"
|
||||
headers = {
|
||||
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||
"Content-Length": str(file_size),
|
||||
}
|
||||
|
||||
return StreamingResponse(iterfile(), media_type="application/zip", headers=headers)
|
||||
73
runtime/datamate-python/app/module/annotation/schema/auto.py
Normal file
73
runtime/datamate-python/app/module/annotation/schema/auto.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Schemas for Auto Annotation tasks"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
|
||||
class AutoAnnotationConfig(BaseModel):
|
||||
"""自动标注任务配置(与前端 payload 对齐)"""
|
||||
|
||||
model_size: str = Field(alias="modelSize", description="模型规模: n/s/m/l/x")
|
||||
conf_threshold: float = Field(alias="confThreshold", description="置信度阈值 0-1")
|
||||
target_classes: List[int] = Field(
|
||||
default_factory=list,
|
||||
alias="targetClasses",
|
||||
description="目标类别ID列表,空表示全部类别",
|
||||
)
|
||||
output_dataset_name: Optional[str] = Field(
|
||||
default=None,
|
||||
alias="outputDatasetName",
|
||||
description="自动标注结果要写入的新数据集名称(可选)",
|
||||
)
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
|
||||
class CreateAutoAnnotationTaskRequest(BaseModel):
|
||||
"""创建自动标注任务的请求体,对齐前端 CreateAutoAnnotationDialog 发送的结构"""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255, description="任务名称")
|
||||
dataset_id: str = Field(..., alias="datasetId", description="数据集ID")
|
||||
config: AutoAnnotationConfig = Field(..., description="任务配置")
|
||||
file_ids: Optional[List[str]] = Field(None, alias="fileIds", description="要处理的文件ID列表,为空则处理数据集中所有图像")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
|
||||
class AutoAnnotationTaskResponse(BaseModel):
|
||||
"""自动标注任务响应模型(列表/详情均可复用)"""
|
||||
|
||||
id: str = Field(..., description="任务ID")
|
||||
name: str = Field(..., description="任务名称")
|
||||
dataset_id: str = Field(..., alias="datasetId", description="数据集ID")
|
||||
dataset_name: Optional[str] = Field(None, alias="datasetName", description="数据集名称")
|
||||
source_datasets: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
alias="sourceDatasets",
|
||||
description="本任务实际处理涉及到的所有数据集名称列表",
|
||||
)
|
||||
config: Dict[str, Any] = Field(..., description="任务配置")
|
||||
status: str = Field(..., description="任务状态")
|
||||
progress: int = Field(..., description="任务进度 0-100")
|
||||
total_images: int = Field(..., alias="totalImages", description="总图片数")
|
||||
processed_images: int = Field(..., alias="processedImages", description="已处理图片数")
|
||||
detected_objects: int = Field(..., alias="detectedObjects", description="检测到的对象总数")
|
||||
output_path: Optional[str] = Field(None, alias="outputPath", description="输出路径")
|
||||
error_message: Optional[str] = Field(None, alias="errorMessage", description="错误信息")
|
||||
created_at: datetime = Field(..., alias="createdAt", description="创建时间")
|
||||
updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间")
|
||||
completed_at: Optional[datetime] = Field(None, alias="completedAt", description="完成时间")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, from_attributes=True)
|
||||
|
||||
|
||||
class AutoAnnotationTaskListResponse(BaseModel):
|
||||
"""自动标注任务列表响应,目前前端直接使用数组,这里预留分页结构"""
|
||||
|
||||
content: List[AutoAnnotationTaskResponse] = Field(..., description="任务列表")
|
||||
total: int = Field(..., description="总数")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
154
runtime/datamate-python/app/module/annotation/service/auto.py
Normal file
154
runtime/datamate-python/app/module/annotation/service/auto.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""Service layer for Auto Annotation tasks"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models.annotation_management import AutoAnnotationTask
|
||||
from app.db.models.dataset_management import Dataset, DatasetFiles
|
||||
|
||||
from ..schema.auto import (
|
||||
CreateAutoAnnotationTaskRequest,
|
||||
AutoAnnotationTaskResponse,
|
||||
)
|
||||
|
||||
|
||||
class AutoAnnotationTaskService:
|
||||
"""自动标注任务服务(仅管理任务元数据,真正执行由 runtime 负责)"""
|
||||
|
||||
async def create_task(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
request: CreateAutoAnnotationTaskRequest,
|
||||
dataset_name: Optional[str] = None,
|
||||
total_images: int = 0,
|
||||
) -> AutoAnnotationTaskResponse:
|
||||
"""创建自动标注任务,初始状态为 pending。
|
||||
|
||||
这里仅插入任务记录,不负责真正执行 YOLO 推理,
|
||||
后续可以由调度器/worker 读取该表并更新进度。
|
||||
"""
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
task = AutoAnnotationTask(
|
||||
id=str(uuid4()),
|
||||
name=request.name,
|
||||
dataset_id=request.dataset_id,
|
||||
dataset_name=dataset_name,
|
||||
config=request.config.model_dump(by_alias=True),
|
||||
file_ids=request.file_ids, # 存储用户选择的文件ID列表
|
||||
status="pending",
|
||||
progress=0,
|
||||
total_images=total_images,
|
||||
processed_images=0,
|
||||
detected_objects=0,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
|
||||
# 创建后附带 sourceDatasets 信息(通常只有一个原始数据集)
|
||||
resp = AutoAnnotationTaskResponse.model_validate(task)
|
||||
try:
|
||||
resp.source_datasets = await self._compute_source_datasets(db, task)
|
||||
except Exception:
|
||||
resp.source_datasets = [dataset_name] if dataset_name else [request.dataset_id]
|
||||
return resp
|
||||
|
||||
async def list_tasks(self, db: AsyncSession) -> List[AutoAnnotationTaskResponse]:
|
||||
"""获取未软删除的自动标注任务列表,按创建时间倒序。"""
|
||||
|
||||
result = await db.execute(
|
||||
select(AutoAnnotationTask)
|
||||
.where(AutoAnnotationTask.deleted_at.is_(None))
|
||||
.order_by(AutoAnnotationTask.created_at.desc())
|
||||
)
|
||||
tasks: List[AutoAnnotationTask] = list(result.scalars().all())
|
||||
|
||||
responses: List[AutoAnnotationTaskResponse] = []
|
||||
for task in tasks:
|
||||
resp = AutoAnnotationTaskResponse.model_validate(task)
|
||||
try:
|
||||
resp.source_datasets = await self._compute_source_datasets(db, task)
|
||||
except Exception:
|
||||
# 出错时降级为单个 datasetName/datasetId
|
||||
fallback_name = getattr(task, "dataset_name", None)
|
||||
fallback_id = getattr(task, "dataset_id", "")
|
||||
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
|
||||
responses.append(resp)
|
||||
|
||||
return responses
|
||||
|
||||
async def get_task(self, db: AsyncSession, task_id: str) -> Optional[AutoAnnotationTaskResponse]:
|
||||
result = await db.execute(
|
||||
select(AutoAnnotationTask).where(
|
||||
AutoAnnotationTask.id == task_id,
|
||||
AutoAnnotationTask.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return None
|
||||
|
||||
resp = AutoAnnotationTaskResponse.model_validate(task)
|
||||
try:
|
||||
resp.source_datasets = await self._compute_source_datasets(db, task)
|
||||
except Exception:
|
||||
fallback_name = getattr(task, "dataset_name", None)
|
||||
fallback_id = getattr(task, "dataset_id", "")
|
||||
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
|
||||
return resp
|
||||
|
||||
async def _compute_source_datasets(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
task: AutoAnnotationTask,
|
||||
) -> List[str]:
|
||||
"""根据任务的 file_ids 推断实际涉及到的所有数据集名称。
|
||||
|
||||
- 如果存在 file_ids,则通过 t_dm_dataset_files 反查 dataset_id,再关联 t_dm_datasets 获取名称;
|
||||
- 如果没有 file_ids,则退回到任务上冗余的 dataset_name/dataset_id。
|
||||
"""
|
||||
|
||||
file_ids = task.file_ids or []
|
||||
if file_ids:
|
||||
stmt = (
|
||||
select(Dataset.name)
|
||||
.join(DatasetFiles, Dataset.id == DatasetFiles.dataset_id)
|
||||
.where(DatasetFiles.id.in_(file_ids))
|
||||
.distinct()
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
names = [row[0] for row in result.fetchall() if row[0]]
|
||||
if names:
|
||||
return names
|
||||
|
||||
# 回退:只显示一个数据集
|
||||
if task.dataset_name:
|
||||
return [task.dataset_name]
|
||||
if task.dataset_id:
|
||||
return [task.dataset_id]
|
||||
return []
|
||||
|
||||
async def soft_delete_task(self, db: AsyncSession, task_id: str) -> bool:
|
||||
result = await db.execute(
|
||||
select(AutoAnnotationTask).where(
|
||||
AutoAnnotationTask.id == task_id,
|
||||
AutoAnnotationTask.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return False
|
||||
|
||||
task.deleted_at = datetime.now()
|
||||
await db.commit()
|
||||
return True
|
||||
@@ -1,8 +1,8 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
if [-d $LOCAL_FILES_DOCUMENT_ROOT ] && $LOCAL_FILES_SERVING_ENABLED; then
|
||||
echo "Using local document root: $LOCAL_FILES_DOCUMENT_ROOT"
|
||||
if [ -d "${LOCAL_FILES_DOCUMENT_ROOT}" ] && [ "${LOCAL_FILES_SERVING_ENABLED}" = "true" ]; then
|
||||
echo "Using local document root: ${LOCAL_FILES_DOCUMENT_ROOT}"
|
||||
fi
|
||||
|
||||
# 启动应用
|
||||
|
||||
17
runtime/ops/__init__.py
Normal file
17
runtime/ops/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Datamate built-in operators package.
|
||||
|
||||
This package contains built-in operators for filtering, slicing, annotation, etc.
|
||||
It is mounted into the runtime container under ``datamate.ops`` so that
|
||||
``from datamate.ops.annotation...`` imports work correctly.
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"annotation",
|
||||
"filter",
|
||||
"formatter",
|
||||
"llms",
|
||||
"mapper",
|
||||
"slicer",
|
||||
"user",
|
||||
]
|
||||
6
runtime/ops/annotation/__init__.py
Normal file
6
runtime/ops/annotation/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Annotation-related operators (e.g. YOLO detection)."""
|
||||
|
||||
__all__ = [
|
||||
"image_object_detection_bounding_box",
|
||||
]
|
||||
@@ -0,0 +1,9 @@
|
||||
"""Image object detection (YOLOv8) operator package.
|
||||
|
||||
This package exposes the ImageObjectDetectionBoundingBox annotator so that
|
||||
the auto-annotation worker can import it via different module paths.
|
||||
"""
|
||||
|
||||
from .process import ImageObjectDetectionBoundingBox
|
||||
|
||||
__all__ = ["ImageObjectDetectionBoundingBox"]
|
||||
@@ -0,0 +1,3 @@
|
||||
name: image_object_detection_bounding_box
|
||||
version: 0.1.0
|
||||
description: "YOLOv8-based object detection operator for auto annotation"
|
||||
@@ -0,0 +1,214 @@
|
||||
#!/user/bin/python
|
||||
# -- encoding: utf-8 --
|
||||
|
||||
"""
|
||||
Description: 图像目标检测算子
|
||||
Create: 2025/12/17
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, Any
|
||||
import cv2
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
except ImportError:
|
||||
logger.warning("ultralytics not installed. Please install it using: pip install ultralytics")
|
||||
YOLO = None
|
||||
|
||||
from datamate.core.base_op import Mapper
|
||||
|
||||
|
||||
# COCO 80 类别映射
|
||||
COCO_CLASS_MAP = {
|
||||
0: "person", 1: "bicycle", 2: "car", 3: "motorcycle", 4: "airplane",
|
||||
5: "bus", 6: "train", 7: "truck", 8: "boat", 9: "traffic light",
|
||||
10: "fire hydrant", 11: "stop sign", 12: "parking meter", 13: "bench",
|
||||
14: "bird", 15: "cat", 16: "dog", 17: "horse", 18: "sheep", 19: "cow",
|
||||
20: "elephant", 21: "bear", 22: "zebra", 23: "giraffe", 24: "backpack",
|
||||
25: "umbrella", 26: "handbag", 27: "tie", 28: "suitcase", 29: "frisbee",
|
||||
30: "skis", 31: "snowboard", 32: "sports ball", 33: "kite",
|
||||
34: "baseball bat", 35: "baseball glove", 36: "skateboard",
|
||||
37: "surfboard", 38: "tennis racket", 39: "bottle",
|
||||
40: "wine glass", 41: "cup", 42: "fork", 43: "knife", 44: "spoon",
|
||||
45: "bowl", 46: "banana", 47: "apple", 48: "sandwich", 49: "orange",
|
||||
50: "broccoli", 51: "carrot", 52: "hot dog", 53: "pizza",
|
||||
54: "donut", 55: "cake", 56: "chair", 57: "couch",
|
||||
58: "potted plant", 59: "bed", 60: "dining table", 61: "toilet",
|
||||
62: "tv", 63: "laptop", 64: "mouse", 65: "remote",
|
||||
66: "keyboard", 67: "cell phone", 68: "microwave", 69: "oven",
|
||||
70: "toaster", 71: "sink", 72: "refrigerator", 73: "book",
|
||||
74: "clock", 75: "vase", 76: "scissors", 77: "teddy bear",
|
||||
78: "hair drier", 79: "toothbrush"
|
||||
}
|
||||
|
||||
|
||||
class ImageObjectDetectionBoundingBox(Mapper):
|
||||
"""图像目标检测算子"""
|
||||
|
||||
# 模型映射
|
||||
MODEL_MAP = {
|
||||
"n": "yolov8n.pt",
|
||||
"s": "yolov8s.pt",
|
||||
"m": "yolov8m.pt",
|
||||
"l": "yolov8l.pt",
|
||||
"x": "yolov8x.pt",
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ImageObjectDetectionBoundingBox, self).__init__(*args, **kwargs)
|
||||
|
||||
# 获取参数
|
||||
self._model_size = kwargs.get("modelSize", "l")
|
||||
self._conf_threshold = kwargs.get("confThreshold", 0.7)
|
||||
self._target_classes = kwargs.get("targetClasses", [])
|
||||
self._output_dir = kwargs.get("outputDir", None) # 输出目录
|
||||
|
||||
# 如果目标类别为空列表,则检测所有类别
|
||||
if not self._target_classes:
|
||||
self._target_classes = None
|
||||
else:
|
||||
# 确保是整数列表
|
||||
self._target_classes = [int(cls_id) for cls_id in self._target_classes]
|
||||
|
||||
# 获取模型路径
|
||||
model_filename = self.MODEL_MAP.get(self._model_size, "yolov8l.pt")
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = os.path.join(current_dir, model_filename)
|
||||
|
||||
# 初始化模型
|
||||
if YOLO is None:
|
||||
raise ImportError("ultralytics is not installed. Please install it.")
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
logger.warning(f"Model file {model_path} not found. Downloading from ultralytics...")
|
||||
self.model = YOLO(model_filename) # 自动下载
|
||||
else:
|
||||
self.model = YOLO(model_path)
|
||||
|
||||
logger.info(f"Loaded YOLOv8 model: {model_filename}, "
|
||||
f"conf_threshold: {self._conf_threshold}, "
|
||||
f"target_classes: {self._target_classes}")
|
||||
|
||||
@staticmethod
|
||||
def _get_color_by_class_id(class_id: int):
|
||||
"""根据 class_id 生成稳定颜色(BGR,OpenCV 用)"""
|
||||
np.random.seed(class_id)
|
||||
color = np.random.randint(0, 255, size=3).tolist()
|
||||
return tuple(color)
|
||||
|
||||
def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行目标检测"""
|
||||
start = time.time()
|
||||
|
||||
# 读取图像文件
|
||||
image_path = sample.get(self.image_key)
|
||||
if not image_path or not os.path.exists(image_path):
|
||||
logger.warning(f"Image file not found: {image_path}")
|
||||
return sample
|
||||
|
||||
# 读取图像
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
logger.warning(f"Failed to read image: {image_path}")
|
||||
return sample
|
||||
|
||||
# 执行目标检测
|
||||
results = self.model(img, conf=self._conf_threshold)
|
||||
r = results[0]
|
||||
|
||||
# 准备标注数据
|
||||
h, w = img.shape[:2]
|
||||
annotations = {
|
||||
"image": os.path.basename(image_path),
|
||||
"width": w,
|
||||
"height": h,
|
||||
"model_size": self._model_size,
|
||||
"conf_threshold": self._conf_threshold,
|
||||
"selected_class_ids": self._target_classes,
|
||||
"detections": []
|
||||
}
|
||||
|
||||
# 处理检测结果
|
||||
if r.boxes is not None:
|
||||
for box in r.boxes:
|
||||
cls_id = int(box.cls[0])
|
||||
|
||||
# 过滤目标类别
|
||||
if self._target_classes is not None and cls_id not in self._target_classes:
|
||||
continue
|
||||
|
||||
conf = float(box.conf[0])
|
||||
x1, y1, x2, y2 = map(float, box.xyxy[0])
|
||||
label = COCO_CLASS_MAP.get(cls_id, f"class_{cls_id}")
|
||||
|
||||
# 记录检测结果
|
||||
annotations["detections"].append({
|
||||
"label": label,
|
||||
"class_id": cls_id,
|
||||
"confidence": round(conf, 4),
|
||||
"bbox_xyxy": [x1, y1, x2, y2],
|
||||
"bbox_xywh": [x1, y1, x2 - x1, y2 - y1]
|
||||
})
|
||||
|
||||
# 在图像上绘制
|
||||
color = self._get_color_by_class_id(cls_id)
|
||||
cv2.rectangle(
|
||||
img,
|
||||
(int(x1), int(y1)),
|
||||
(int(x2), int(y2)),
|
||||
color,
|
||||
2
|
||||
)
|
||||
|
||||
cv2.putText(
|
||||
img,
|
||||
f"{label} {conf:.2f}",
|
||||
(int(x1), max(int(y1) - 5, 10)),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
color,
|
||||
1
|
||||
)
|
||||
|
||||
# 确定输出目录
|
||||
if self._output_dir and os.path.exists(self._output_dir):
|
||||
output_dir = self._output_dir
|
||||
else:
|
||||
output_dir = os.path.dirname(image_path)
|
||||
|
||||
# 创建输出子目录(可选,用于组织文件)
|
||||
images_dir = os.path.join(output_dir, "images")
|
||||
annotations_dir = os.path.join(output_dir, "annotations")
|
||||
os.makedirs(images_dir, exist_ok=True)
|
||||
os.makedirs(annotations_dir, exist_ok=True)
|
||||
|
||||
# 保持原始文件名(不添加后缀),确保一一对应
|
||||
base_name = os.path.basename(image_path)
|
||||
name_without_ext = os.path.splitext(base_name)[0]
|
||||
|
||||
# 保存标注图像(保持原始扩展名或使用jpg)
|
||||
output_filename = base_name
|
||||
output_path = os.path.join(images_dir, output_filename)
|
||||
cv2.imwrite(output_path, img)
|
||||
|
||||
# 保存标注 JSON(文件名与图像对应)
|
||||
json_filename = f"{name_without_ext}.json"
|
||||
json_path = os.path.join(annotations_dir, json_filename)
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(annotations, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# 更新样本数据
|
||||
sample["detection_count"] = len(annotations["detections"])
|
||||
sample["output_image"] = output_path
|
||||
sample["annotations_file"] = json_path
|
||||
sample["annotations"] = annotations
|
||||
|
||||
logger.info(f"Image: {os.path.basename(image_path)}, "
|
||||
f"Detections: {len(annotations['detections'])}, "
|
||||
f"Time: {(time.time() - start):.4f}s")
|
||||
|
||||
return sample
|
||||
166
runtime/ops/annotation/image_semantic_segmentation/process.py
Normal file
166
runtime/ops/annotation/image_semantic_segmentation/process.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_color_by_class_id(class_id: int):
|
||||
"""根据 class_id 生成稳定颜色(BGR)"""
|
||||
np.random.seed(class_id)
|
||||
color = np.random.randint(0, 255, size=3).tolist()
|
||||
return tuple(color)
|
||||
|
||||
|
||||
def mask_to_polygons(mask: np.ndarray):
|
||||
"""将二值 mask 转换为 COCO 风格多边形列表"""
|
||||
contours, _ = cv2.findContours(
|
||||
mask,
|
||||
cv2.RETR_EXTERNAL,
|
||||
cv2.CHAIN_APPROX_SIMPLE
|
||||
)
|
||||
|
||||
polygons = []
|
||||
for contour in contours:
|
||||
if contour.shape[0] < 3:
|
||||
continue
|
||||
polygon = contour.flatten().tolist()
|
||||
polygons.append(polygon)
|
||||
|
||||
return polygons
|
||||
|
||||
|
||||
IMAGE_DIR = "C:/Users/meta/Desktop/Datamate/yolo/Photos"
|
||||
OUT_IMG_DIR = "outputs_seg/images"
|
||||
OUT_JSON_DIR = "outputs_seg/annotations"
|
||||
|
||||
MODEL_MAP = {
|
||||
"n": "yolov8n-seg.pt",
|
||||
"s": "yolov8s-seg.pt",
|
||||
"m": "yolov8m-seg.pt",
|
||||
"l": "yolov8l-seg.pt",
|
||||
"x": "yolov8x-seg.pt",
|
||||
}
|
||||
MODEL_KEY = "x"
|
||||
MODEL_PATH = MODEL_MAP[MODEL_KEY]
|
||||
|
||||
CONF_THRES = 0.7
|
||||
DRAW_BBOX = True
|
||||
|
||||
COCO_CLASS_MAP = {
|
||||
0: "person", 1: "bicycle", 2: "car", 3: "motorcycle", 4: "airplane",
|
||||
5: "bus", 6: "train", 7: "truck", 8: "boat", 9: "traffic light",
|
||||
10: "fire hydrant", 11: "stop sign", 12: "parking meter", 13: "bench",
|
||||
14: "bird", 15: "cat", 16: "dog", 17: "horse", 18: "sheep", 19: "cow",
|
||||
20: "elephant", 21: "bear", 22: "zebra", 23: "giraffe", 24: "backpack",
|
||||
25: "umbrella", 26: "handbag", 27: "tie", 28: "suitcase", 29: "frisbee",
|
||||
30: "skis", 31: "snowboard", 32: "sports ball", 33: "kite",
|
||||
34: "baseball bat", 35: "baseball glove", 36: "skateboard",
|
||||
37: "surfboard", 38: "tennis racket", 39: "bottle",
|
||||
40: "wine glass", 41: "cup", 42: "fork", 43: "knife", 44: "spoon",
|
||||
45: "bowl", 46: "banana", 47: "apple", 48: "sandwich", 49: "orange",
|
||||
50: "broccoli", 51: "carrot", 52: "hot dog", 53: "pizza",
|
||||
54: "donut", 55: "cake", 56: "chair", 57: "couch",
|
||||
58: "potted plant", 59: "bed", 60: "dining table", 61: "toilet",
|
||||
62: "tv", 63: "laptop", 64: "mouse", 65: "remote",
|
||||
66: "keyboard", 67: "cell phone", 68: "microwave", 69: "oven",
|
||||
70: "toaster", 71: "sink", 72: "refrigerator", 73: "book",
|
||||
74: "clock", 75: "vase", 76: "scissors", 77: "teddy bear",
|
||||
78: "hair drier", 79: "toothbrush"
|
||||
}
|
||||
|
||||
TARGET_CLASS_IDS = [0, 2, 5]
|
||||
|
||||
os.makedirs(OUT_IMG_DIR, exist_ok=True)
|
||||
os.makedirs(OUT_JSON_DIR, exist_ok=True)
|
||||
|
||||
if TARGET_CLASS_IDS is not None:
|
||||
for cid in TARGET_CLASS_IDS:
|
||||
if cid not in COCO_CLASS_MAP:
|
||||
raise ValueError(f"Invalid class id: {cid}")
|
||||
|
||||
model = YOLO(MODEL_PATH)
|
||||
|
||||
image_paths = list(Path(IMAGE_DIR).glob("*.*"))
|
||||
|
||||
for img_path in image_paths:
|
||||
img = cv2.imread(str(img_path))
|
||||
if img is None:
|
||||
print(f"[WARN] Failed to read {img_path}")
|
||||
continue
|
||||
|
||||
results = model(img, conf=CONF_THRES)
|
||||
r = results[0]
|
||||
|
||||
h, w = img.shape[:2]
|
||||
annotations = {
|
||||
"image": img_path.name,
|
||||
"width": w,
|
||||
"height": h,
|
||||
"model_key": MODEL_KEY,
|
||||
"conf_threshold": CONF_THRES,
|
||||
"supported_classes": COCO_CLASS_MAP,
|
||||
"selected_class_ids": TARGET_CLASS_IDS,
|
||||
"instances": []
|
||||
}
|
||||
|
||||
if r.boxes is not None and r.masks is not None:
|
||||
for i, box in enumerate(r.boxes):
|
||||
cls_id = int(box.cls[0])
|
||||
if TARGET_CLASS_IDS is not None and cls_id not in TARGET_CLASS_IDS:
|
||||
continue
|
||||
|
||||
conf = float(box.conf[0])
|
||||
x1, y1, x2, y2 = map(float, box.xyxy[0])
|
||||
label = COCO_CLASS_MAP[cls_id]
|
||||
|
||||
mask = r.masks.data[i].cpu().numpy()
|
||||
mask = (mask > 0.5).astype(np.uint8)
|
||||
mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
color = get_color_by_class_id(cls_id)
|
||||
img[mask == 1] = (
|
||||
img[mask == 1] * 0.5 + np.array(color) * 0.5
|
||||
).astype(np.uint8)
|
||||
|
||||
if True:
|
||||
cv2.rectangle(
|
||||
img,
|
||||
(int(x1), int(y1)),
|
||||
(int(x2), int(y2)),
|
||||
color,
|
||||
2
|
||||
)
|
||||
|
||||
cv2.putText(
|
||||
img,
|
||||
f"{label} {conf:.2f}",
|
||||
(int(x1), max(int(y1) - 5, 10)),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
color,
|
||||
1
|
||||
)
|
||||
|
||||
polygons = mask_to_polygons(mask)
|
||||
|
||||
annotations["instances"].append({
|
||||
"label": label,
|
||||
"class_id": cls_id,
|
||||
"confidence": round(conf, 4),
|
||||
"bbox_xyxy": [x1, y1, x2, y2],
|
||||
"segmentation": polygons
|
||||
})
|
||||
|
||||
out_img_path = os.path.join(OUT_IMG_DIR, img_path.name)
|
||||
out_json_path = os.path.join(OUT_JSON_DIR, img_path.stem + ".json")
|
||||
|
||||
cv2.imwrite(out_img_path, img)
|
||||
|
||||
with open(out_json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(annotations, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"[OK] {img_path.name}")
|
||||
|
||||
print("Segmentation batch finished.")
|
||||
@@ -31,4 +31,5 @@ dependencies = [
|
||||
"sqlalchemy>=2.0.44",
|
||||
"xmltodict>=1.0.2",
|
||||
"zhconv>=1.4.3",
|
||||
"ultralytics>=8.0.0",
|
||||
]
|
||||
|
||||
603
runtime/python-executor/datamate/auto_annotation_worker.py
Normal file
603
runtime/python-executor/datamate/auto_annotation_worker.py
Normal file
@@ -0,0 +1,603 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Simple background worker for auto-annotation tasks.
|
||||
|
||||
This module runs inside the datamate-runtime container (operator_runtime service).
|
||||
It polls `t_dm_auto_annotation_tasks` for pending tasks and performs YOLO
|
||||
inference using the ImageObjectDetectionBoundingBox operator, updating
|
||||
progress back to the same table so that the datamate-python backend and
|
||||
frontend can display real-time status.
|
||||
|
||||
设计目标(最小可用版本):
|
||||
- 单实例 worker,串行处理 `pending` 状态的任务。
|
||||
- 对指定数据集下的所有已完成文件逐张执行目标检测。
|
||||
- 按已处理图片数更新 `processed_images`、`progress`、`detected_objects`、`status` 等字段。
|
||||
- 失败时将任务标记为 `failed` 并记录 `error_message`。
|
||||
|
||||
注意:
|
||||
- 为了保持简单,目前不处理 "running" 状态的恢复逻辑;容器重启时,
|
||||
已处于 running 的任务不会被重新拉起,需要后续扩展。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import text
|
||||
|
||||
from datamate.sql_manager.sql_manager import SQLManager
|
||||
|
||||
# 尝试多种导入路径,适配不同的打包/安装方式
|
||||
ImageObjectDetectionBoundingBox = None # type: ignore
|
||||
try:
|
||||
# 优先使用 datamate.ops 路径(源码 COPY 到 /opt/runtime/datamate/ops 情况)
|
||||
from datamate.ops.annotation.image_object_detection_bounding_box.process import ( # type: ignore
|
||||
ImageObjectDetectionBoundingBox,
|
||||
)
|
||||
logger.info(
|
||||
"Imported ImageObjectDetectionBoundingBox from datamate.ops.annotation.image_object_detection_bounding_box",
|
||||
)
|
||||
except Exception as e1: # pragma: no cover - 导入失败时仅记录日志,避免整体崩溃
|
||||
logger.error(
|
||||
"Failed to import ImageObjectDetectionBoundingBox via datamate.ops: {}",
|
||||
e1,
|
||||
)
|
||||
try:
|
||||
# 兼容顶层 ops 包安装的情况(通过 ops.pth 暴露)
|
||||
from ops.annotation.image_object_detection_bounding_box.process import ( # type: ignore
|
||||
ImageObjectDetectionBoundingBox,
|
||||
)
|
||||
logger.info(
|
||||
"Imported ImageObjectDetectionBoundingBox from top-level ops.annotation.image_object_detection_bounding_box",
|
||||
)
|
||||
except Exception as e2:
|
||||
logger.error(
|
||||
"Failed to import ImageObjectDetectionBoundingBox via top-level ops package: {}",
|
||||
e2,
|
||||
)
|
||||
ImageObjectDetectionBoundingBox = None
|
||||
|
||||
|
||||
# 进一步兜底:直接从本地 runtime/ops 目录加载算子(开发环境常用场景)
|
||||
if ImageObjectDetectionBoundingBox is None:
|
||||
try:
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
ops_root = project_root / "ops"
|
||||
if ops_root.is_dir():
|
||||
# 确保 ops 的父目录在 sys.path 中,这样可以按 "ops.xxx" 导入
|
||||
if str(project_root) not in sys.path:
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from ops.annotation.image_object_detection_bounding_box.process import ( # type: ignore
|
||||
ImageObjectDetectionBoundingBox,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Imported ImageObjectDetectionBoundingBox from local runtime/ops.annotation.image_object_detection_bounding_box",
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Local runtime/ops directory not found when trying to import ImageObjectDetectionBoundingBox: {}",
|
||||
ops_root,
|
||||
)
|
||||
except Exception as e3: # pragma: no cover - 兜底失败仅记录日志
|
||||
logger.error(
|
||||
"Failed to import ImageObjectDetectionBoundingBox from local runtime/ops: {}",
|
||||
e3,
|
||||
)
|
||||
ImageObjectDetectionBoundingBox = None
|
||||
|
||||
|
||||
POLL_INTERVAL_SECONDS = float(os.getenv("AUTO_ANNOTATION_POLL_INTERVAL", "5"))
|
||||
|
||||
DEFAULT_OUTPUT_ROOT = os.getenv(
|
||||
"AUTO_ANNOTATION_OUTPUT_ROOT", "/dataset"
|
||||
)
|
||||
|
||||
|
||||
def _fetch_pending_task() -> Optional[Dict[str, Any]]:
|
||||
"""从 t_dm_auto_annotation_tasks 中取出一个 pending 任务。"""
|
||||
|
||||
sql = text(
|
||||
"""
|
||||
SELECT id, name, dataset_id, dataset_name, config, file_ids, status,
|
||||
total_images, processed_images, detected_objects, output_path
|
||||
FROM t_dm_auto_annotation_tasks
|
||||
WHERE status = 'pending' AND deleted_at IS NULL
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
|
||||
with SQLManager.create_connect() as conn:
|
||||
result = conn.execute(sql).fetchone()
|
||||
if not result:
|
||||
return None
|
||||
row = dict(result._mapping) # type: ignore[attr-defined]
|
||||
|
||||
try:
|
||||
row["config"] = json.loads(row["config"]) if row.get("config") else {}
|
||||
except Exception:
|
||||
row["config"] = {}
|
||||
|
||||
try:
|
||||
raw_ids = row.get("file_ids")
|
||||
if not raw_ids:
|
||||
row["file_ids"] = None
|
||||
elif isinstance(raw_ids, str):
|
||||
row["file_ids"] = json.loads(raw_ids)
|
||||
else:
|
||||
row["file_ids"] = raw_ids
|
||||
except Exception:
|
||||
row["file_ids"] = None
|
||||
return row
|
||||
|
||||
|
||||
def _update_task_status(
|
||||
task_id: str,
|
||||
*,
|
||||
status: str,
|
||||
progress: Optional[int] = None,
|
||||
processed_images: Optional[int] = None,
|
||||
detected_objects: Optional[int] = None,
|
||||
total_images: Optional[int] = None,
|
||||
output_path: Optional[str] = None,
|
||||
error_message: Optional[str] = None,
|
||||
completed: bool = False,
|
||||
) -> None:
|
||||
"""更新任务的状态和统计字段。"""
|
||||
|
||||
fields: List[str] = ["status = :status", "updated_at = :updated_at"]
|
||||
params: Dict[str, Any] = {
|
||||
"task_id": task_id,
|
||||
"status": status,
|
||||
"updated_at": datetime.now(),
|
||||
}
|
||||
|
||||
if progress is not None:
|
||||
fields.append("progress = :progress")
|
||||
params["progress"] = int(progress)
|
||||
if processed_images is not None:
|
||||
fields.append("processed_images = :processed_images")
|
||||
params["processed_images"] = int(processed_images)
|
||||
if detected_objects is not None:
|
||||
fields.append("detected_objects = :detected_objects")
|
||||
params["detected_objects"] = int(detected_objects)
|
||||
if total_images is not None:
|
||||
fields.append("total_images = :total_images")
|
||||
params["total_images"] = int(total_images)
|
||||
if output_path is not None:
|
||||
fields.append("output_path = :output_path")
|
||||
params["output_path"] = output_path
|
||||
if error_message is not None:
|
||||
fields.append("error_message = :error_message")
|
||||
params["error_message"] = error_message[:2000]
|
||||
if completed:
|
||||
fields.append("completed_at = :completed_at")
|
||||
params["completed_at"] = datetime.now()
|
||||
|
||||
sql = text(
|
||||
f"""
|
||||
UPDATE t_dm_auto_annotation_tasks
|
||||
SET {', '.join(fields)}
|
||||
WHERE id = :task_id
|
||||
"""
|
||||
)
|
||||
|
||||
with SQLManager.create_connect() as conn:
|
||||
conn.execute(sql, params)
|
||||
|
||||
|
||||
def _load_dataset_files(dataset_id: str) -> List[Tuple[str, str, str]]:
|
||||
"""加载指定数据集下的所有已完成文件。"""
|
||||
|
||||
sql = text(
|
||||
"""
|
||||
SELECT id, file_path, file_name
|
||||
FROM t_dm_dataset_files
|
||||
WHERE dataset_id = :dataset_id
|
||||
AND status = 'ACTIVE'
|
||||
ORDER BY created_at ASC
|
||||
"""
|
||||
)
|
||||
|
||||
with SQLManager.create_connect() as conn:
|
||||
rows = conn.execute(sql, {"dataset_id": dataset_id}).fetchall()
|
||||
return [(str(r[0]), str(r[1]), str(r[2])) for r in rows]
|
||||
|
||||
|
||||
def _load_files_by_ids(file_ids: List[str]) -> List[Tuple[str, str, str]]:
|
||||
"""根据文件ID列表加载文件记录,支持跨多个数据集。"""
|
||||
|
||||
if not file_ids:
|
||||
return []
|
||||
|
||||
placeholders = ", ".join(f":id{i}" for i in range(len(file_ids)))
|
||||
sql = text(
|
||||
f"""
|
||||
SELECT id, file_path, file_name
|
||||
FROM t_dm_dataset_files
|
||||
WHERE id IN ({placeholders})
|
||||
AND status = 'ACTIVE'
|
||||
ORDER BY created_at ASC
|
||||
"""
|
||||
)
|
||||
params = {f"id{i}": str(fid) for i, fid in enumerate(file_ids)}
|
||||
|
||||
with SQLManager.create_connect() as conn:
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
return [(str(r[0]), str(r[1]), str(r[2])) for r in rows]
|
||||
|
||||
|
||||
def _ensure_output_dir(output_dir: str) -> str:
|
||||
"""确保输出目录及其 images/、annotations/ 子目录存在。"""
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
os.makedirs(os.path.join(output_dir, "images"), exist_ok=True)
|
||||
os.makedirs(os.path.join(output_dir, "annotations"), exist_ok=True)
|
||||
return output_dir
|
||||
|
||||
|
||||
def _create_output_dataset(
|
||||
source_dataset_id: str,
|
||||
source_dataset_name: str,
|
||||
output_dataset_name: str,
|
||||
) -> Tuple[str, str]:
|
||||
"""为自动标注结果创建一个新的数据集并返回 (dataset_id, path)。"""
|
||||
|
||||
new_dataset_id = str(uuid.uuid4())
|
||||
dataset_base_path = DEFAULT_OUTPUT_ROOT.rstrip("/") or "/dataset"
|
||||
output_dir = os.path.join(dataset_base_path, new_dataset_id)
|
||||
|
||||
description = (
|
||||
f"Auto annotations for dataset {source_dataset_name or source_dataset_id}"[:255]
|
||||
)
|
||||
|
||||
sql = text(
|
||||
"""
|
||||
INSERT INTO t_dm_datasets (id, name, description, dataset_type, path, status)
|
||||
VALUES (:id, :name, :description, :dataset_type, :path, :status)
|
||||
"""
|
||||
)
|
||||
params = {
|
||||
"id": new_dataset_id,
|
||||
"name": output_dataset_name,
|
||||
"description": description,
|
||||
"dataset_type": "IMAGE",
|
||||
"path": output_dir,
|
||||
"status": "ACTIVE",
|
||||
}
|
||||
|
||||
with SQLManager.create_connect() as conn:
|
||||
conn.execute(sql, params)
|
||||
|
||||
return new_dataset_id, output_dir
|
||||
|
||||
|
||||
def _register_output_dataset(
|
||||
task_id: str,
|
||||
output_dataset_id: str,
|
||||
output_dir: str,
|
||||
output_dataset_name: str,
|
||||
total_images: int,
|
||||
) -> None:
|
||||
"""将自动标注结果注册到新建的数据集。"""
|
||||
|
||||
images_dir = os.path.join(output_dir, "images")
|
||||
if not os.path.isdir(images_dir):
|
||||
logger.warning(
|
||||
"Auto-annotation images directory not found for task {}: {}",
|
||||
task_id,
|
||||
images_dir,
|
||||
)
|
||||
return
|
||||
|
||||
image_files: List[Tuple[str, str, int]] = []
|
||||
annotation_files: List[Tuple[str, str, int]] = []
|
||||
total_size = 0
|
||||
|
||||
for file_name in sorted(os.listdir(images_dir)):
|
||||
file_path = os.path.join(images_dir, file_name)
|
||||
if not os.path.isfile(file_path):
|
||||
continue
|
||||
try:
|
||||
file_size = os.path.getsize(file_path)
|
||||
except OSError:
|
||||
file_size = 0
|
||||
image_files.append((file_name, file_path, int(file_size)))
|
||||
total_size += int(file_size)
|
||||
|
||||
annotations_dir = os.path.join(output_dir, "annotations")
|
||||
if os.path.isdir(annotations_dir):
|
||||
for file_name in sorted(os.listdir(annotations_dir)):
|
||||
file_path = os.path.join(annotations_dir, file_name)
|
||||
if not os.path.isfile(file_path):
|
||||
continue
|
||||
try:
|
||||
file_size = os.path.getsize(file_path)
|
||||
except OSError:
|
||||
file_size = 0
|
||||
annotation_files.append((file_name, file_path, int(file_size)))
|
||||
total_size += int(file_size)
|
||||
|
||||
if not image_files:
|
||||
logger.warning(
|
||||
"No image files found in auto-annotation output for task {}: {}",
|
||||
task_id,
|
||||
images_dir,
|
||||
)
|
||||
return
|
||||
|
||||
insert_file_sql = text(
|
||||
"""
|
||||
INSERT INTO t_dm_dataset_files (
|
||||
id, dataset_id, file_name, file_path, file_type, file_size, status
|
||||
) VALUES (
|
||||
:id, :dataset_id, :file_name, :file_path, :file_type, :file_size, :status
|
||||
)
|
||||
"""
|
||||
)
|
||||
update_dataset_stat_sql = text(
|
||||
"""
|
||||
UPDATE t_dm_datasets
|
||||
SET file_count = COALESCE(file_count, 0) + :add_count,
|
||||
size_bytes = COALESCE(size_bytes, 0) + :add_size
|
||||
WHERE id = :dataset_id
|
||||
"""
|
||||
)
|
||||
|
||||
with SQLManager.create_connect() as conn:
|
||||
added_count = 0
|
||||
|
||||
for file_name, file_path, file_size in image_files:
|
||||
ext = os.path.splitext(file_name)[1].lstrip(".").upper() or None
|
||||
conn.execute(
|
||||
insert_file_sql,
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"dataset_id": output_dataset_id,
|
||||
"file_name": file_name,
|
||||
"file_path": file_path,
|
||||
"file_type": ext,
|
||||
"file_size": int(file_size),
|
||||
"status": "ACTIVE",
|
||||
},
|
||||
)
|
||||
added_count += 1
|
||||
|
||||
for file_name, file_path, file_size in annotation_files:
|
||||
ext = os.path.splitext(file_name)[1].lstrip(".").upper() or None
|
||||
conn.execute(
|
||||
insert_file_sql,
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"dataset_id": output_dataset_id,
|
||||
"file_name": file_name,
|
||||
"file_path": file_path,
|
||||
"file_type": ext,
|
||||
"file_size": int(file_size),
|
||||
"status": "ACTIVE",
|
||||
},
|
||||
)
|
||||
added_count += 1
|
||||
|
||||
if added_count > 0:
|
||||
conn.execute(
|
||||
update_dataset_stat_sql,
|
||||
{
|
||||
"dataset_id": output_dataset_id,
|
||||
"add_count": added_count,
|
||||
"add_size": int(total_size),
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Registered auto-annotation output into dataset: dataset_id={}, name={}, added_files={}, added_size_bytes={}, task_id={}, output_dir={}",
|
||||
output_dataset_id,
|
||||
output_dataset_name,
|
||||
len(image_files) + len(annotation_files),
|
||||
total_size,
|
||||
task_id,
|
||||
output_dir,
|
||||
)
|
||||
|
||||
|
||||
def _process_single_task(task: Dict[str, Any]) -> None:
|
||||
"""执行单个自动标注任务。"""
|
||||
|
||||
if ImageObjectDetectionBoundingBox is None:
|
||||
logger.error(
|
||||
"YOLO operator not available (import failed earlier), skip auto-annotation task: {}",
|
||||
task["id"],
|
||||
)
|
||||
_update_task_status(
|
||||
task["id"],
|
||||
status="failed",
|
||||
error_message="YOLO operator not available in runtime container",
|
||||
)
|
||||
return
|
||||
|
||||
task_id = str(task["id"])
|
||||
dataset_id = str(task["dataset_id"])
|
||||
task_name = str(task.get("name") or "")
|
||||
source_dataset_name = str(task.get("dataset_name") or "")
|
||||
cfg: Dict[str, Any] = task.get("config") or {}
|
||||
selected_file_ids: Optional[List[str]] = task.get("file_ids") or None
|
||||
|
||||
model_size = cfg.get("modelSize", "l")
|
||||
conf_threshold = float(cfg.get("confThreshold", 0.7))
|
||||
target_classes = cfg.get("targetClasses", []) or []
|
||||
output_dataset_name = cfg.get("outputDatasetName")
|
||||
|
||||
if not output_dataset_name:
|
||||
base_name = source_dataset_name or task_name or f"dataset-{dataset_id[:8]}"
|
||||
output_dataset_name = f"{base_name}_auto_{task_id[:8]}"
|
||||
|
||||
logger.info(
|
||||
"Start processing auto-annotation task: id={}, dataset_id={}, model_size={}, conf_threshold={}, target_classes={}, output_dataset_name={}",
|
||||
task_id,
|
||||
dataset_id,
|
||||
model_size,
|
||||
conf_threshold,
|
||||
target_classes,
|
||||
output_dataset_name,
|
||||
)
|
||||
|
||||
_update_task_status(task_id, status="running", progress=0)
|
||||
|
||||
if selected_file_ids:
|
||||
all_files = _load_files_by_ids(selected_file_ids)
|
||||
else:
|
||||
all_files = _load_dataset_files(dataset_id)
|
||||
|
||||
files = [(path, name) for _, path, name in all_files]
|
||||
|
||||
total_images = len(files)
|
||||
if total_images == 0:
|
||||
logger.warning("No files found for dataset {} when running auto-annotation task {}", dataset_id, task_id)
|
||||
_update_task_status(
|
||||
task_id,
|
||||
status="completed",
|
||||
progress=100,
|
||||
total_images=0,
|
||||
processed_images=0,
|
||||
detected_objects=0,
|
||||
completed=True,
|
||||
output_path=None,
|
||||
)
|
||||
return
|
||||
|
||||
output_dataset_id, output_dir = _create_output_dataset(
|
||||
source_dataset_id=dataset_id,
|
||||
source_dataset_name=source_dataset_name,
|
||||
output_dataset_name=output_dataset_name,
|
||||
)
|
||||
output_dir = _ensure_output_dir(output_dir)
|
||||
|
||||
try:
|
||||
detector = ImageObjectDetectionBoundingBox(
|
||||
modelSize=model_size,
|
||||
confThreshold=conf_threshold,
|
||||
targetClasses=target_classes,
|
||||
outputDir=output_dir,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to init YOLO detector for task {}: {}", task_id, e)
|
||||
_update_task_status(
|
||||
task_id,
|
||||
status="failed",
|
||||
total_images=total_images,
|
||||
processed_images=0,
|
||||
detected_objects=0,
|
||||
error_message=f"Init YOLO detector failed: {e}",
|
||||
)
|
||||
return
|
||||
|
||||
processed = 0
|
||||
detected_total = 0
|
||||
|
||||
for file_path, file_name in files:
|
||||
try:
|
||||
sample = {
|
||||
"image": file_path,
|
||||
"filename": file_name,
|
||||
}
|
||||
result = detector.execute(sample)
|
||||
|
||||
annotations = (result or {}).get("annotations", {})
|
||||
detections = annotations.get("detections", [])
|
||||
detected_total += len(detections)
|
||||
processed += 1
|
||||
|
||||
progress = int(processed * 100 / total_images) if total_images > 0 else 100
|
||||
|
||||
_update_task_status(
|
||||
task_id,
|
||||
status="running",
|
||||
progress=progress,
|
||||
processed_images=processed,
|
||||
detected_objects=detected_total,
|
||||
total_images=total_images,
|
||||
output_path=output_dir,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to process image for task {}: file_path={}, error={}",
|
||||
task_id,
|
||||
file_path,
|
||||
e,
|
||||
)
|
||||
continue
|
||||
|
||||
_update_task_status(
|
||||
task_id,
|
||||
status="completed",
|
||||
progress=100,
|
||||
processed_images=processed,
|
||||
detected_objects=detected_total,
|
||||
total_images=total_images,
|
||||
output_path=output_dir,
|
||||
completed=True,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Completed auto-annotation task: id={}, total_images={}, processed={}, detected_objects={}, output_path={}",
|
||||
task_id,
|
||||
total_images,
|
||||
processed,
|
||||
detected_total,
|
||||
output_dir,
|
||||
)
|
||||
|
||||
if output_dataset_name and output_dataset_id:
|
||||
try:
|
||||
_register_output_dataset(
|
||||
task_id=task_id,
|
||||
output_dataset_id=output_dataset_id,
|
||||
output_dir=output_dir,
|
||||
output_dataset_name=output_dataset_name,
|
||||
total_images=total_images,
|
||||
)
|
||||
except Exception as e: # pragma: no cover - 防御性日志
|
||||
logger.error(
|
||||
"Failed to register auto-annotation output as dataset for task {}: {}",
|
||||
task_id,
|
||||
e,
|
||||
)
|
||||
|
||||
|
||||
def _worker_loop() -> None:
|
||||
"""Worker 主循环,在独立线程中运行。"""
|
||||
|
||||
logger.info(
|
||||
"Auto-annotation worker started with poll interval {} seconds, output root {}",
|
||||
POLL_INTERVAL_SECONDS,
|
||||
DEFAULT_OUTPUT_ROOT,
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
task = _fetch_pending_task()
|
||||
if not task:
|
||||
time.sleep(POLL_INTERVAL_SECONDS)
|
||||
continue
|
||||
|
||||
_process_single_task(task)
|
||||
except Exception as e: # pragma: no cover - 防御性日志
|
||||
logger.error("Auto-annotation worker loop error: {}", e)
|
||||
time.sleep(POLL_INTERVAL_SECONDS)
|
||||
|
||||
|
||||
def start_auto_annotation_worker() -> None:
|
||||
"""在后台线程中启动自动标注 worker。"""
|
||||
|
||||
thread = threading.Thread(target=_worker_loop, name="auto-annotation-worker", daemon=True)
|
||||
thread.start()
|
||||
logger.info("Auto-annotation worker thread started: {}", thread.name)
|
||||
@@ -13,6 +13,7 @@ from datamate.common.error_code import ErrorCode
|
||||
from datamate.scheduler import cmd_scheduler
|
||||
from datamate.scheduler import func_scheduler
|
||||
from datamate.wrappers import WRAPPERS
|
||||
from datamate.auto_annotation_worker import start_auto_annotation_worker
|
||||
|
||||
# 日志配置
|
||||
LOG_DIR = "/var/log/datamate/runtime"
|
||||
@@ -49,6 +50,16 @@ class APIException(Exception):
|
||||
return result
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""FastAPI 启动时初始化后台自动标注 worker。"""
|
||||
|
||||
try:
|
||||
start_auto_annotation_worker()
|
||||
except Exception as e: # pragma: no cover - 防御性日志
|
||||
logger.error("Failed to start auto-annotation worker: {}", e)
|
||||
|
||||
|
||||
@app.exception_handler(APIException)
|
||||
async def api_exception_handler(request: Request, exc: APIException):
|
||||
return JSONResponse(
|
||||
|
||||
@@ -37,6 +37,30 @@ CREATE TABLE t_dm_labeling_projects (
|
||||
INDEX idx_labeling_project_id (labeling_project_id)
|
||||
) COMMENT='标注项目表';
|
||||
|
||||
-- 自动标注任务表
|
||||
CREATE TABLE t_dm_auto_annotation_tasks (
|
||||
id VARCHAR(36) PRIMARY KEY COMMENT 'UUID',
|
||||
name VARCHAR(255) NOT NULL COMMENT '任务名称',
|
||||
dataset_id VARCHAR(36) NOT NULL COMMENT '数据集ID',
|
||||
dataset_name VARCHAR(255) COMMENT '数据集名称(冗余字段,方便查询)',
|
||||
config JSON NOT NULL COMMENT '任务配置(模型规模、置信度等)',
|
||||
file_ids JSON COMMENT '要处理的文件ID列表,为空则处理数据集所有图像',
|
||||
status VARCHAR(50) NOT NULL DEFAULT 'pending' COMMENT '任务状态: pending/running/completed/failed',
|
||||
progress INT DEFAULT 0 COMMENT '任务进度 0-100',
|
||||
total_images INT DEFAULT 0 COMMENT '总图片数',
|
||||
processed_images INT DEFAULT 0 COMMENT '已处理图片数',
|
||||
detected_objects INT DEFAULT 0 COMMENT '检测到的对象总数',
|
||||
output_path VARCHAR(500) COMMENT '输出路径',
|
||||
error_message TEXT COMMENT '错误信息',
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
|
||||
completed_at TIMESTAMP NULL COMMENT '完成时间',
|
||||
deleted_at TIMESTAMP NULL COMMENT '删除时间(软删除)',
|
||||
INDEX idx_dataset_id (dataset_id),
|
||||
INDEX idx_status (status),
|
||||
INDEX idx_created_at (created_at)
|
||||
) COMMENT='自动标注任务表';
|
||||
|
||||
|
||||
-- 内置标注模板初始化数据
|
||||
-- 这些模板将在系统首次启动时自动创建
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
FROM maven:3-eclipse-temurin-21 AS builder
|
||||
|
||||
COPY backend/ /opt/backend
|
||||
|
||||
RUN cd /opt/backend/services && \
|
||||
|
||||
@@ -24,7 +24,6 @@ WORKDIR /opt/runtime
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install -e .[all] --system \
|
||||
&& uv pip install -r /opt/runtime/datamate/ops/pyproject.toml --system \
|
||||
&& uv pip uninstall torch torchvision --system \
|
||||
&& python -m spacy download zh_core_web_sm \
|
||||
&& echo "/usr/local/lib/ops/site-packages" > /usr/local/lib/python3.11/site-packages/ops.pth
|
||||
|
||||
|
||||
Reference in New Issue
Block a user