You need to enable JavaScript to run this app.
导航
串联大模型最佳实践
最近更新时间:2025.09.30 15:22:29首次发布时间:2025.08.15 17:43:46
复制全文
我的收藏
有用
有用
无用
无用

本章将以火山引擎“豆包大模型”为例,向您展示如何将记忆库与大模型相结合,构建一个具备长期记忆能力的对话式AI。

准备工作与API封装

在开始之前,请确保您已经:

  1. 安装了必要的库:
pip install volcengine
pip install --upgrade "volcengine-python-sdk[ark]"
  1. 在火山引擎获取用于记忆库鉴权的 AK/SK,在火山方舟获取用于大模型鉴权的 API Key,并将 AK、SK 和 API Key 一并写入环境变量中。

为了方便调用,建议您将记忆库的API调用逻辑封装成一个服务类。以下是一个完整的示例封装,您可以将其保存为 step1_memory_service.py 文件,以便在主程序中引入。
step1_memory_service.py

import json
import threading

from volcengine.ApiInfo import ApiInfo
from volcengine.Credentials import Credentials
from volcengine.base.Service import Service
from volcengine.ServiceInfo import ServiceInfo
from volcengine.auth.SignerV4 import SignerV4
from volcengine.base.Request import Request

class VikingDBMemoryException(Exception):
    def __init__(self, code, request_id, message=None):
        self.code = code
        self.request_id = request_id
        self.message = "{}, code:{},request_id:{}".format(message, self.code, self.request_id)

    def __str__(self):
        return self.message

