From 651c9b4d0ac0f6615697b71623efb8a552d84dd4 Mon Sep 17 00:00:00 2001 From: xs Date: Fri, 16 Jan 2026 15:08:39 +0800 Subject: [PATCH] =?UTF-8?q?1.9.0=E5=90=8E=E7=AB=AF=20fix(AI):=E5=AE=8C?= =?UTF-8?q?=E6=88=90AI=E5=90=91=E9=87=8F=E6=95=B0=E6=8D=AE=E5=BA=93?= =?UTF-8?q?=E8=BF=9E=E6=8E=A5=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ai/controller/AiAssistantController.java | 2 - .../ai/process/config/EmbeddingConfig.java | 111 ++ .../process/enums/EmbeddingProviderEnum.java | 37 + .../impl/BaseAIProviderProcessor.java | 3 - .../processor/impl/TencentLkeProcessor.java | 3 +- .../service/impl/DatabaseMetaServiceImpl.java | 290 ++--- .../ai/vectordb/config/VectorDBConfig.java | 1 - .../service/impl/MilvusServiceImpl.java | 1042 +++++++++-------- 8 files changed, 864 insertions(+), 625 deletions(-) create mode 100644 ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/process/config/EmbeddingConfig.java create mode 100644 ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/process/enums/EmbeddingProviderEnum.java diff --git a/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/controller/AiAssistantController.java b/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/controller/AiAssistantController.java index e77b9f5d..4ba98aeb 100644 --- a/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/controller/AiAssistantController.java +++ b/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/controller/AiAssistantController.java @@ -16,8 +16,6 @@ import org.dromara.ai.domain.bo.AiModelBo; import org.dromara.ai.domain.vo.AiChatMessageVo; import org.dromara.ai.domain.vo.AiModelVo; import org.dromara.ai.process.dto.AIFillFormRequest; -import org.dromara.ai.test.vectorization.factory.EmbeddingServiceFactory; -import org.dromara.ai.test.vectorization.process.IEmbeddingProcessor; import org.dromara.ai.process.dto.AIMessage; import org.dromara.ai.process.dto.AIRequest; import org.dromara.ai.process.dto.AIResponse; diff --git a/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/process/config/EmbeddingConfig.java b/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/process/config/EmbeddingConfig.java new file mode 100644 index 00000000..0e51ef50 --- /dev/null +++ b/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/process/config/EmbeddingConfig.java @@ -0,0 +1,111 @@ +package org.dromara.ai.process.config; + +import lombok.Data; +import org.dromara.ai.process.enums.EmbeddingProviderEnum; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.context.annotation.Configuration; + +/** + * @author xins + * @description 向量化配置类 + * @date 2025/8/5 16:30 + */ + +@Data +@Configuration +@ConfigurationProperties(prefix = "embedding") +public class EmbeddingConfig { + /** + * 默认向量化服务提供商 + */ + private EmbeddingProviderEnum defaultProvider = EmbeddingProviderEnum.TENCENTLKE; + + /** + * 腾讯云向量化服务配置 + */ + private TencentLKE tencentLke = new TencentLKE(); + + /** + * OpenAI向量化服务配置 + */ + private Openai openai = new Openai(); + + /** + * 本地模型向量化服务配置 + */ + private Local local = new Local(); + + /** + * 腾讯云配置内部类 + */ + @Data + public static class TencentLKE { + /** + * 腾讯云API密钥ID + */ + private String secretId; + + /** + * 腾讯云API密钥Key + */ + private String secretKey; + + /** + * 服务区域,默认ap-guangzhou + */ + private String region = "ap-guangzhou"; + + /** + * 使用的模型名称,默认lke-text-embedding-v1 + */ + private String model = "lke-text-embedding-v1"; + } + + /** + * OpenAI配置内部类 + */ + @Data + public static class Openai { + /** + * OpenAI API密钥 + */ + private String apiKey; + + /** + * 使用的模型名称,默认text-embedding-ada-002 + */ + private String model = "text-embedding-ada-002"; + + /** + * 组织ID(可选) + */ + private String organization; + + /** + * 代理主机(可选) + */ + private String proxyHost; + + /** + * 代理端口(可选) + */ + private Integer proxyPort; + } + + /** + * 本地模型配置内部类 + */ + @Data + public static class Local { + private String baseUrl = "http://localhost:11434"; + /** + * 本地模型,默认llama2 + */ + private String model = "llama2"; + + /** + * 超时时间,默认30 + */ + private Integer timeout = 30; + } +} diff --git a/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/process/enums/EmbeddingProviderEnum.java b/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/process/enums/EmbeddingProviderEnum.java new file mode 100644 index 00000000..2f3c512e --- /dev/null +++ b/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/process/enums/EmbeddingProviderEnum.java @@ -0,0 +1,37 @@ +package org.dromara.ai.process.enums; + +/** + * @Author xins + * @Date 2025/8/5 14:28 + * @Description:向量模型类型 + */ +public enum EmbeddingProviderEnum { + TENCENTLKE("tencentlke", "Tencent LKE"), + OPENAI("openai", "OpenAI"), + OLLAMA("ollma", "Local Model"); + + private final String code; + private final String name; + + EmbeddingProviderEnum(String code, String name) { + this.code = code; + this.name = name; + } + + public String getCode() { + return code; + } + + public String getName() { + return name; + } + + public static EmbeddingProviderEnum fromString(String code) { + for (EmbeddingProviderEnum provider : values()) { + if (provider.code.equalsIgnoreCase(code)) { + return provider; + } + } + throw new IllegalArgumentException("Unknown provider code: " + code); + } +} diff --git a/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/process/provider/processor/impl/BaseAIProviderProcessor.java b/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/process/provider/processor/impl/BaseAIProviderProcessor.java index cd348f8d..95083069 100644 --- a/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/process/provider/processor/impl/BaseAIProviderProcessor.java +++ b/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/process/provider/processor/impl/BaseAIProviderProcessor.java @@ -19,12 +19,9 @@ import org.dromara.ai.process.dto.AIMessage; import org.dromara.ai.process.dto.AIRequest; import org.dromara.ai.process.dto.AIResponse; import org.dromara.ai.process.dto.TokenUsage; -import org.dromara.ai.process.enums.AIChatMessageTypeEnum; import org.dromara.ai.process.provider.processor.IUnifiedAIProviderProcessor; -import org.dromara.ai.test.ChatRequest; import org.dromara.common.constant.HwMomAiConstants; import org.dromara.common.encrypt.utils.EncryptUtils; -import org.dromara.common.satoken.utils.LoginHelper; import org.dromara.system.api.model.LoginUser; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.http.HttpHeaders; diff --git a/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/process/provider/processor/impl/TencentLkeProcessor.java b/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/process/provider/processor/impl/TencentLkeProcessor.java index 31657acc..4c1ccabc 100644 --- a/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/process/provider/processor/impl/TencentLkeProcessor.java +++ b/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/process/provider/processor/impl/TencentLkeProcessor.java @@ -11,12 +11,11 @@ import com.tencentcloudapi.lkeap.v20240522.models.GetEmbeddingResponse; import com.tencentcloudapi.lkeap.v20240522.models.Usage; import lombok.extern.slf4j.Slf4j; import org.dromara.ai.domain.dto.StreamResult; +import org.dromara.ai.process.config.EmbeddingConfig; import org.dromara.ai.process.provider.processor.utils.ProcessorUtils; -import org.dromara.ai.test.vectorization.config.EmbeddingConfig; import org.dromara.ai.process.dto.AIRequest; import org.dromara.ai.process.dto.AIResponse; import org.dromara.ai.process.enums.AIProviderEnum; -import org.dromara.ai.process.provider.processor.IUnifiedAIProviderProcessor; import org.dromara.common.encrypt.utils.EncryptUtils; import org.dromara.system.api.model.LoginUser; import org.springframework.stereotype.Component; diff --git a/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/service/impl/DatabaseMetaServiceImpl.java b/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/service/impl/DatabaseMetaServiceImpl.java index 3cfc469a..ae03a633 100644 --- a/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/service/impl/DatabaseMetaServiceImpl.java +++ b/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/service/impl/DatabaseMetaServiceImpl.java @@ -1,145 +1,145 @@ -package org.dromara.ai.service.impl; - -import org.dromara.ai.mapper.SQLServerDatabaseMetaMapper; -import org.dromara.ai.test.ChatRequest; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.stereotype.Service; - -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - -/** - * @Author xins - * @Date 2025/7/8 9:25 - * @Description: - */ -@Service -public class DatabaseMetaServiceImpl { - - @Autowired - private SQLServerDatabaseMetaMapper databaseMetaMapper; - - /** - * 获取格式化的数据库结构描述 - */ - public String getFormattedSchema() { - List> allColumns = databaseMetaMapper.getAllTablesStructure(); - - Map>> tables = allColumns.stream() - .collect(Collectors.groupingBy( - col -> (String) col.get("tableName"), - LinkedHashMap::new, - Collectors.toList() - )); - - StringBuilder sb = new StringBuilder("SQL Server 数据库结构:\n\n"); - - for (Map.Entry>> entry : tables.entrySet()) { - String tableName = entry.getKey(); - List primaryKeys = databaseMetaMapper.getPrimaryKeys(tableName); - - sb.append("- 表名: ").append(tableName).append("\n"); - sb.append(" 主键: ").append(String.join(", ", primaryKeys)).append("\n"); - sb.append(" 字段:\n"); - - for (Map column : entry.getValue()) { - sb.append(" * ") - .append(column.get("columnName")) - .append(" (").append(column.get("dataType")); - - if (column.get("maxLength") != null && (short)column.get("maxLength") > 0) { - sb.append(", 长度: ").append(column.get("maxLength")); - } - if (column.get("precision") != null && (short)column.get("precision") > 0) { - sb.append(", 精度: ").append(column.get("precision")); - } - if (column.get("scale") != null && (short)column.get("scale") > 0) { - sb.append(", 小数位: ").append(column.get("scale")); - } - - sb.append(column.get("nullable").equals(1) ? ", 可空" : ", 非空"); - - if (!((String)column.get("defaultValue")).isEmpty()) { - sb.append(", 默认值: ").append(column.get("defaultValue")); - } - - if (!((String)column.get("description")).isEmpty()) { - sb.append(", 描述: ").append(column.get("description")); - } - - sb.append(")\n"); - } - sb.append("\n"); - } - - return sb.toString(); - } - - - public String generateSQL(String naturalLanguageQuery) { - // 1. 获取数据库结构 - String schemaDescription = this.getFormattedSchema(); - - // 2. 构建 AI 提示 - String prompt = String.format( - "你是一个专业的 SQL Server 数据库专家。根据以下数据库结构:\n\n%s\n\n" + - "请将以下自然语言查询转换为优化的 SQL Server T-SQL 语句:\n" + - "---\n%s\n---\n\n" + - "要求:\n" + - "1. 只返回 SQL 语句,不要包含解释\n" + - "2. 使用 SQL Server 特有的语法(如 TOP 而不是 LIMIT)\n" + - "3. 考虑性能优化\n" + - "4. 使用合适的索引提示(如果需要)\n" + - "5. 包含必要的 WITH(NOLOCK) 提示(适用于高并发环境)\n" + - "6. 使用 ANSI 标准的 JOIN 语法", - schemaDescription, naturalLanguageQuery - ); - - ChatRequest chatRequest = new ChatRequest(); - chatRequest.setPrompt(prompt); - - return prompt; - - } - - - /** - * 获取表的结构信息 - */ - public Map getTableDetail(String tableName) { - Map result = new LinkedHashMap<>(); - result.put("tableName", tableName); - result.put("primaryKeys", databaseMetaMapper.getPrimaryKeys(tableName)); - result.put("columns", databaseMetaMapper.getTableStructure(tableName)); - return result; - } - - - - - - public String generateSQL1(String naturalLanguageQuery) { - // 1. 获取数据库结构 - String schemaDescription = this.getFormattedSchema(); - - // 2. 构建 AI 提示 - String prompt = String.format( - "你是一个专业的 SQL Server 数据库专家。根据以下数据库结构:\n\n%s\n\n" + - "请将以下自然语言查询转换为优化的 SQL Server T-SQL 语句:\n" + - "---\n%s\n---\n\n" + - "要求:\n" + - "1. 只返回 SQL 语句,不要包含解释\n" + - "2. 使用 SQL Server 特有的语法(如 TOP 而不是 LIMIT)\n" + - "3. 考虑性能优化\n" + - "4. 使用合适的索引提示(如果需要)\n" + - "5. 包含必要的 WITH(NOLOCK) 提示(适用于高并发环境)\n" + - "6. 使用 ANSI 标准的 JOIN 语法", - schemaDescription, naturalLanguageQuery - ); - - return prompt; - } - -} +//package org.dromara.ai.service.impl; +// +//import org.dromara.ai.mapper.SQLServerDatabaseMetaMapper; +//import org.dromara.ai.test.ChatRequest; +//import org.springframework.beans.factory.annotation.Autowired; +//import org.springframework.stereotype.Service; +// +//import java.util.LinkedHashMap; +//import java.util.List; +//import java.util.Map; +//import java.util.stream.Collectors; +// +///** +// * @Author xins +// * @Date 2025/7/8 9:25 +// * @Description: +// */ +//@Service +//public class DatabaseMetaServiceImpl { +// +// @Autowired +// private SQLServerDatabaseMetaMapper databaseMetaMapper; +// +// /** +// * 获取格式化的数据库结构描述 +// */ +// public String getFormattedSchema() { +// List> allColumns = databaseMetaMapper.getAllTablesStructure(); +// +// Map>> tables = allColumns.stream() +// .collect(Collectors.groupingBy( +// col -> (String) col.get("tableName"), +// LinkedHashMap::new, +// Collectors.toList() +// )); +// +// StringBuilder sb = new StringBuilder("SQL Server 数据库结构:\n\n"); +// +// for (Map.Entry>> entry : tables.entrySet()) { +// String tableName = entry.getKey(); +// List primaryKeys = databaseMetaMapper.getPrimaryKeys(tableName); +// +// sb.append("- 表名: ").append(tableName).append("\n"); +// sb.append(" 主键: ").append(String.join(", ", primaryKeys)).append("\n"); +// sb.append(" 字段:\n"); +// +// for (Map column : entry.getValue()) { +// sb.append(" * ") +// .append(column.get("columnName")) +// .append(" (").append(column.get("dataType")); +// +// if (column.get("maxLength") != null && (short)column.get("maxLength") > 0) { +// sb.append(", 长度: ").append(column.get("maxLength")); +// } +// if (column.get("precision") != null && (short)column.get("precision") > 0) { +// sb.append(", 精度: ").append(column.get("precision")); +// } +// if (column.get("scale") != null && (short)column.get("scale") > 0) { +// sb.append(", 小数位: ").append(column.get("scale")); +// } +// +// sb.append(column.get("nullable").equals(1) ? ", 可空" : ", 非空"); +// +// if (!((String)column.get("defaultValue")).isEmpty()) { +// sb.append(", 默认值: ").append(column.get("defaultValue")); +// } +// +// if (!((String)column.get("description")).isEmpty()) { +// sb.append(", 描述: ").append(column.get("description")); +// } +// +// sb.append(")\n"); +// } +// sb.append("\n"); +// } +// +// return sb.toString(); +// } +// +// +// public String generateSQL(String naturalLanguageQuery) { +// // 1. 获取数据库结构 +// String schemaDescription = this.getFormattedSchema(); +// +// // 2. 构建 AI 提示 +// String prompt = String.format( +// "你是一个专业的 SQL Server 数据库专家。根据以下数据库结构:\n\n%s\n\n" + +// "请将以下自然语言查询转换为优化的 SQL Server T-SQL 语句:\n" + +// "---\n%s\n---\n\n" + +// "要求:\n" + +// "1. 只返回 SQL 语句,不要包含解释\n" + +// "2. 使用 SQL Server 特有的语法(如 TOP 而不是 LIMIT)\n" + +// "3. 考虑性能优化\n" + +// "4. 使用合适的索引提示(如果需要)\n" + +// "5. 包含必要的 WITH(NOLOCK) 提示(适用于高并发环境)\n" + +// "6. 使用 ANSI 标准的 JOIN 语法", +// schemaDescription, naturalLanguageQuery +// ); +// +// ChatRequest chatRequest = new ChatRequest(); +// chatRequest.setPrompt(prompt); +// +// return prompt; +// +// } +// +// +// /** +// * 获取表的结构信息 +// */ +// public Map getTableDetail(String tableName) { +// Map result = new LinkedHashMap<>(); +// result.put("tableName", tableName); +// result.put("primaryKeys", databaseMetaMapper.getPrimaryKeys(tableName)); +// result.put("columns", databaseMetaMapper.getTableStructure(tableName)); +// return result; +// } +// +// +// +// +// +// public String generateSQL1(String naturalLanguageQuery) { +// // 1. 获取数据库结构 +// String schemaDescription = this.getFormattedSchema(); +// +// // 2. 构建 AI 提示 +// String prompt = String.format( +// "你是一个专业的 SQL Server 数据库专家。根据以下数据库结构:\n\n%s\n\n" + +// "请将以下自然语言查询转换为优化的 SQL Server T-SQL 语句:\n" + +// "---\n%s\n---\n\n" + +// "要求:\n" + +// "1. 只返回 SQL 语句,不要包含解释\n" + +// "2. 使用 SQL Server 特有的语法(如 TOP 而不是 LIMIT)\n" + +// "3. 考虑性能优化\n" + +// "4. 使用合适的索引提示(如果需要)\n" + +// "5. 包含必要的 WITH(NOLOCK) 提示(适用于高并发环境)\n" + +// "6. 使用 ANSI 标准的 JOIN 语法", +// schemaDescription, naturalLanguageQuery +// ); +// +// return prompt; +// } +// +//} diff --git a/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/vectordb/config/VectorDBConfig.java b/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/vectordb/config/VectorDBConfig.java index e288978a..bf15427e 100644 --- a/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/vectordb/config/VectorDBConfig.java +++ b/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/vectordb/config/VectorDBConfig.java @@ -1,7 +1,6 @@ package org.dromara.ai.vectordb.config; import lombok.Data; -import org.dromara.ai.test.vectorization.enums.EmbeddingProviderEnum; import org.dromara.ai.vectordb.enums.VectorDBTypeEnum; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.context.annotation.Configuration; diff --git a/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/vectordb/service/impl/MilvusServiceImpl.java b/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/vectordb/service/impl/MilvusServiceImpl.java index 2021e03a..4bcc3020 100644 --- a/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/vectordb/service/impl/MilvusServiceImpl.java +++ b/ruoyi-modules/hwmom-ai/src/main/java/org/dromara/ai/vectordb/service/impl/MilvusServiceImpl.java @@ -1,202 +1,400 @@ package org.dromara.ai.vectordb.service.impl; -import io.milvus.param.partition.CreatePartitionParam; -import io.milvus.param.partition.DropPartitionParam; -import io.milvus.param.partition.HasPartitionParam; -import io.milvus.param.partition.ReleasePartitionsParam; -import io.milvus.response.QueryResultsWrapper; -import org.apache.tika.utils.StringUtils; -import org.dromara.ai.vectordb.config.VectorDBConfig; -import org.dromara.ai.vectordb.service.IVectorDBService; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.stereotype.Service; import io.milvus.client.MilvusServiceClient; import io.milvus.grpc.*; import io.milvus.param.*; import io.milvus.param.collection.*; import io.milvus.param.dml.*; +import io.milvus.param.highlevel.collection.ListCollectionsParam; +import io.milvus.param.highlevel.collection.response.ListCollectionsResponse; import io.milvus.param.index.*; +import io.milvus.param.partition.*; +import io.milvus.response.QueryResultsWrapper; import io.milvus.response.SearchResultsWrapper; import com.fasterxml.jackson.databind.ObjectMapper; import lombok.extern.slf4j.Slf4j; +import org.apache.tika.utils.StringUtils; +import org.dromara.ai.vectordb.config.VectorDBConfig; +import org.dromara.ai.vectordb.service.IVectorDBService; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Lazy; +import org.springframework.stereotype.Service; import javax.annotation.PostConstruct; +import javax.annotation.PreDestroy; import java.util.*; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; /** * @Author xins * @Date 2025/7/31 10:18 - * @Description:Milvus向量数据库服务实现 + * @Description:Milvus向量数据库服务实现(增强版,支持自动重连) */ @Service @Slf4j +@Lazy public class MilvusServiceImpl implements IVectorDBService { - private MilvusServiceClient milvusClient; + + private volatile MilvusServiceClient milvusClient; + private final AtomicBoolean isConnected = new AtomicBoolean(false); + private final AtomicBoolean isInitializing = new AtomicBoolean(false); + private final ScheduledExecutorService connectionScheduler = Executors.newSingleThreadScheduledExecutor(); private final ObjectMapper objectMapper = new ObjectMapper(); + // 连接重试配置 + private static final int INITIAL_RETRY_DELAY = 5; // 初始重试延迟(秒) + private static final int MAX_RETRY_DELAY = 300; // 最大重试延迟(秒) + private static final int CONNECTION_TIMEOUT = 10; // 连接超时(秒) + private static final int HEALTH_CHECK_INTERVAL = 30; // 健康检查间隔(秒) + @Autowired private VectorDBConfig vectorDBConfig; - // 向量维度,根据你的模型确定 -// private static final int VECTOR_DIMENSION = 768; -// public static final String host = "1.13.177.47"; -// public static final int port = 19530; -// public static final String collectionNamePrefix = "hwmom_embeddings_"; - public static final String partitionNamePrefix = "hwmom_partition_"; - - private static final String primaryKeyFieldName = "pid"; - private static final String vectorFieldName = "kvector"; - private static final String knowledgeBaseIdFieldName = "kid"; - private static final String contentIdFieldName = "contentId"; - private static final String fragmentIdFieldName = "fid"; - private static final String metadataFieldName = "metadata"; + // 字段名常量 + private static final String PRIMARY_KEY_FIELD_NAME = "pid"; + private static final String VECTOR_FIELD_NAME = "kvector"; + private static final String KNOWLEDGE_BASE_ID_FIELD_NAME = "kid"; + private static final String CONTENT_ID_FIELD_NAME = "contentId"; + private static final String FRAGMENT_ID_FIELD_NAME = "fid"; + private static final String METADATA_FIELD_NAME = "metadata"; + private static final String PARTITION_NAME_PREFIX = "hwmom_partition_"; + // 已加载到内存的集合缓存 + private final Set loadedCollections = Collections.synchronizedSet(new HashSet<>()); @PostConstruct public void init() { - // 初始化Milvus连接 - milvusClient = new MilvusServiceClient( - ConnectParam.newBuilder() - .withHost(vectorDBConfig.getMilvus().getHost()) - .withPort(vectorDBConfig.getMilvus().getPort()) - .build() - ); + log.info("Starting Milvus service initialization..."); + startConnectionManager(); + } - // 检查集合是否存在,不存在则创建 -// if (!hasCollection()) { -// createCollection(); -// createIndex(); -// } + @PreDestroy + public void destroy() { + log.info("Shutting down Milvus service..."); + connectionScheduler.shutdownNow(); + closeConnection(); + log.info("Milvus service shutdown completed"); + } - log.info("Milvus initialized successfully"); + /** + * 启动连接管理器 + */ + private void startConnectionManager() { + // 延迟启动,等待其他服务初始化 + connectionScheduler.schedule(() -> { + if (!isInitializing.getAndSet(true)) { + attemptConnection(); + scheduleHealthCheck(); + log.info("Milvus connection manager started"); + } + }, 10, TimeUnit.SECONDS); + } + + /** + * 尝试连接Milvus + */ + private void attemptConnection() { + int attemptCount = 0; + int delay = INITIAL_RETRY_DELAY; + + while (!isConnected.get() && !Thread.currentThread().isInterrupted()) { + attemptCount++; + try { + log.info("Attempting to connect to Milvus (attempt {}) at {}:{}", + attemptCount, + vectorDBConfig.getMilvus().getHost(), + vectorDBConfig.getMilvus().getPort()); + + // 关闭旧连接 + closeConnection(); + + // 创建新连接 + milvusClient = new MilvusServiceClient( + ConnectParam.newBuilder() + .withHost(vectorDBConfig.getMilvus().getHost()) + .withPort(vectorDBConfig.getMilvus().getPort()) + .withConnectTimeout(CONNECTION_TIMEOUT,TimeUnit.SECONDS) + .build() + ); + + // 测试连接 + testConnection(); + + isConnected.set(true); + log.info("✅ Milvus connected successfully after {} attempts", attemptCount); + + // 重新加载之前已加载的集合 + reloadCollections(); + + return; + + } catch (Exception e) { + log.warn("⚠️ Failed to connect to Milvus (attempt {}): {}", attemptCount, e.getMessage()); + + // 指数退避重试 + delay = Math.min(delay * 2, MAX_RETRY_DELAY); + log.info("Will retry connection in {} seconds...", delay); + + try { + TimeUnit.SECONDS.sleep(delay); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + break; + } + } + } + } + + /** + * 测试连接是否正常 + */ + private void testConnection() { + try { + R response = milvusClient.listCollections( + ListCollectionsParam.newBuilder().build() + ); + + if (response.getStatus() != R.Status.Success.getCode()) { + throw new RuntimeException("Milvus connection test failed with status: " + response.getStatus()); + } + + } catch (Exception e) { + throw new RuntimeException("Milvus connection test failed: " + e.getMessage(), e); + } + } + + /** + * 安排健康检查 + */ + private void scheduleHealthCheck() { + connectionScheduler.scheduleAtFixedRate(() -> { + if (isConnected.get()) { + try { + testConnection(); + log.debug("Milvus health check passed"); + } catch (Exception e) { + log.error("❌ Milvus health check failed: {}", e.getMessage()); + isConnected.set(false); + loadedCollections.clear(); + attemptConnection(); + } + } else { + log.info("Milvus is disconnected, attempting to reconnect..."); + attemptConnection(); + } + }, HEALTH_CHECK_INTERVAL, HEALTH_CHECK_INTERVAL, TimeUnit.SECONDS); + } + + /** + * 重新加载集合到内存 + */ + private void reloadCollections() { + if (!loadedCollections.isEmpty()) { + log.info("Reloading {} collections to memory", loadedCollections.size()); + for (String collectionName : loadedCollections) { + try { + loadCollectionToMemory(collectionName); + } catch (Exception e) { + log.warn("Failed to reload collection {}: {}", collectionName, e.getMessage()); + } + } + } + } + + /** + * 检查连接状态 + */ + private void checkConnection() { + if (!isConnected.get()) { + throw new IllegalStateException("Milvus is not connected. Please check the database status."); + } + } + + /** + * 获取客户端(带连接检查) + */ + private MilvusServiceClient getClient() { + checkConnection(); + return milvusClient; + } + + /** + * 关闭连接 + */ + private void closeConnection() { + if (milvusClient != null) { + try { + milvusClient.close(); + log.debug("Milvus connection closed"); + } catch (Exception e) { + log.warn("Error closing Milvus connection: {}", e.getMessage()); + } finally { + milvusClient = null; + } + } + } + + /** + * 加载集合到内存 + */ + private void loadCollectionToMemory(String collectionName) { + try { + LoadCollectionParam loadParam = LoadCollectionParam.newBuilder() + .withCollectionName(collectionName) + .build(); + + R loadResponse = getClient().loadCollection(loadParam); + + if (loadResponse.getStatus() == R.Status.Success.getCode()) { + loadedCollections.add(collectionName); + log.debug("Collection {} loaded to memory", collectionName); + } else { + log.warn("Failed to load collection {} to memory: {}", collectionName, loadResponse.getMessage()); + } + } catch (Exception e) { + log.warn("Error loading collection {} to memory: {}", collectionName, e.getMessage()); + } } private boolean hasCollection(String collectionName) { - R response = milvusClient.hasCollection( - HasCollectionParam.newBuilder() - .withCollectionName(collectionName) - .build() - ); + try { + checkConnection(); + R response = getClient().hasCollection( + HasCollectionParam.newBuilder() + .withCollectionName(collectionName) + .build() + ); - if (response.getStatus() != R.Status.Success.getCode()) { - log.error("Failed to check collection existence: {}", response.getMessage()); + return response.getStatus() == R.Status.Success.getCode() && + Boolean.TRUE.equals(response.getData()); + + } catch (Exception e) { + log.warn("Failed to check collection existence for {}: {}", collectionName, e.getMessage()); return false; } - - return response.getData(); } /** - * @param knowledgeBaseId 知识库ID - * @description 创建支持库Collection,每个知识库创建一个,在创建知识库时创建 + * 创建知识库Collection */ @Override public void createCollection(Long knowledgeBaseId) { - // 1. 检查集合是否已存在 String collectionName = vectorDBConfig.getMilvus().getCollectionNamePrefix() + knowledgeBaseId; - if (hasCollection(collectionName)) { - throw new RuntimeException("Collection already exists: " + collectionName); + + try { + checkConnection(); + + // 检查集合是否已存在 + if (hasCollection(collectionName)) { + log.warn("Collection already exists: {}", collectionName); + return; + } + + // 定义字段 + FieldType primaryKeyField = FieldType.newBuilder() + .withName(PRIMARY_KEY_FIELD_NAME) + .withDataType(DataType.Int64) + .withPrimaryKey(true) + .withAutoID(true) + .build(); + + FieldType vectorField = FieldType.newBuilder() + .withName(VECTOR_FIELD_NAME) + .withDataType(DataType.FloatVector) + .withDimension(vectorDBConfig.getMilvus().getVectorDimension()) + .build(); + + FieldType knowledgeBaseIdField = FieldType.newBuilder() + .withName(KNOWLEDGE_BASE_ID_FIELD_NAME) + .withDataType(DataType.VarChar) + .withMaxLength(20) + .build(); + + FieldType contentIdField = FieldType.newBuilder() + .withName(CONTENT_ID_FIELD_NAME) + .withDataType(DataType.VarChar) + .withMaxLength(20) + .build(); + + FieldType fragmentIdField = FieldType.newBuilder() + .withName(FRAGMENT_ID_FIELD_NAME) + .withDataType(DataType.VarChar) + .withMaxLength(20) + .build(); + + FieldType metadataField = FieldType.newBuilder() + .withName(METADATA_FIELD_NAME) + .withDataType(DataType.VarChar) + .withMaxLength(65535) + .build(); + + // 创建集合 + CreateCollectionParam createParam = CreateCollectionParam.newBuilder() + .withCollectionName(collectionName) + .withDescription("Knowledge base: " + knowledgeBaseId) + .withFieldTypes(Arrays.asList( + primaryKeyField, vectorField, knowledgeBaseIdField, + contentIdField, fragmentIdField, metadataField)) + .build(); + + R response = getClient().createCollection(createParam); + + if (response.getStatus() != R.Status.Success.getCode()) { + throw new RuntimeException("Failed to create collection: " + response.getMessage()); + } + + log.info("Collection created successfully: {}", collectionName); + + // 创建索引 + createIndex(collectionName, VECTOR_FIELD_NAME); + + // 加载到内存 + loadCollectionToMemory(collectionName); + + } catch (Exception e) { + log.error("Failed to create collection {}: {}", collectionName, e.getMessage(), e); + throw new RuntimeException("Failed to create Milvus collection: " + e.getMessage(), e); } - - - // 定义字段 - //主键 - FieldType primaryKeyField = FieldType.newBuilder() - .withName(primaryKeyFieldName) - .withDataType(DataType.Int64) - .withPrimaryKey(true) - .withAutoID(true) - .build(); - - //向量字段 - FieldType vectorField = FieldType.newBuilder() - .withName(vectorFieldName) - .withDataType(DataType.FloatVector) - .withDimension(vectorDBConfig.getMilvus().getVectorDimension()) - .build(); - - //知识库ID - FieldType knowledgeBaseIdField = FieldType.newBuilder() - .withName(knowledgeBaseIdFieldName) - .withDataType(DataType.VarChar) - .withMaxLength(20) - .build(); - - //知识库内容ID - FieldType contentIdField = FieldType.newBuilder() - .withName(contentIdFieldName) - .withDataType(DataType.VarChar) - .withMaxLength(20) - .build(); - - //知识库内容片段ID - FieldType fragmentIdField = FieldType.newBuilder() - .withName(fragmentIdFieldName) - .withDataType(DataType.VarChar) - .withMaxLength(20) - .build(); - - //具体内容分段的源数据 - FieldType metadataField = FieldType.newBuilder() - .withName(metadataFieldName) - .withDataType(DataType.VarChar) - .withMaxLength(65535) - .build(); - - // 创建集合 - CreateCollectionParam createParam = CreateCollectionParam.newBuilder() - .withCollectionName(collectionName) - .withDescription("knowledge base") - .withFieldTypes(Arrays.asList(primaryKeyField, vectorField, knowledgeBaseIdField, contentIdField, fragmentIdField, metadataField)) - .build(); - - R response = milvusClient.createCollection(createParam); - - if (response.getStatus() != R.Status.Success.getCode()) { - log.error("Failed to create collection: {}", response.getMessage()); - throw new RuntimeException("Failed to create Milvus collection"); - } - - this.createIndex(collectionName, vectorFieldName); } /** - * @param collectionName - * @return null - * @description 创建向量的索引 + * 创建向量索引 */ private void createIndex(String collectionName, String vectorFieldName) { - // 创建向量索引 - IndexType indexType = IndexType.IVF_FLAT; - String indexParam = "{\"nlist\":1024}"; + try { + IndexType indexType = IndexType.IVF_FLAT; + String indexParam = "{\"nlist\":1024}"; - CreateIndexParam createIndexParam = CreateIndexParam.newBuilder() - .withCollectionName(collectionName) - .withFieldName(vectorFieldName) - .withIndexType(indexType) -// .withMetricType(MetricType.L2) - .withMetricType(MetricType.COSINE) - .withExtraParam(indexParam) - .withSyncMode(Boolean.FALSE) - .build(); + CreateIndexParam createIndexParam = CreateIndexParam.newBuilder() + .withCollectionName(collectionName) + .withFieldName(vectorFieldName) + .withIndexType(indexType) + .withMetricType(MetricType.COSINE) + .withExtraParam(indexParam) + .withSyncMode(false) + .build(); - R response = milvusClient.createIndex(createIndexParam); + R response = getClient().createIndex(createIndexParam); - if (response.getStatus() != R.Status.Success.getCode()) { - log.error("Failed to create index: {}", response.getMessage()); - throw new RuntimeException("Failed to create Milvus index"); + if (response.getStatus() != R.Status.Success.getCode()) { + log.error("Failed to create index for collection {}: {}", collectionName, response.getMessage()); + throw new RuntimeException("Failed to create Milvus index"); + } + + log.info("Index created successfully for collection: {}", collectionName); + + } catch (Exception e) { + log.error("Failed to create index for collection {}: {}", collectionName, e.getMessage(), e); + throw new RuntimeException("Failed to create index: " + e.getMessage(), e); } } - /** - * @param knowledgeBaseId 知识库ID - * @param contentId 知识库内容ID - * @param chunkList 知识库内容分段的源数据 - * @param vectorList 知识库内容分段的向量 - * @param fidList 知识库内容分段的ID - * @description 插入知识库内容的向量 + * 插入知识库向量 */ @Override public void insertKnowledgeEmbedding(Long knowledgeBaseId, Long contentId, @@ -204,73 +402,85 @@ public class MilvusServiceImpl implements IVectorDBService { List fidList) { String collectionName = vectorDBConfig.getMilvus().getCollectionNamePrefix() + knowledgeBaseId; - // 检查集合是否存在 - HasCollectionParam hasCollectionParam = HasCollectionParam.newBuilder() - .withCollectionName(collectionName) - .build(); - R booleanR = milvusClient.hasCollection(hasCollectionParam); + try { + checkConnection(); - if (booleanR.getStatus() == R.Status.Success.getCode()) { - boolean collectionExists = booleanR.getData().booleanValue(); - if (!collectionExists) {//集合不存在时创建集合 - this.createCollection(knowledgeBaseId); + // 检查并创建集合(如果不存在) + if (!hasCollection(collectionName)) { + log.info("Collection {} does not exist, creating...", collectionName); + createCollection(knowledgeBaseId); } - } else { - System.err.println("检查集合是否存在时出错: " + booleanR.getMessage()); - return; - } - if (contentId == null) { - throw new RuntimeException("知识库内容ID不能为空"); - } + if (contentId == null) { + throw new IllegalArgumentException("知识库内容ID不能为空"); + } - String partitionName = partitionNamePrefix + contentId; + String partitionName = PARTITION_NAME_PREFIX + contentId; - // 检查分区是否存在 - HasPartitionParam hasPartitionParam = HasPartitionParam.newBuilder() - .withCollectionName(collectionName) - .withPartitionName(partitionName) - .build(); - R hasPartition = milvusClient.hasPartition(hasPartitionParam); + // 检查并创建分区 + R hasPartition = getClient().hasPartition( + HasPartitionParam.newBuilder() + .withCollectionName(collectionName) + .withPartitionName(partitionName) + .build() + ); - if (hasPartition.getStatus() == R.Status.Success.getCode()) { - boolean partitionExists = hasPartition.getData().booleanValue(); - if (!partitionExists) {//分区不存在时创建分区 - milvusClient.createPartition( + if (hasPartition.getStatus() != R.Status.Success.getCode()) { + throw new RuntimeException("Failed to check partition: " + hasPartition.getMessage()); + } + + boolean partitionExists = Boolean.TRUE.equals(hasPartition.getData()); + + if (!partitionExists) { + R createResponse = getClient().createPartition( CreatePartitionParam.newBuilder() .withCollectionName(collectionName) .withPartitionName(partitionName) .build() ); + + if (createResponse.getStatus() != R.Status.Success.getCode()) { + throw new RuntimeException("Failed to create partition: " + createResponse.getMessage()); + } + + log.debug("Partition created: {}", partitionName); } - } else { - System.err.println("检查分区是否存在时出错: " + hasPartition.getMessage()); - return; - } + // 准备插入数据 + int batchSize = Math.min(chunkList.size(), vectorList.size()); + if (batchSize == 0) { + log.warn("No data to insert"); + return; + } + + if (fidList.size() != batchSize) { + throw new IllegalArgumentException("fidList size must match chunkList size"); + } - try { List> vectorFloatList = new ArrayList<>(); List kidList = new ArrayList<>(); List contentIdList = new ArrayList<>(); - for (int i = 0; i < Math.min(chunkList.size(), vectorList.size()); i++) { + + for (int i = 0; i < batchSize; i++) { + // 转换向量 List vector = vectorList.get(i); - List vfList = new ArrayList<>(); - for (int j = 0; j < vector.size(); j++) { - Double value = vector.get(j); - vfList.add(value.floatValue()); - } - vectorFloatList.add(vfList); + List floatVector = vector.stream() + .map(Double::floatValue) + .collect(Collectors.toList()); + vectorFloatList.add(floatVector); + + // 添加ID kidList.add(String.valueOf(knowledgeBaseId)); contentIdList.add(String.valueOf(contentId)); } - List fields = new ArrayList<>(); - fields.add(new InsertParam.Field(metadataFieldName, chunkList)); - fields.add(new InsertParam.Field(knowledgeBaseIdFieldName, kidList)); - fields.add(new InsertParam.Field(contentIdFieldName, contentIdList)); - fields.add(new InsertParam.Field(fragmentIdFieldName, fidList)); - fields.add(new InsertParam.Field(vectorFieldName, vectorFloatList)); + // 构建插入参数 + List fields = new ArrayList<>(); + fields.add(new InsertParam.Field(METADATA_FIELD_NAME, chunkList.subList(0, batchSize))); + fields.add(new InsertParam.Field(KNOWLEDGE_BASE_ID_FIELD_NAME, kidList)); + fields.add(new InsertParam.Field(CONTENT_ID_FIELD_NAME, contentIdList)); + fields.add(new InsertParam.Field(FRAGMENT_ID_FIELD_NAME, fidList.subList(0, batchSize))); + fields.add(new InsertParam.Field(VECTOR_FIELD_NAME, vectorFloatList)); InsertParam insertParam = InsertParam.newBuilder() .withCollectionName(collectionName) @@ -279,239 +489,150 @@ public class MilvusServiceImpl implements IVectorDBService { .build(); // 执行插入 - R response = milvusClient.insert(insertParam); + R response = getClient().insert(insertParam); if (response.getStatus() != R.Status.Success.getCode()) { - log.error("Failed to insert data: {}", response.getMessage()); - throw new RuntimeException("Failed to insert data into Milvus"); + throw new RuntimeException("Failed to insert data: " + response.getMessage()); } - // 刷新数据使可搜索 - milvusClient.flush(FlushParam.newBuilder() + log.info("Successfully inserted {} records into collection {}", batchSize, collectionName); + + // 刷新数据 + getClient().flush(FlushParam.newBuilder() .addCollectionName(collectionName) .build()); - // milvus在将数据装载到内存后才能进行向量计算. - LoadCollectionParam loadCollectionParam = LoadCollectionParam.newBuilder() - .withCollectionName(collectionName) - .build(); - R loadResponse = milvusClient.loadCollection(loadCollectionParam); - if (loadResponse.getStatus() != R.Status.Success.getCode()) { - System.err.println("加载集合 " + collectionName + " 到内存时出错:" + loadResponse.getMessage()); + // 确保集合已加载到内存 + if (!loadedCollections.contains(collectionName)) { + loadCollectionToMemory(collectionName); } } catch (Exception e) { - log.error("Failed to insert data:", e); - throw new RuntimeException("Failed to insert data:" + e); + log.error("Failed to insert knowledge embeddings for knowledgeBaseId {}, contentId {}: {}", + knowledgeBaseId, contentId, e.getMessage(), e); + throw new RuntimeException("Failed to insert data: " + e.getMessage(), e); } } - + /** + * 向量搜索 + */ @Override public List search(List queryVector, Long knowledgeBaseId, int topK) { String collectionName = vectorDBConfig.getMilvus().getCollectionNamePrefix() + knowledgeBaseId; - HasCollectionParam hasCollectionParam = HasCollectionParam.newBuilder() - .withCollectionName(collectionName) - .build(); - R booleanR = milvusClient.hasCollection(hasCollectionParam); - if (booleanR.getStatus() != R.Status.Success.getCode() || !booleanR.getData().booleanValue()) { - System.err.println("集合 " + collectionName + " 不存在或检查集合存在性时出错。"); - return new ArrayList<>(); - } - - DescribeIndexParam describeIndexParam = DescribeIndexParam.newBuilder().withCollectionName(collectionName).build(); - - R describeIndexResponseR = milvusClient.describeIndex(describeIndexParam); - - if (describeIndexResponseR.getStatus() == R.Status.Success.getCode()) { - System.out.println("索引信息: " + describeIndexResponseR.getData().getIndexDescriptionsCount()); - } else { - System.err.println("获取索引失败: " + describeIndexResponseR.getMessage()); - } - - - // 构建搜索参数 - List outputFields = Arrays.asList(primaryKeyFieldName, metadataFieldName); - - List fv = new ArrayList<>(); - for (int i = 0; i < queryVector.size(); i++) { - fv.add(queryVector.get(i).floatValue()); - } - List> vectors = new ArrayList<>(); - vectors.add(fv); - - String searchParams = "{\"nprobe\":10, \"offset\":0}"; - - SearchParam searchParam = SearchParam.newBuilder() - .withCollectionName(collectionName) -// .withMetricType(MetricType.L2) - .withMetricType(MetricType.COSINE) - .withOutFields(outputFields) - .withTopK(topK) - .withVectors(vectors) - .withVectorFieldName(vectorFieldName) - .withParams(searchParams) - .build(); - - R respSearch = milvusClient.search(searchParam); - - if (respSearch.getStatus() != R.Status.Success.getCode()) { - log.error("Failed to search vectors: {}", respSearch.getMessage()); - throw new RuntimeException("Failed to search vectors in Milvus"); - } - - - System.out.println("SearchParam: " + searchParam.toString()); - if (respSearch.getStatus() == R.Status.Success.getCode()) { - SearchResults searchResults = respSearch.getData(); - if (searchResults != null) { - System.out.println(searchResults.getResults()); - SearchResultsWrapper wrapperSearch = new SearchResultsWrapper(searchResults.getResults()); - List rowRecords = wrapperSearch.getRowRecords(); - - List resultList = new ArrayList<>(); - if (rowRecords != null && !rowRecords.isEmpty()) { - for (QueryResultsWrapper.RowRecord rowRecord : rowRecords) { - String content = rowRecord.get(metadataFieldName).toString(); - resultList.add(content); - } - } - return resultList; - } else { - System.err.println("搜索结果为空"); - } - } else { - System.err.println("搜索操作失败: " + respSearch.getMessage()); - } - - return new ArrayList<>(); - - // 处理搜索结果 -// SearchResultsWrapper wrapper = new SearchResultsWrapper(response.getData().getResults()); - // 调试:检查搜索结果数量 -// List idScores = wrapper.getIDScore(0); - - // 调试:检查字段数据 -// List metadataData = wrapper.getFieldData(metadataFieldName, 0); -// System.out.println("元数据字段数据数量: " + metadataData.size()); - -// System.out.println("搜索返回结果数量: " + idScores.size()); -// for (int i = 0; i < idScores.size(); i++) { -// SearchResultsWrapper.IDScore idScore = idScores.get(i); -// System.out.println("结果 " + i + ": ID=" + idScore.getStrID() + ", Score=" + idScore.getScore()); -// -// if (i < metadataData.size()) { -// System.out.println("对应元数据: " + metadataData.get(i)); -// } -// } - -// return idScores.stream() -// .map(idScore -> { -// try { -// SearchResult result = new SearchResult(); -// result.setId(idScore.getStrID()); -// result.setScore(idScore.getScore()); -// -// // 获取元数据 -// // 使用索引位置获取元数据 -// int index = idScores.indexOf(idScore); -// if (index < metadataData.size()) { -// String metadataJson = metadataData.get(index).toString(); -// result.setMetadata(objectMapper.readValue(metadataJson, Object.class)); -// return result; -// } -// return null; -// } catch (Exception e) { -// log.error("Failed to process search result", e); -// return null; -// } -// }) -// .filter(Objects::nonNull) -// .collect(Collectors.toList()); - } - - private void processResult(SearchResultsWrapper wrapper) { - // 调试:检查搜索结果数量 - List idScores = wrapper.getIDScore(0); - System.out.println("搜索返回结果数量: " + idScores.size()); - } - - private String convertToJsonString(Object data, ObjectMapper objectMapper) { try { - if (data instanceof String) { - String strData = (String) data; - // 检查是否是有效的JSON - try { - objectMapper.readTree(strData); - return strData; // 已经是JSON,直接返回 - } catch (Exception e) { - // 不是JSON,包装成JSON对象 - Map jsonObject = new HashMap<>(); - jsonObject.put("text", strData); - return objectMapper.writeValueAsString(jsonObject); + checkConnection(); + + // 检查集合是否存在 + if (!hasCollection(collectionName)) { + log.warn("Collection {} does not exist", collectionName); + return Collections.emptyList(); + } + + // 确保集合已加载到内存 + if (!loadedCollections.contains(collectionName)) { + loadCollectionToMemory(collectionName); + } + + // 准备查询向量 + List floatVector = queryVector.stream() + .map(Double::floatValue) + .collect(Collectors.toList()); + + List> vectors = Collections.singletonList(floatVector); + + // 构建搜索参数 + List outputFields = Arrays.asList(PRIMARY_KEY_FIELD_NAME, METADATA_FIELD_NAME); + String searchParams = "{\"nprobe\":10, \"offset\":0}"; + + SearchParam searchParam = SearchParam.newBuilder() + .withCollectionName(collectionName) + .withMetricType(MetricType.COSINE) + .withOutFields(outputFields) + .withTopK(topK) + .withVectors(vectors) + .withVectorFieldName(VECTOR_FIELD_NAME) + .withParams(searchParams) + .build(); + + // 执行搜索 + R response = getClient().search(searchParam); + + if (response.getStatus() != R.Status.Success.getCode()) { + log.error("Search failed for collection {}: {}", collectionName, response.getMessage()); + throw new RuntimeException("Failed to search vectors in Milvus"); + } + + // 处理结果 + SearchResults searchResults = response.getData(); + if (searchResults == null) { + log.warn("Search returned no results for collection {}", collectionName); + return Collections.emptyList(); + } + + SearchResultsWrapper wrapper = new SearchResultsWrapper(searchResults.getResults()); + List rowRecords = wrapper.getRowRecords(); + + List resultList = new ArrayList<>(); + if (rowRecords != null && !rowRecords.isEmpty()) { + for (QueryResultsWrapper.RowRecord rowRecord : rowRecords) { + String content = rowRecord.get(METADATA_FIELD_NAME).toString(); + resultList.add(content); } - } else { - // 直接序列化对象 - return objectMapper.writeValueAsString(data); } + + log.debug("Search completed for collection {}, found {} results", collectionName, resultList.size()); + return resultList; + } catch (Exception e) { - log.warn("Failed to convert data to JSON, using default format", e); - try { - Map defaultObject = new HashMap<>(); - defaultObject.put("raw_data", data != null ? data.toString() : "null"); - return objectMapper.writeValueAsString(defaultObject); - } catch (Exception ex) { - return "{\"error\": \"failed_to_convert_to_json\"}"; + log.error("Search failed for knowledgeBaseId {}: {}", knowledgeBaseId, e.getMessage(), e); + + // 如果是连接问题,尝试重连 + if (e instanceof IllegalStateException && e.getMessage().contains("not connected")) { + isConnected.set(false); + attemptConnection(); } + + throw new RuntimeException("Search failed: " + e.getMessage(), e); } } - - // @Override -// public void removeByKidAndFid(String kid, String fid) { -// milvusServiceClient.delete -// DeleteParam.newBuilder() -// .withCollectionName(collectionName + kid) -// .withExpr("fid == " + fid) -// .build() -// ); -// } -// -// -// - /** - * 删除数据by 知识库ID和Content ID - * - * @param knowledgeBaseId - * @param contentId + * 根据内容ID删除数据 */ @Override public void removeByContentId(Long knowledgeBaseId, Long contentId) { String collectionName = vectorDBConfig.getMilvus().getCollectionNamePrefix() + knowledgeBaseId; - String partitionName = partitionNamePrefix + contentId; -// milvusClient.delete( -// DeleteParam.newBuilder() -// .withCollectionName(collectionName) -// .withExpr("contentId>0") -// .withPartitionName(partitionName) -// .build() -// ); + String partitionName = PARTITION_NAME_PREFIX + contentId; - // 可以先查询分区信息确认存在 - // 检查分区是否存在 - HasPartitionParam hasPartitionParam = HasPartitionParam.newBuilder() - .withCollectionName(collectionName) - .withPartitionName(partitionName) - .build(); - R hasPartition = milvusClient.hasPartition(hasPartitionParam); + try { + checkConnection(); - if (hasPartition.getStatus() == R.Status.Success.getCode()) { - boolean partitionExists = hasPartition.getData().booleanValue(); - if (partitionExists) {//分区存在时删除 - // 首先释放(卸载)分区 - milvusClient.releasePartitions( + // 检查集合是否存在 + if (!hasCollection(collectionName)) { + log.warn("Collection {} does not exist, nothing to delete", collectionName); + return; + } + + // 检查分区是否存在 + R hasPartition = getClient().hasPartition( + HasPartitionParam.newBuilder() + .withCollectionName(collectionName) + .withPartitionName(partitionName) + .build() + ); + + if (hasPartition.getStatus() != R.Status.Success.getCode()) { + log.error("Failed to check partition existence: {}", hasPartition.getMessage()); + return; + } + + boolean partitionExists = Boolean.TRUE.equals(hasPartition.getData()); + + if (partitionExists) { + // 首先释放分区 + getClient().releasePartitions( ReleasePartitionsParam.newBuilder() .withCollectionName(collectionName) .withPartitionNames(Collections.singletonList(partitionName)) @@ -519,112 +640,89 @@ public class MilvusServiceImpl implements IVectorDBService { ); // 然后删除分区 - milvusClient.dropPartition( + R dropResponse = getClient().dropPartition( DropPartitionParam.newBuilder() .withCollectionName(collectionName) .withPartitionName(partitionName) .build() ); + if (dropResponse.getStatus() == R.Status.Success.getCode()) { + log.info("Successfully deleted partition {} from collection {}", partitionName, collectionName); + } else { + log.error("Failed to delete partition {}: {}", partitionName, dropResponse.getMessage()); + } + } else { + log.info("Partition {} does not exist, nothing to delete", partitionName); } - } else { - System.err.println("检查分区是否存在时出错: " + hasPartition.getMessage()); - return; + + } catch (Exception e) { + log.error("Failed to remove content {} from knowledgeBase {}: {}", + contentId, knowledgeBaseId, e.getMessage(), e); + throw new RuntimeException("Failed to remove content: " + e.getMessage(), e); } } /** * 根据知识库ID删除数据 - * - * @param knowledgeBaseId */ @Override public void removeByKnowledgeBaseId(Long knowledgeBaseId) { String collectionName = vectorDBConfig.getMilvus().getCollectionNamePrefix() + knowledgeBaseId; - if(hasCollection(collectionName)){ - milvusClient.dropCollection( - DropCollectionParam.newBuilder() - .withCollectionName(collectionName) - .build() - ); + + try { + checkConnection(); + + if (hasCollection(collectionName)) { + R response = getClient().dropCollection( + DropCollectionParam.newBuilder() + .withCollectionName(collectionName) + .build() + ); + + if (response.getStatus() == R.Status.Success.getCode()) { + loadedCollections.remove(collectionName); + log.info("Successfully deleted collection: {}", collectionName); + } else { + log.error("Failed to delete collection {}: {}", collectionName, response.getMessage()); + } + } else { + log.info("Collection {} does not exist, nothing to delete", collectionName); + } + + } catch (Exception e) { + log.error("Failed to remove knowledgeBase {}: {}", knowledgeBaseId, e.getMessage(), e); + throw new RuntimeException("Failed to remove knowledge base: " + e.getMessage(), e); } } -// @Override -// public void delete(String id) { -// // 构建删除表达式 -// String expr = "id == \"" + id + "\""; -// -// DeleteParam deleteParam = DeleteParam.newBuilder() -// .withCollectionName(collectionName) -// .withExpr(expr) -// .build(); -// -// R response = milvusClient.delete(deleteParam); -// -// if (response.getStatus() != R.Status.Success.getCode()) { -// log.error("Failed to delete data: {}", response.getMessage()); -// throw new RuntimeException("Failed to delete data from Milvus"); -// } -// -// // 刷新数据 -// milvusClient.flush(FlushParam.newBuilder() -// .addCollectionName(collectionName) -// .build()); -// } + /** + * 获取连接状态 + */ + public String getConnectionStatus() { + if (isConnected.get()) { + return String.format("Connected to %s:%s", + vectorDBConfig.getMilvus().getHost(), + vectorDBConfig.getMilvus().getPort()); + } else { + return "Disconnected"; + } + } + /** + * 获取已加载集合列表 + */ + public List getLoadedCollections() { + return new ArrayList<>(loadedCollections); + } - // @Override -// public List getVectorById(String id) { -// // 构建查询参数 -// List outputFields = Collections.singletonList("vector"); -// String expr = "id == \"" + id + "\""; -// -// QueryParam queryParam = QueryParam.newBuilder() -// .withCollectionName(collectionName) -// .withExpr(expr) -// .withOutFields(outputFields) -// .build(); -// -// R response = milvusClient.query(queryParam); -// -// if (response.getStatus() != R.Status.Success.getCode()) { -// log.error("Failed to query vector: {}", response.getMessage()); -// throw new RuntimeException("Failed to query vector from Milvus"); -// } -// -// QueryResultsWrapper wrapper = new QueryResultsWrapper(response.getData()); -// return wrapper.getFieldWrapper("vector").getFieldData().get(0).get(0); -// } - -// @Override -// public Object getMetadataById(String id) { -// try { -// // 构建查询参数 -// List outputFields = Collections.singletonList("metadata"); -// String expr = "id == \"" + id + "\""; -// -// QueryParam queryParam = QueryParam.newBuilder() -// .withCollectionName(collectionName) -// .withExpr(expr) -// .withOutFields(outputFields) -// .build(); -// -// R response = milvusClient.query(queryParam); -// -// if (response.getStatus() != R.Status.Success.getCode()) { -// log.error("Failed to query metadata: {}", response.getMessage()); -// throw new RuntimeException("Failed to query metadata from Milvus"); -// } -// -// QueryResultsWrapper wrapper = new QueryResultsWrapper(response.getData()); -// String metadataJson = wrapper.getFieldWrapper("metadata").getFieldData().get(0).get(0).toString(); -// -// return objectMapper.readValue(metadataJson, Object.class); -// } catch (Exception e) { -// log.error("Failed to deserialize metadata", e); -// throw new RuntimeException("Failed to deserialize metadata"); -// } -// } - + /** + * 手动触发重连 + */ + public void reconnect() { + log.info("Manual reconnection triggered"); + isConnected.set(false); + loadedCollections.clear(); + attemptConnection(); + } }