You've already forked DataMate
- 移除项目参数依赖,简化 _find_knowledge_set_by_name 方法 - 删除不再使用的 _parse_metadata 和 _metadata_matches_project 方法 - 更新知识库集创建流程中的查找调用方式 - 统一所有知识库集查找操作的参数结构
366 lines
13 KiB
Python
366 lines
13 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from datetime import datetime
|
|
from typing import Any, Dict, Optional
|
|
|
|
import httpx
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.config import settings
|
|
from app.core.logging import get_logger
|
|
from app.db.models import Dataset, DatasetFiles, LabelingProject
|
|
from app.module.annotation.service.text_fetcher import fetch_text_content_via_download_api
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class KnowledgeSyncService:
|
|
"""标注保存后同步到知识管理"""
|
|
|
|
CONFIG_KEY_SET_ID = "knowledge_set_id"
|
|
CONFIG_KEY_SET_NAME = "knowledge_set_name"
|
|
KNOWLEDGE_SET_LIST_SIZE = 50
|
|
|
|
def __init__(self, db: AsyncSession):
|
|
self.db = db
|
|
self.base_url = settings.datamate_backend_base_url.rstrip("/")
|
|
|
|
async def sync_annotation_to_knowledge(
|
|
self,
|
|
project: LabelingProject,
|
|
file_record: DatasetFiles,
|
|
annotation: Dict[str, Any],
|
|
annotation_updated_at: Optional[datetime],
|
|
) -> None:
|
|
"""将标注结果同步到知识管理(失败不影响标注保存)"""
|
|
if not project or not file_record:
|
|
logger.warning("标注同步失败:缺少项目或文件信息")
|
|
return
|
|
|
|
set_id = await self._ensure_knowledge_set(project)
|
|
if not set_id:
|
|
logger.warning("标注同步失败:无法获取知识集")
|
|
return
|
|
|
|
item = await self._get_item_by_source(set_id, project.dataset_id, str(file_record.id))
|
|
if item and item.get("status") in {"PUBLISHED", "ARCHIVED", "DEPRECATED"}:
|
|
logger.info(
|
|
"知识条目为只读状态,跳过同步:item_id=%s status=%s",
|
|
item.get("id"),
|
|
item.get("status"),
|
|
)
|
|
return
|
|
|
|
payload = await self._build_item_payload(
|
|
project=project,
|
|
file_record=file_record,
|
|
annotation=annotation,
|
|
annotation_updated_at=annotation_updated_at,
|
|
)
|
|
if not payload:
|
|
logger.warning("标注同步失败:无法构建知识条目内容")
|
|
return
|
|
|
|
try:
|
|
if item:
|
|
await self._update_item(set_id, item["id"], payload)
|
|
else:
|
|
await self._create_item(set_id, payload)
|
|
except Exception as exc:
|
|
logger.warning("标注同步到知识管理失败:%s", exc)
|
|
|
|
async def _ensure_knowledge_set(self, project: LabelingProject) -> Optional[str]:
|
|
config = project.configuration if isinstance(project.configuration, dict) else {}
|
|
set_id = config.get(self.CONFIG_KEY_SET_ID)
|
|
|
|
if set_id:
|
|
exists = await self._get_knowledge_set(set_id)
|
|
if exists:
|
|
return set_id
|
|
logger.warning("知识集不存在,准备重建:set_id=%s", set_id)
|
|
|
|
dataset_name = project.name or "annotation-project"
|
|
base_name = dataset_name.strip() or "annotation-project"
|
|
metadata = self._build_set_metadata(project)
|
|
|
|
existing = await self._find_knowledge_set_by_name(base_name)
|
|
if existing:
|
|
await self._update_project_config(
|
|
project,
|
|
{
|
|
self.CONFIG_KEY_SET_ID: existing.get("id"),
|
|
self.CONFIG_KEY_SET_NAME: existing.get("name"),
|
|
},
|
|
)
|
|
return existing.get("id")
|
|
|
|
created = await self._create_knowledge_set(base_name, metadata)
|
|
if not created:
|
|
created = await self._find_knowledge_set_by_name(base_name)
|
|
|
|
if not created:
|
|
fallback_name = self._build_fallback_set_name(base_name, project.id)
|
|
existing = await self._find_knowledge_set_by_name(fallback_name)
|
|
if existing:
|
|
created = existing
|
|
else:
|
|
created = await self._create_knowledge_set(fallback_name, metadata)
|
|
if not created:
|
|
created = await self._find_knowledge_set_by_name(fallback_name)
|
|
|
|
if not created:
|
|
return None
|
|
|
|
await self._update_project_config(
|
|
project,
|
|
{
|
|
self.CONFIG_KEY_SET_ID: created.get("id"),
|
|
self.CONFIG_KEY_SET_NAME: created.get("name"),
|
|
},
|
|
)
|
|
return created.get("id")
|
|
|
|
async def _get_knowledge_set(self, set_id: str) -> Optional[Dict[str, Any]]:
|
|
try:
|
|
return await self._request("GET", f"/data-management/knowledge-sets/{set_id}")
|
|
except httpx.HTTPStatusError as exc:
|
|
if exc.response.status_code == 404:
|
|
return None
|
|
raise
|
|
|
|
async def _list_knowledge_sets(self, keyword: Optional[str]) -> list[Dict[str, Any]]:
|
|
params: Dict[str, Any] = {
|
|
"page": 0,
|
|
"size": self.KNOWLEDGE_SET_LIST_SIZE,
|
|
}
|
|
if keyword:
|
|
params["keyword"] = keyword
|
|
try:
|
|
data = await self._request("GET", "/data-management/knowledge-sets", params=params)
|
|
except httpx.HTTPStatusError as exc:
|
|
logger.warning(
|
|
"查询知识集失败:keyword=%s status=%s",
|
|
keyword,
|
|
exc.response.status_code,
|
|
)
|
|
return []
|
|
if not isinstance(data, dict):
|
|
return []
|
|
content = data.get("content")
|
|
if not isinstance(content, list):
|
|
return []
|
|
return [item for item in content if isinstance(item, dict)]
|
|
|
|
async def _find_knowledge_set_by_name(self, name: str) -> Optional[Dict[str, Any]]:
|
|
if not name:
|
|
return None
|
|
items = await self._list_knowledge_sets(name)
|
|
if not items:
|
|
return None
|
|
exact_matches = [item for item in items if item.get("name") == name]
|
|
if not exact_matches:
|
|
return None
|
|
return exact_matches[0]
|
|
|
|
async def _create_knowledge_set(self, name: str, metadata: str) -> Optional[Dict[str, Any]]:
|
|
payload = {
|
|
"name": name,
|
|
"description": "标注项目自动创建的知识集",
|
|
"status": "DRAFT",
|
|
"metadata": metadata,
|
|
}
|
|
try:
|
|
return await self._request("POST", "/data-management/knowledge-sets", json=payload)
|
|
except httpx.HTTPStatusError as exc:
|
|
logger.warning(
|
|
"创建知识集失败:name=%s status=%s detail=%s",
|
|
name,
|
|
exc.response.status_code,
|
|
self._safe_response_text(exc.response),
|
|
)
|
|
return None
|
|
|
|
async def _get_item_by_source(
|
|
self,
|
|
set_id: str,
|
|
dataset_id: str,
|
|
file_id: str,
|
|
) -> Optional[Dict[str, Any]]:
|
|
params = {
|
|
"page": 0,
|
|
"size": 1,
|
|
"sourceDatasetId": dataset_id,
|
|
"sourceFileId": file_id,
|
|
}
|
|
try:
|
|
data = await self._request("GET", f"/data-management/knowledge-sets/{set_id}/items", params=params)
|
|
except httpx.HTTPStatusError as exc:
|
|
logger.warning(
|
|
"查询知识条目失败:set_id=%s status=%s",
|
|
set_id,
|
|
exc.response.status_code,
|
|
)
|
|
return None
|
|
|
|
if not isinstance(data, dict):
|
|
return None
|
|
content = data.get("content") or []
|
|
if not content:
|
|
return None
|
|
return content[0]
|
|
|
|
async def _create_item(self, set_id: str, payload: Dict[str, Any]) -> None:
|
|
await self._request("POST", f"/data-management/knowledge-sets/{set_id}/items", json=payload)
|
|
|
|
async def _update_item(self, set_id: str, item_id: str, payload: Dict[str, Any]) -> None:
|
|
update_payload = dict(payload)
|
|
update_payload.pop("sourceDatasetId", None)
|
|
update_payload.pop("sourceFileId", None)
|
|
await self._request(
|
|
"PUT",
|
|
f"/data-management/knowledge-sets/{set_id}/items/{item_id}",
|
|
json=update_payload,
|
|
)
|
|
|
|
async def _build_item_payload(
|
|
self,
|
|
project: LabelingProject,
|
|
file_record: DatasetFiles,
|
|
annotation: Dict[str, Any],
|
|
annotation_updated_at: Optional[datetime],
|
|
) -> Optional[Dict[str, Any]]:
|
|
dataset_type = await self._get_dataset_type(project.dataset_id)
|
|
annotation_json = self._safe_json_dumps(annotation)
|
|
metadata = self._build_item_metadata(
|
|
project=project,
|
|
file_record=file_record,
|
|
annotation=annotation,
|
|
annotation_updated_at=annotation_updated_at,
|
|
)
|
|
|
|
title = self._strip_extension(getattr(file_record, "file_name", ""))
|
|
if not title:
|
|
title = "未命名"
|
|
|
|
content_type = "TEXT"
|
|
if dataset_type == "TEXT" and self._is_markdown_file(file_record):
|
|
content_type = "MARKDOWN"
|
|
|
|
content = annotation_json
|
|
if dataset_type == "TEXT":
|
|
try:
|
|
content = await fetch_text_content_via_download_api(
|
|
project.dataset_id,
|
|
str(file_record.id),
|
|
)
|
|
content = self._append_annotation_to_content(content, annotation_json, content_type)
|
|
except Exception as exc:
|
|
logger.warning("读取文本失败,改为仅存标注JSON:%s", exc)
|
|
content = annotation_json
|
|
|
|
payload: Dict[str, Any] = {
|
|
"title": title,
|
|
"content": content,
|
|
"contentType": content_type,
|
|
"metadata": metadata,
|
|
"sourceDatasetId": project.dataset_id,
|
|
"sourceFileId": str(file_record.id),
|
|
}
|
|
return payload
|
|
|
|
async def _get_dataset_type(self, dataset_id: str) -> Optional[str]:
|
|
result = await self.db.execute(
|
|
select(Dataset.dataset_type).where(Dataset.id == dataset_id)
|
|
)
|
|
dataset_type = result.scalar_one_or_none()
|
|
return str(dataset_type).upper() if dataset_type else None
|
|
|
|
def _is_markdown_file(self, file_record: DatasetFiles) -> bool:
|
|
file_name = getattr(file_record, "file_name", "") or ""
|
|
file_type = getattr(file_record, "file_type", "") or ""
|
|
extension = ""
|
|
if "." in file_name:
|
|
extension = file_name.rsplit(".", 1)[-1]
|
|
elif file_type.startswith("."):
|
|
extension = file_type[1:]
|
|
else:
|
|
extension = file_type
|
|
return extension.lower() in {"md", "markdown"}
|
|
|
|
def _append_annotation_to_content(self, content: str, annotation_json: str, content_type: str) -> str:
|
|
if content_type == "MARKDOWN":
|
|
return (
|
|
f"{content}\n\n---\n\n## 标注结果\n\n```json\n"
|
|
f"{annotation_json}\n```")
|
|
return f"{content}\n\n---\n\n标注结果(JSON):\n{annotation_json}"
|
|
|
|
def _strip_extension(self, file_name: str) -> str:
|
|
if not file_name:
|
|
return ""
|
|
if "." not in file_name:
|
|
return file_name
|
|
return file_name.rsplit(".", 1)[0]
|
|
|
|
def _build_set_metadata(self, project: LabelingProject) -> str:
|
|
payload = {
|
|
"source": "annotation",
|
|
"project_id": project.id,
|
|
"dataset_id": project.dataset_id,
|
|
}
|
|
return self._safe_json_dumps(payload)
|
|
|
|
def _build_item_metadata(
|
|
self,
|
|
project: LabelingProject,
|
|
file_record: DatasetFiles,
|
|
annotation: Dict[str, Any],
|
|
annotation_updated_at: Optional[datetime],
|
|
) -> str:
|
|
payload: Dict[str, Any] = {
|
|
"source": {
|
|
"type": "annotation",
|
|
"project_id": project.id,
|
|
"dataset_id": project.dataset_id,
|
|
"file_id": str(file_record.id),
|
|
"file_name": getattr(file_record, "file_name", ""),
|
|
},
|
|
"annotation": annotation,
|
|
}
|
|
if annotation_updated_at:
|
|
payload["annotation_updated_at"] = annotation_updated_at.isoformat()
|
|
return self._safe_json_dumps(payload)
|
|
|
|
def _build_fallback_set_name(self, base_name: str, project_id: str) -> str:
|
|
short_id = project_id.replace("-", "")[:8]
|
|
return f"{base_name}-annotation-{short_id}"
|
|
|
|
async def _update_project_config(self, project: LabelingProject, updates: Dict[str, Any]) -> None:
|
|
config = project.configuration if isinstance(project.configuration, dict) else {}
|
|
config.update(updates)
|
|
project.configuration = config
|
|
await self.db.commit()
|
|
await self.db.refresh(project)
|
|
|
|
async def _request(self, method: str, path: str, **kwargs) -> Any:
|
|
url = f"{self.base_url}{path}"
|
|
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
|
response = await client.request(method, url, **kwargs)
|
|
response.raise_for_status()
|
|
if response.content:
|
|
return response.json()
|
|
return None
|
|
|
|
def _safe_json_dumps(self, payload: Any) -> str:
|
|
try:
|
|
return json.dumps(payload, ensure_ascii=False, indent=2)
|
|
except Exception:
|
|
return json.dumps({"error": "failed to serialize"}, ensure_ascii=False)
|
|
|
|
def _safe_response_text(self, response: httpx.Response) -> str:
|
|
try:
|
|
return response.text
|
|
except Exception:
|
|
return ""
|