feature: deer-flow支持从datamate获取外部接入模型 (#83)

* feature: deer-flow支持从datamate获取外部接入模型
This commit is contained in:
hhhhsc701
2025-11-13 20:13:16 +08:00
committed by GitHub
parent 604fd019d5
commit 5cef9cb273
19 changed files with 177 additions and 383 deletions

View File

@@ -1,177 +0,0 @@
diff --git a/src/rag/milvus.py b/src/rag/milvus.py
index de589d4..c1b9b98 100644
--- a/src/rag/milvus.py
+++ b/src/rag/milvus.py
@@ -9,7 +9,7 @@ from typing import Any, Dict, Iterable, List, Optional, Sequence, Set
from langchain_milvus.vectorstores import Milvus as LangchainMilvus
from langchain_openai import OpenAIEmbeddings
from openai import OpenAI
-from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient
+from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient, utility
from src.config.loader import get_bool_env, get_int_env, get_str_env
from src.rag.retriever import Chunk, Document, Resource, Retriever
@@ -397,6 +397,36 @@ class MilvusRetriever(Retriever):
except Exception as e:
raise ConnectionError(f"Failed to connect to Milvus: {str(e)}")
+ def _connect_with_collection(self, collection_name) -> None:
+ """Create the underlying Milvus client (idempotent)."""
+ try:
+ # Check if using Milvus Lite (file-based) vs server-based Milvus
+ if self._is_milvus_lite():
+ # Use MilvusClient for Milvus Lite (local file database)
+ self.client = MilvusClient(self.uri)
+ # Ensure collection exists
+ self._ensure_collection_exists()
+ else:
+ connection_args = {
+ "uri": self.uri,
+ }
+ # Add user/password only if provided
+ if self.user:
+ connection_args["user"] = self.user
+ if self.password:
+ connection_args["password"] = self.password
+
+ # Create LangChain client (it will handle collection creation automatically)
+ self.client = LangchainMilvus(
+ embedding_function=self.embedding_model,
+ collection_name=collection_name,
+ connection_args=connection_args,
+ # optional (if collection already exists with different schema, be careful)
+ drop_old=False,
+ )
+ except Exception as e:
+ raise ConnectionError(f"Failed to connect to Milvus: {str(e)}")
+
def _is_milvus_lite(self) -> bool:
"""Return True if the URI points to a local Milvus Lite file.
Milvus Lite uses local file paths (often ``*.db``) without an HTTP/HTTPS
@@ -476,26 +506,12 @@ class MilvusRetriever(Retriever):
else:
# Use similarity_search_by_vector for lightweight listing.
# If a query is provided embed it; else use a zero vector.
- docs: Iterable[Any] = self.client.similarity_search(
- query,
- k=100,
- expr="source == 'examples'", # Limit to 100 results
- )
- for d in docs:
- meta = getattr(d, "metadata", {}) or {}
- # check if the resource is in the list of resources
- if resources and any(
- r.uri == meta.get(self.url_field, "")
- or r.uri == f"milvus://{meta.get(self.id_field, '')}"
- for r in resources
- ):
- continue
+ connections = utility.list_collections(using=f"{self.uri}-{self.user}")
+ for connection in connections:
resources.append(
Resource(
- uri=meta.get(self.url_field, "")
- or f"milvus://{meta.get(self.id_field, '')}",
- title=meta.get(self.title_field, "")
- or meta.get(self.id_field, "Unnamed"),
+ uri=f"milvus://{connection}",
+ title=connection,
description="Stored Milvus document",
)
)
@@ -621,38 +637,32 @@ class MilvusRetriever(Retriever):
else:
# For LangChain Milvus, use similarity search
- search_results = self.client.similarity_search_with_score(
- query=query, k=self.top_k
- )
+ if not resources:
+ return []
documents = {}
+ for resource in resources:
+ self._connect_with_collection(resource.title)
+ search_results = self.client.similarity_search_with_score(
+ query=query, k=self.top_k
+ )
- for doc, score in search_results:
- metadata = doc.metadata or {}
- doc_id = metadata.get(self.id_field, "")
- title = metadata.get(self.title_field, "")
- url = metadata.get(self.url_field, "")
- content = doc.page_content
-
- # Skip if resource filtering is requested and this doc is not in the list
- if resources:
- doc_in_resources = False
- for resource in resources:
- if (url and url in resource.uri) or doc_id in resource.uri:
- doc_in_resources = True
- break
- if not doc_in_resources:
- continue
-
- # Create or update document
- if doc_id not in documents:
- documents[doc_id] = Document(
- id=doc_id, url=url, title=title, chunks=[]
- )
+ for doc, score in search_results:
+ metadata = doc.metadata or {}
+ doc_id = metadata.get(self.id_field, "")
+ title = metadata.get(self.title_field, "")
+ url = metadata.get(self.url_field, "")
+ content = doc.page_content
+
+ # Create or update document
+ if doc_id not in documents:
+ documents[doc_id] = Document(
+ id=doc_id, url=url, title=title, chunks=[]
+ )
- # Add chunk to document
- chunk = Chunk(content=content, similarity=score)
- documents[doc_id].chunks.append(chunk)
+ # Add chunk to document
+ chunk = Chunk(content=content, similarity=score)
+ documents[doc_id].chunks.append(chunk)
return list(documents.values())
diff --git a/web/src/components/deer-flow/theme-provider-wrapper.tsx b/web/src/components/deer-flow/theme-provider-wrapper.tsx
index 6da0db8..1a99bcf 100644
--- a/web/src/components/deer-flow/theme-provider-wrapper.tsx
+++ b/web/src/components/deer-flow/theme-provider-wrapper.tsx
@@ -18,9 +18,9 @@ export function ThemeProviderWrapper({
return (
<ThemeProvider
attribute="class"
- defaultTheme={"dark"}
+ defaultTheme={"light"}
enableSystem={isChatPage}
- forcedTheme={isChatPage ? undefined : "dark"}
+ forcedTheme={isChatPage ? undefined : "light"}
disableTransitionOnChange
>
{children}
diff --git a/web/src/core/api/resolve-service-url.ts b/web/src/core/api/resolve-service-url.ts
index a87b777..d93e987 100644
--- a/web/src/core/api/resolve-service-url.ts
+++ b/web/src/core/api/resolve-service-url.ts
@@ -4,9 +4,13 @@
import { env } from "~/env";
export function resolveServiceURL(path: string) {
- let BASE_URL = env.NEXT_PUBLIC_API_URL ?? "http://localhost:8000/api/";
+ let BASE_URL = env.NEXT_PUBLIC_API_URL ?? "/api/";
if (!BASE_URL.endsWith("/")) {
BASE_URL += "/";
}
+
+ const origin = window.location.origin;
+ BASE_URL = origin + BASE_URL;
+
return new URL(path, BASE_URL).toString();
}

View File

@@ -0,0 +1,6 @@
# -*- coding: utf-8 -*-
from datamate.core.base_op import OPERATORS
OPERATORS.register_module(module_name='TestMapper',
module_path="ops.user.test_operator.process")

View File

@@ -0,0 +1,85 @@
name: '测试算子'
description: '这是一个测试算子。'
language: 'python'
vendor: 'huawei'
raw_id: 'TestMapper'
version: '1.0.0'
modal: 'text'
effect:
before: '使用方式很简单,只需要将代码放入Markdown文本中即可,富文本格式可直接复制表情😀使用。'
after: '使用方式很简单,只需要将代码放入Markdown文本中即可,富文本格式可直接复制表情使用。'
inputs: 'text'
outputs: 'text'
settings:
sliderTest:
name: '滑窗测试'
description: '这是一个测试滑窗。'
type: 'slider'
defaultVal: 0.5
min: 0
max: 1
step: 0.1
switchTest:
name: '开关测试'
description: '这是一个开关测试。'
type: 'switch'
defaultVal: 'true'
required: false
checkedLabel: '选中'
unCheckedLabel: '未选中'
radioTest:
name: '单选测试'
description: '这是一个单选测试。'
type: 'radio'
defaultVal: 'option1'
required: false
options:
- label: '选项1'
value: 'option1'
- label: '选项2'
value: 'option2'
selectTest:
name: '下拉框测试'
description: '这是一个下拉框测试。'
type: 'select'
defaultVal: 'option1'
required: false
options:
- label: '选项1'
value: 'option1'
- label: '选项2'
value: 'option2'
rangeTest:
name: '范围测试'
description: '这是一个范围框测试。'
type: 'range'
properties:
- name: 'rangeLeft'
type: 'inputNumber'
defaultVal: 100
min: 0
max: 10000
step: 1
- name: 'rangeRight'
type: 'inputNumber'
defaultVal: 8000
min: 0
max: 10000
step: 1
checkboxTest:
name: '多选框测试'
description: '这是一个多选框测试。'
type: 'checkbox'
defaultVal: 'option1,option2'
required: false
options:
- label: '选项1'
value: 'option1'
- label: '选项2'
value: 'option2'
inputTest:
name: '输入框测试'
description: '这是一个输入框测试。'
type: 'input'
defaultVal: 'Test Input'
required: false

View File

@@ -0,0 +1,10 @@
from typing import Dict, Any
from datamate.core.base_op import Mapper
class TestMapper(Mapper):
def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]:
sample[self.text_key] += "\n####################\n"
return sample

View File

@@ -1,49 +0,0 @@
{
"name": "text_length_filter",
"displayName": "文本长度过滤器",
"version": "1.0.0",
"author": "DataMate Team",
"description": "根据文本长度过滤数据,支持最小和最大长度限制",
"category": "数据清洗",
"type": "CUSTOM",
"inputs": [
{
"name": "input_data",
"type": "array",
"description": "输入文本数组",
"required": true
}
],
"outputs": [
{
"name": "filtered_data",
"type": "array",
"description": "过滤后的文本数组"
}
],
"parameters": [
{
"name": "min_length",
"type": "integer",
"description": "最小文本长度",
"default": 10,
"min": 0
},
{
"name": "max_length",
"type": "integer",
"description": "最大文本长度",
"default": 1000,
"min": 1
},
{
"name": "text_field",
"type": "string",
"description": "文本字段名称(如果输入是对象数组)",
"default": "text"
}
],
"tags": ["文本处理", "数据过滤", "长度检查"],
"documentation": "https://docs.datamate.com/operators/text-length-filter",
"repository": "https://github.com/datamate/operators/tree/main/text-length-filter"
}

View File

@@ -1,135 +0,0 @@
"""
文本长度过滤器算子
根据设定的最小和最大长度过滤文本数据
"""
import json
import logging
from typing import Dict, Any, List, Union
logger = logging.getLogger(__name__)
class TextLengthFilter:
"""文本长度过滤器算子"""
def __init__(self):
self.name = "text_length_filter"
self.version = "1.0.0"
def execute(self, config: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
"""执行文本长度过滤"""
logger.info(f"开始执行算子: {self.name}")
# 获取参数
parameters = config.get('parameters', {})
min_length = parameters.get('min_length', 10)
max_length = parameters.get('max_length', 1000)
text_field = parameters.get('text_field', 'text')
logger.info(f"过滤参数: min_length={min_length}, max_length={max_length}, text_field={text_field}")
# 验证参数
if min_length < 0:
raise ValueError("min_length must be >= 0")
if max_length < min_length:
raise ValueError("max_length must be >= min_length")
# 读取输入数据
input_path = context['input_path']
with open(input_path, 'r', encoding='utf-8') as f:
input_data = json.load(f)
if not isinstance(input_data, list):
raise ValueError("输入数据必须是数组格式")
logger.info(f"输入数据条数: {len(input_data)}")
# 执行过滤
filtered_data = []
stats = {
'total_input': len(input_data),
'too_short': 0,
'too_long': 0,
'filtered_out': 0,
'kept': 0
}
for i, item in enumerate(input_data):
try:
# 提取文本内容
if isinstance(item, str):
text = item
elif isinstance(item, dict) and text_field in item:
text = str(item[text_field])
else:
logger.warning(f"跳过无法处理的数据项 {i}: {type(item)}")
stats['filtered_out'] += 1
continue
# 检查长度
text_length = len(text)
if text_length < min_length:
stats['too_short'] += 1
stats['filtered_out'] += 1
elif text_length > max_length:
stats['too_long'] += 1
stats['filtered_out'] += 1
else:
filtered_data.append(item)
stats['kept'] += 1
# 进度报告
if (i + 1) % 1000 == 0:
progress = (i + 1) / len(input_data) * 100
logger.info(f"处理进度: {progress:.1f}% ({i + 1}/{len(input_data)})")
except Exception as e:
logger.warning(f"处理数据项 {i} 时出错: {e}")
stats['filtered_out'] += 1
continue
# 保存结果
output_path = context['output_path']
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(filtered_data, f, ensure_ascii=False, indent=2)
# 准备返回结果
result = {
'status': 'success',
'statistics': stats,
'filter_rate': stats['filtered_out'] / stats['total_input'] * 100 if stats['total_input'] > 0 else 0,
'output_path': output_path
}
logger.info(f"过滤完成: {stats}")
logger.info(f"过滤率: {result['filter_rate']:.2f}%")
return result
def validate_config(self, config: Dict[str, Any]) -> List[str]:
"""验证配置参数"""
errors = []
parameters = config.get('parameters', {})
min_length = parameters.get('min_length')
max_length = parameters.get('max_length')
if min_length is not None and not isinstance(min_length, int):
errors.append("min_length must be an integer")
if max_length is not None and not isinstance(max_length, int):
errors.append("max_length must be an integer")
if min_length is not None and min_length < 0:
errors.append("min_length must be >= 0")
if min_length is not None and max_length is not None and max_length < min_length:
errors.append("max_length must be >= min_length")
return errors
def create_operator():
"""算子工厂函数"""
return TextLengthFilter()