class VikingDBMemoryService(Service):
    _instance_lock = threading.Lock()

    def __new__(cls, *args, **kwargs):
        if not hasattr(VikingDBMemoryService, "_instance"):
            with VikingDBMemoryService._instance_lock:
                if not hasattr(VikingDBMemoryService, "_instance"):
                    VikingDBMemoryService._instance = object.__new__(cls)
        return VikingDBMemoryService._instance

    def __init__(self, host="api-knowledgebase.mlp.cn-beijing.volces.com", region="cn-beijing", ak="", sk="", sts_token="", scheme='https',
                 connection_timeout=30, socket_timeout=30):
        self.service_info = VikingDBMemoryService.get_service_info(host, region, scheme, connection_timeout, socket_timeout)
        self.api_info = VikingDBMemoryService.get_api_info()
        super(VikingDBMemoryService, self).__init__(self.service_info, self.api_info)
        if ak:
            self.set_ak(ak)
        if sk:
            self.set_sk(sk)
        if sts_token:
            self.set_session_token(session_token=sts_token)
        try:
            self.get_body("Ping", {}, json.dumps({}))
        except Exception as e:
            raise VikingDBMemoryException(1000028, "missed", "host or region is incorrect".format(str(e))) from None

    def setHeader(self, header):
        api_info = VikingDBMemoryService.get_api_info()
        for key in api_info:
            for item in header:
                api_info[key].header[item] = header[item]
        self.api_info = api_info

    @staticmethod
    def get_service_info(host, region, scheme, connection_timeout, socket_timeout):
        service_info = ServiceInfo(host, {"Host": host},
                                   Credentials('', '', 'air', region), connection_timeout, socket_timeout,
                                   scheme=scheme)
        return service_info

    @staticmethod 
    def get_api_info():
        api_info = {
            "CreateCollection":     ApiInfo("POST", "/api/memory/collection/create", {}, {},
                                        {'Accept': 'application/json', 'Content-Type': 'application/json'}),
            "GetCollection":        ApiInfo("POST", "/api/memory/collection/info", {}, {},
                                        {'Accept': 'application/json', 'Content-Type': 'application/json'}),
            "DropCollection":       ApiInfo("POST", "/api/memory/collection/delete", {}, {},
                                        {'Accept': 'application/json', 'Content-Type': 'application/json'}),
            "UpdateCollection":     ApiInfo("POST", "/api/memory/collection/update", {}, {},
                                        {'Accept': 'application/json', 'Content-Type': 'application/json'}),
            
            "SearchMemory":     ApiInfo("POST", "/api/memory/search", {}, {},
                                        {'Accept': 'application/json', 'Content-Type': 'application/json'}),
            "AddSession":    ApiInfo("POST", "/api/memory/session/add", {}, {},
                                        {'Accept': 'application/json', 'Content-Type': 'application/json'}),
            
            "Ping":                   ApiInfo("GET", "/api/memory/ping", {}, {},
                                        {'Accept': 'application/json', 'Content-Type': 'application/json'}),
        }
        return api_info

    def get_body(self, api, params, body):
        if not (api in self.api_info):
            raise Exception("no such api")
        api_info = self.api_info[api]
        r = self.prepare_request(api_info, params)
        r.headers['Content-Type'] = 'application/json'
        r.headers['Traffic-Source'] = 'SDK'
        r.body = body

        SignerV4.sign(r, self.service_info.credentials)

        url = r.build()
        resp = self.session.get(url, headers=r.headers, data=r.body,
                                timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout))
        if resp.status_code == 200:
            return json.dumps(resp.json())
        else:
            raise Exception(resp.text.encode("utf-8"))

    def get_body_exception(self, api, params, body):
        try:
            res = self.get_body(api, params, body)
        except Exception as e:
            try:
                res_json = json.loads(e.args[0].decode("utf-8"))
            except:
                raise VikingDBMemoryException(1000028, "missed", "json load res error, res:{}".format(str(e))) from None
            code = res_json.get("code", 1000028)
            request_id = res_json.get("request_id", 1000028)
            message = res_json.get("message", None)
            
            raise VikingDBMemoryException(code, request_id, message)
        
        if res == '':
            raise VikingDBMemoryException(1000028, "missed",
                                    "empty response due to unknown error, please contact customer service") from None
        return res
    
    def get_exception(self, api, params):
        try:
            res = self.get(api, params)
        except Exception as e:
            try:
                res_json = json.loads(e.args[0].decode("utf-8"))
            except:
                raise VikingDBMemoryException(1000028, "missed", "json load res error, res:{}".format(str(e))) from None
            code = res_json.get("code", 1000028)
            request_id = res_json.get("request_id", 1000028)
            message = res_json.get("message", None)
            raise VikingDBMemoryException(code, request_id, message)
        if res == '':
            raise VikingDBMemoryException(1000028, "missed",
                                    "empty response due to unknown error, please contact customer service") from None
        return res
    

    def create_collection(self, collection_name, description="", custom_event_type_schemas=[], custom_profile_type_schemas=[], builtin_event_types=[], builtin_profile_types=[]):
        params = {
            "CollectionName": collection_name, "Description": description, 
            "CustomEventTypeSchemas": custom_event_type_schemas, "CustomProfileTypeSchemas": custom_profile_type_schemas,
            "BuiltinEventTypes": builtin_event_types, "BuiltinProfileTypes": builtin_profile_types,    
        }
        res = self.json("CreateCollection", {}, json.dumps(params))
        return json.loads(res)

    def get_collection(self, collection_name):
        params = {"CollectionName": collection_name}
        res = self.json("GetCollection", {}, json.dumps(params))
        return json.loads(res)

    def drop_collection(self, collection_name):
        params = {"CollectionName": collection_name}
        res = self.json("DropCollection", {}, json.dumps(params))
        return json.loads(res)

    def update_collection(self, collection_name, custom_event_type_schemas=[], custom_profile_type_schemas=[], builtin_event_types=[], builtin_profile_types=[]):
        params = {
            "CollectionName": collection_name, 
            "CustomEventTypeSchemas": custom_event_type_schemas, "CustomProfileTypeSchemas": custom_profile_type_schemas,
            "BuiltinEventTypes": builtin_event_types, "BuiltinProfileTypes": builtin_profile_types,    
        }
        res = self.json("UpdateCollection", {}, json.dumps(params))
        return json.loads(res)
        
    def search_memory(self, collection_name, query, filter, limit=10):
        params = {
            "collection_name": collection_name, 
            "query": query,
            "limit": limit,
            "filter": filter,
        }
        res = self.json("SearchMemory", {}, json.dumps(params))
        return json.loads(res)
    
    def add_session(self, collection_name, session_id, messages, metadata, entities=None):
        params = {
            "collection_name": collection_name, 
            "session_id": session_id,
            "messages": messages,
            "metadata": metadata,
        }
        if entities is not None:
            params["entities"] = entities
        res = self.json("AddSession", {}, json.dumps(params))
        return json.loads(res)

进行一个完整话题的对话

在记忆库准备好后,我们先模拟一段包含两轮的完整对话。对话结束后,把这段对话历史消息写入记忆库。然后再开启一个新话题,提出和刚才相关的问题,AI 就能用刚写入的记忆来回答。注意:首次写入需要 3–5 分钟建立索引,这段时间内检索会报错。
step2_first_conversation.py

import json
import os
import time
from dotenv import load_dotenv
from volcenginesdkarkruntime import Ark
from step1_memory_service import VikingDBMemoryService

def initialize_services():
    load_dotenv()
    ak = os.environ.get("VOLC_ACCESSKEY")
    sk = os.environ.get("VOLC_SECRETKEY")
    ark_api_key = os.environ.get("ARK_API_KEY")

    if not all([ak, sk, ark_api_key]):
        raise ValueError("必须在环境变量中设置 VOLC_ACCESSKEY, VOLC_SECRETKEY, 和 ARK_API_KEY。")

    memory_service = VikingDBMemoryService(ak=ak, sk=sk)
    llm_client = Ark(
        base_url="https://ark.cn-beijing.volces.com/api/v3",
        api_key=ark_api_key,
    )
    return memory_service, llm_client

def ensure_collection_exists(memory_service, collection_name):
    """检查记忆集合是否存在,如果不存在则创建。"""
    try:
        memory_service.get_collection(collection_name)
        print(f"记忆集合 '{collection_name}' 已存在。")
    except Exception as e:
        error_message = str(e)
        if "collection not exist" in error_message:
            print(f"记忆集合 '{collection_name}' 未找到,正在创建...")
            try:
                memory_service.create_collection(
                    collection_name=collection_name,
                    description="中文情感陪伴场景测试",
                    builtin_event_types=["sys_event_v1"],
                    builtin_profile_types=["sys_profile_v1"]
                )
                print(f"记忆集合 '{collection_name}' 创建成功。")
                print("等待集合准备就绪...")
            except Exception as create_e:
                print(f"创建集合失败: {create_e}")
                raise
        else:
            print(f"检查集合时出错: {e}")
            raise

def search_relevant_memories(memory_service, collection_name, user_id, query):
    """搜索与用户查询相关的记忆,并在索引构建中时重试。"""
    print(f"正在搜索与 '{query}' 相关的记忆...")
    retry_attempt = 0
    while True:
        try:
            filter_params = {
                "user_id": [user_id],
                "memory_type": ["sys_event_v1", "sys_profile_v1"]
            }
            response = memory_service.search_memory(
                collection_name=collection_name,
                query=query,
                filter=filter_params,
                limit=3
            )

            memories = []
            if response.get('data', {}).get('count', 0) > 0:
                for result in response['data']['result_list']:
                    if 'memory_info' in result and result['memory_info']:
                        memories.append({
                            'memory_info': result['memory_info'],
                            'score': result['score']
                        })

            if memories:
                if retry_attempt > 0:
                    print("重试后搜索成功。")
                print(f"找到 {len(memories)} 条相关记忆:")
                for i, memory in enumerate(memories, 1):
                    print(f"  {i}. (相关度: {memory['score']:.3f}): {json.dumps(memory['memory_info'], ensure_ascii=False, indent=2)}")
            else:
                print("未找到相关记忆。")
            return memories

        except Exception as e:
            error_message = str(e)
            if "1000023" in error_message:
                retry_attempt += 1
                print(f"记忆索引正在构建中。将在60秒后重试... (尝试次数 {retry_attempt})")
                time.sleep(60)
            else:
                print(f"搜索记忆时出错 (不可重试): {e}")
                return []

def handle_conversation_turn(memory_service, llm_client, collection_name, user_id, user_message, conversation_history):
    """处理一轮对话,包括记忆搜索和LLM响应。"""
    print("\n" + "="*60)
    print(f"用户: {user_message}")
    
    relevant_memories = search_relevant_memories(memory_service, collection_name, user_id, user_message)

    system_prompt = "你是一个富有同情心、善于倾听的AI伙伴,拥有长期记忆能力。你的目标是为用户提供情感支持和温暖的陪伴。"
    if relevant_memories:
        memory_context = "\n".join([f"- {json.dumps(mem['memory_info'], ensure_ascii=False)}" for mem in relevant_memories])
        system_prompt += f"\n\n这是我们过去的一些对话记忆,请参考:\n{memory_context}\n\n请利用这些信息来更好地理解和回应用户。"

    print("AI正在思考...")
    
    try:
        messages = [{"role": "system", "content": system_prompt}] + conversation_history + [{"role": "user", "content": user_message}]
        completion = llm_client.chat.completions.create(
            model="doubao-seed-1-6-flash-250715",
            messages=messages
        )
        assistant_reply = completion.choices[0].message.content
    except Exception as e:
        print(f"LLM调用失败: {e}")
        assistant_reply = "抱歉,我现在有点混乱,无法回应。我们可以稍后再聊吗?"
    
    print(f"伙伴: {assistant_reply}")
    
    conversation_history.extend([
        {"role": "user", "content": user_message},
        {"role": "assistant", "content": assistant_reply}
    ])
    return assistant_reply

def archive_conversation(memory_service, collection_name, user_id, assistant_id, conversation_history, topic_name):
    """将对话历史归档到记忆数据库。"""
    if not conversation_history:
        print("没有对话可以归档。")
        return False
        
    print(f"\n正在归档关于 '{topic_name}' 的对话...")
    session_id = f"{topic_name}_{int(time.time())}"
    metadata = {
        "default_user_id": user_id,
        "default_assistant_id": assistant_id,
        "time": int(time.time() * 1000)
    }
    
    try:
        memory_service.add_session(
            collection_name=collection_name,
            session_id=session_id,
            messages=conversation_history,
            metadata=metadata
        )
        print(f"对话已成功归档,会话ID: {session_id}")
        print("正在等待记忆索引更新...")
        return True
    except Exception as e:
        print(f"归档对话失败: {e}")
        return False

def main():
    print("开始端到端记忆测试...")

    try:
        memory_service, llm_client = initialize_services()
        collection_name = "emotional_support"
        user_id = "xiaoming"
        assistant_id = "assistant"
        ensure_collection_exists(memory_service, collection_name)
    except Exception as e:
        print(f"初始化失败: {e}")
        return

    print("\n--- 阶段 1: 初始对话 ---")
    initial_conversation_history = []
    handle_conversation_turn(
        memory_service, llm_client, collection_name, user_id,
        "你好,我是小明,今年18岁,马上就要高考了。",
        initial_conversation_history
    )
    handle_conversation_turn(
        memory_service, llm_client, collection_name, user_id,
        "家里人的期待好高。",
        initial_conversation_history
    )

    print("\n--- 阶段 2: 归档记忆 ---")
    archive_conversation(
        memory_service, collection_name, user_id, assistant_id,
        initial_conversation_history, "study_stress_discussion"
    )

    print("\n--- 阶段 3: 验证记忆 ---")
    verification_conversation_history = []
    handle_conversation_turn(
        memory_service, llm_client, collection_name, user_id,
        "我最近很焦虑,不知道该怎么办。",
        verification_conversation_history
    )

    print("\n端到端记忆测试完成!")
    
if __name__ == "__main__":
    main()