AI对话记忆系统:从原理到实现

这篇指南讲什么

这是AI记忆系统的完整教程,从理论原理到代码实现,手把手教你打造”会记住”的AI助手。不管是简单的客服Bot还是复杂的Agent系统,记忆都是核心能力之一。

什么是记忆系统?

为什么AI需要记忆?

普通的LLM调用是这样的:

用户: 你好
助手: 你好!

用户: 我昨天买了个手机
助手: 好的

用户: 什么时候发货的?
助手: 抱歉,我不知道你买手机的事

每次对话都是独立的,AI不记得之前说过什么。这就是”没有记忆”的问题。

有了记忆系统之后:

用户: 你好
助手: 你好!

用户: 我昨天买了个手机,订单号12345
助手: 好的,我已经记住了(存入记忆)

用户: 什么时候发货的?
助手: 让我查一下(从记忆中读取订单12345)...是今天上午9点发货的

记忆系统的层级

一个完善的记忆系统通常包含多个层级:

┌─────────────────────────────────────────┐
│            上下文窗口(Context Window)         │
│  用户输入 → 短期记忆 → 长期记忆 → LLM处理      │
└─────────────────────────────────────────┘

层级说明:
- 上下文窗口:模型能看到的token数量(有限)
- 短期记忆:当前会话的对话历史(有限)
- 长期记忆:跨会话的持久信息(无限)
- 知识图谱:结构化的实体关系(无限)

记忆类型详解

短期记忆(Short-term Memory)

短期记忆存储当前会话的对话历史,是最基础的记忆形式。

特点:

  • 存在内存中,会话结束消失
  • 容量有限(受token限制)
  • 响应最快
  • 实现最简单

实现方式:

  • Buffer Memory:固定窗口,只保留最近N条
  • Sliding Window:滑动窗口,基于token数动态调整

长期记忆(Long-term Memory)

长期记忆是跨会话持久化的信息,能让AI记住用户偏好、历史交互。

特点:

  • 持久存储(数据库/文件)
  • 容量无限
  • 需要检索才能使用
  • 实现相对复杂

实现方式:

  • 向量存储:文本向量化后存储,检索时计算相似度
  • 键值存储:结构化数据存储,精准匹配

工作记忆(Working Memory)

工作记忆是Agent执行任务时的临时存储区域。

特点:

  • 任务内有效
  • 容量小
  • 用于多步骤推理

情景记忆(Episodic Memory)

情景记忆记录用户和AI之间的交互历史事件。

特点:

  • 时间序列组织
  • 可检索特定事件
  • 支持”回忆”功能

短期记忆实现

Buffer Memory(固定窗口)

最简单的记忆实现,只保留固定数量的对话:

from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field
from datetime import datetime
 
@dataclass
class Message:
    """消息结构"""
    role: str  # system/user/assistant
    content: str
    timestamp: datetime = field(default_factory=datetime.now)
    metadata: Dict[str, Any] = field(default_factory=dict)
 
class BufferMemory:
    """Buffer Memory - 固定窗口短期记忆"""
    
    def __init__(self, window_size: int = 10):
        """
        初始化记忆
        :param window_size: 保留的消息数量
        """
        self.window_size = window_size
        self.messages: List[Message] = []
    
    def add_message(self, role: str, content: str, metadata: Dict = None) -> None:
        """添加消息"""
        message = Message(
            role=role,
            content=content,
            metadata=metadata or {}
        )
        self.messages.append(message)
        
        # 保持窗口大小
        if len(self.messages) > self.window_size:
            self.messages.pop(0)
    
    def get_messages(self) -> List[Message]:
        """获取所有消息"""
        return self.messages.copy()
    
    def get_context(self) -> str:
        """获取格式化上下文"""
        return "\n".join([
            f"{msg.role}: {msg.content}"
            for msg in self.messages
        ])
    
    def clear(self) -> None:
        """清空记忆"""
        self.messages.clear()
    
    def get_token_count(self) -> int:
        """估算Token数(简单估算)"""
        # 粗略估算:1 token ≈ 4个字符
        return sum(len(msg.content) // 4 for msg in self.messages)
 
 
# 使用示例
memory = BufferMemory(window_size=10)
 
memory.add_message("user", "你好,我想买一部手机")
memory.add_message("assistant", "好的,请问您有什么品牌和预算要求?")
memory.add_message("user", "想要华为的,预算5000左右")
 
print(memory.get_context())
# 输出:
# user: 你好,我想买一部手机
# assistant: 好的,请问您有什么品牌和预算要求?
# user: 想要华为的,预算5000左右

Sliding Window Memory(滑动窗口)

基于Token数动态调整的记忆,不会丢失系统提示词:

class SlidingWindowMemory:
    """滑动窗口记忆 - 基于Token数"""
    
    def __init__(self, max_tokens: int = 4000):
        self.max_tokens = max_tokens
        self.messages: List[Message] = []
        self.system_messages: List[Message] = []  # 保留系统消息
    
    def add_message(self, role: str, content: str) -> None:
        """添加消息,自动滑动"""
        message = Message(role=role, content=content)
        
        # 系统消息单独存储
        if role == "system":
            self.system_messages.append(message)
        else:
            self.messages.append(message)
        
        # 裁剪以满足token限制
        self._trim_to_token_limit()
    
    def _trim_to_token_limit(self) -> None:
        """根据Token限制裁剪"""
        while self.get_token_count() > self.max_tokens and len(self.messages) > 1:
            # 优先保留系统消息
            self.messages.pop(0)  # 移除最旧的用户/助手消息
    
    def get_messages(self) -> List[Message]:
        """获取消息(包含系统消息)"""
        return self.system_messages + self.messages
    
    def get_context(self) -> str:
        """获取格式化上下文"""
        parts = [msg.content for msg in self.get_messages()]
        return "\n".join(parts)
    
    def get_token_count(self) -> int:
        """Token估算"""
        return sum(len(m.content) // 4 for m in self.messages)
 
 
# 使用示例
memory = SlidingWindowMemory(max_tokens=4000)
 
memory.add_message("system", "你是一个手机销售助手")
memory.add_message("user", "你好")
memory.add_message("assistant", "您好,有什么可以帮您的?")
memory.add_message("user", "我想买华为手机")
memory.add_message("assistant", "好的,请问您的预算是多少?")
memory.add_message("user", "5000左右")
 
print(f"当前Token数: {memory.get_token_count()}")
print(memory.get_context())

对话总结(Summarizing Memory)

当对话太长时,压缩历史为摘要:

class SummarizingMemory:
    """总结记忆 - 对话过长时压缩"""
    
    def __init__(
        self,
        max_messages: int = 20,
        summary_trigger: int = 15,
        llm_client = None
    ):
        self.max_messages = max_messages
        self.summary_trigger = summary_trigger
        self.llm_client = llm_client
        
        self.messages: List[Message] = []
        self.summary: Optional[str] = None
        self.uncompressed_count: int = 0  # 未压缩的消息数
    
    async def summarize_old_messages(self) -> None:
        """将旧消息压缩为摘要"""
        if not self.llm_client or len(self.messages) < self.summary_trigger:
            return
        
        # 获取需要总结的消息
        messages_to_summarize = self.messages[:-5]  # 保留最近5条
        
        # 调用LLM生成摘要
        prompt = f"""
请总结以下对话的要点:
 
{"".join([f"{m.role}: {m.content}\n" for m in messages_to_summarize])}
 
请用简洁的语言总结对话的核心内容,保留关键信息。
"""
        
        response = await self.llm_client.chat(prompt)
        
        # 创建摘要消息
        summary_msg = Message(
            role="system",
            content=f"[对话摘要] {response}"
        )
        
        # 压缩:保留摘要 + 最近消息
        self.summary = response
        self.messages = [summary_msg] + self.messages[-5:]
        self.uncompressed_count = 0
    
    def add_message(self, role: str, content: str) -> None:
        """添加消息"""
        self.messages.append(Message(role=role, content=content))
        
        # 检查是否需要总结
        if len(self.messages) >= self.summary_trigger:
            # 异步总结(需要配合事件循环)
            import asyncio
            asyncio.create_task(self.summarize_old_messages())
 
 
# 使用示例
memory = SummarizingMemory(
    max_messages=20,
    summary_trigger=15,
    llm_client=openai_client
)

长期记忆实现

向量存储架构

向量存储是长期记忆的核心技术:

┌────────────────────────────────────────────────┐
│                   写入流程                        │
│                                                │
│  新对话 → 信息提取 → 向量化 → 存储到向量数据库     │
│                        ↓                       │
│                   存储内容:                     │
│                   - 原始文本                     │
│                   - 向量                        │
│                   - 元数据(用户、时间等)         │
└────────────────────────────────────────────────┘

┌────────────────────────────────────────────────┐
│                   读取流程                        │
│                                                │
│  用户查询 → 查询向量化 → 相似度搜索 → 记忆召回    │
│                        ↓                       │
│                   召回结果:                     │
│                   - 相关对话记录                  │
│                   - 用户偏好                      │
│                   - 历史事件                      │
└────────────────────────────────────────────────┘

向量存储实现

from typing import List, Dict, Any, Optional, Tuple
import numpy as np
from abc import ABC, abstractmethod
import hashlib
import json
from datetime import datetime
 
class VectorStore(ABC):
    """向量存储抽象基类"""
    
    @abstractmethod
    async def add(
        self,
        id: str,
        embedding: List[float],
        metadata: Dict[str, Any]
    ) -> None:
        """添加向量"""
        pass
    
    @abstractmethod
    async def search(
        self,
        query_embedding: List[float],
        top_k: int = 5,
        filter: Dict = None
    ) -> List[Dict]:
        """搜索相似向量"""
        pass
    
    @abstractmethod
    async def delete(self, id: str) -> None:
        """删除向量"""
        pass
 
class InMemoryVectorStore(VectorStore):
    """内存向量存储(开发环境使用)"""
    
    def __init__(self, dimension: int = 1536):
        self.dimension = dimension
        self.vectors: Dict[str, np.ndarray] = {}
        self.metadata: Dict[str, Dict] = {}
    
    async def add(
        self,
        id: str,
        embedding: List[float],
        metadata: Dict[str, Any]
    ) -> None:
        self.vectors[id] = np.array(embedding)
        self.metadata[id] = metadata
    
    async def search(
        self,
        query_embedding: List[float],
        top_k: int = 5,
        filter: Dict = None
    ) -> List[Dict]:
        if not self.vectors:
            return []
        
        query_vec = np.array(query_embedding)
        
        # 计算余弦相似度
        results = []
        for id, vec in self.vectors.items():
            # 过滤
            if filter:
                if not all(
                    self.metadata[id].get(k) == v
                    for k, v in filter.items()
                ):
                    continue
            
            # 相似度计算
            similarity = self._cosine_similarity(query_vec, vec)
            results.append({
                'id': id,
                'score': float(similarity),
                'metadata': self.metadata[id]
            })
        
        # 排序并返回Top-K
        results.sort(key=lambda x: x['score'], reverse=True)
        return results[:top_k]
    
    async def delete(self, id: str) -> None:
        self.vectors.pop(id, None)
        self.metadata.pop(id, None)
    
    @staticmethod
    def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
        """计算余弦相似度"""
        dot = np.dot(a, b)
        norm_a = np.linalg.norm(a)
        norm_b = np.linalg.norm(b)
        return dot / (norm_a * norm_b + 1e-8)
 
 
class PineconeVectorStore(VectorStore):
    """Pinecone向量存储(生产环境)"""
    
    def __init__(self, api_key: str, environment: str, index_name: str):
        import pinecone
        self.pinecone = pinecone
        pinecone.init(api_key=api_key, environment=environment)
        self.index = pinecone.Index(index_name)
    
    async def add(
        self,
        id: str,
        embedding: List[float],
        metadata: Dict[str, Any]
    ) -> None:
        await self.index.upsert([(id, embedding, metadata)])
    
    async def search(
        self,
        query_embedding: List[float],
        top_k: int = 5,
        filter: Dict = None
    ) -> List[Dict]:
        result = await self.index.query(
            vector=query_embedding,
            top_k=top_k,
            filter=filter,
            include_metadata=True
        )
        
        return [
            {
                'id': match['id'],
                'score': match['score'],
                'metadata': match.get('metadata', {})
            }
            for match in result['matches']
        ]
    
    async def delete(self, id: str) -> None:
        await self.index.delete(ids=[id])
 
 
# 使用示例
async def example():
    # 创建向量存储
    store = InMemoryVectorStore(dimension=1536)
    
    # 添加记忆
    await store.add(
        id="memory_001",
        embedding=[0.1] * 1536,  # 实际的embedding
        metadata={
            'user_id': 'user_123',
            'content': '用户提到想要买华为手机',
            'timestamp': datetime.now().isoformat()
        }
    )
    
    # 搜索相关记忆
    results = await store.search(
        query_embedding=[0.1] * 1536,
        top_k=3,
        filter={'user_id': 'user_123'}
    )
    
    print(results)

记忆服务封装

class MemoryService:
    """记忆服务 - 统一管理短期和长期记忆"""
    
    def __init__(
        self,
        vector_store: VectorStore,
        embedding_model: str = "text-embedding-3-small",
        short_term_max: int = 10
    ):
        self.vector_store = vector_store
        self.embedding_model = embedding_model
        self.short_term = BufferMemory(window_size=short_term_max)
        self._client = None  # OpenAI client
    
    async def store_interaction(
        self,
        user_id: str,
        session_id: str,
        user_message: str,
        assistant_response: str,
        metadata: Dict = None
    ) -> str:
        """存储对话交互到长期记忆"""
        # 生成唯一ID
        content = f"用户: {user_message}\n助手: {assistant_response}"
        memory_id = self._generate_id(user_id, session_id, content)
        
        # 向量化
        embedding = await self._get_embedding(content)
        
        # 存储
        await self.vector_store.add(
            id=memory_id,
            embedding=embedding,
            metadata={
                'user_id': user_id,
                'session_id': session_id,
                'user_message': user_message,
                'assistant_response': assistant_response,
                'timestamp': datetime.now().isoformat(),
                **(metadata or {})
            }
        )
        
        return memory_id
    
    async def retrieve_memories(
        self,
        user_id: str,
        query: str,
        top_k: int = 5,
        session_id: Optional[str] = None
    ) -> List[Dict]:
        """检索相关记忆"""
        # 向量化查询
        query_embedding = await self._get_embedding(query)
        
        # 构建过滤条件
        filter_conditions = {'user_id': user_id}
        if session_id:
            filter_conditions['session_id'] = session_id
        
        # 检索
        results = await self.vector_store.search(
            query_embedding=query_embedding,
            top_k=top_k,
            filter=filter_conditions
        )
        
        return results
    
    async def _get_embedding(self, text: str) -> List[float]:
        """获取文本向量"""
        if not self._client:
            from openai import OpenAI
            self._client = OpenAI()
        
        response = self._client.embeddings.create(
            model=self.embedding_model,
            input=text
        )
        
        return response.data[0].embedding
    
    def _generate_id(self, *parts) -> str:
        """生成唯一ID"""
        content = "|".join(str(p) for p in parts)
        return hashlib.md5(content.encode()).hexdigest()[:16]
    
    def add_short_term(self, role: str, content: str) -> None:
        """添加到短期记忆"""
        self.short_term.add_message(role, content)
    
    def get_short_term_context(self) -> str:
        """获取短期记忆上下文"""
        return self.short_term.get_context()
    
    async def get_full_context(
        self,
        user_id: str,
        query: str,
        short_term: bool = True,
        long_term: bool = True,
        top_k: int = 3
    ) -> Dict[str, Any]:
        """获取完整上下文"""
        context = {
            'short_term': '',
            'long_term': [],
            'combined': ''
        }
        
        # 短期记忆
        if short_term:
            context['short_term'] = self.get_short_term_context()
        
        # 长期记忆
        if long_term:
            long_term_results = await self.retrieve_memories(
                user_id=user_id,
                query=query,
                top_k=top_k
            )
            context['long_term'] = long_term_results
        
        # 构建组合上下文
        parts = []
        
        if context['short_term']:
            parts.append(f"【近期对话】\n{context['short_term']}")
        
        if context['long_term']:
            parts.append("【相关历史】")
            for item in context['long_term']:
                meta = item['metadata']
                parts.append(f"- 用户曾说:{meta.get('user_message', '')}")
                parts.append(f"  助手回答:{meta.get('assistant_response', '')}")
        
        context['combined'] = "\n\n".join(parts)
        
        return context
 
 
# 使用示例
async def main():
    # 初始化
    vector_store = InMemoryVectorStore()
    memory_service = MemoryService(
        vector_store=vector_store,
        embedding_model="text-embedding-3-small"
    )
    
    user_id = "user_123"
    session_id = "session_001"
    
    # 添加短期记忆
    memory_service.add_short_term("user", "你好,我想买一部手机")
    memory_service.add_short_term("assistant", "好的,请问您有什么品牌偏好?")
    memory_service.add_short_term("user", "华为的,预算5000左右")
    
    # 存储到长期记忆
    await memory_service.store_interaction(
        user_id=user_id,
        session_id=session_id,
        user_message="你好,我想买一部手机",
        assistant_response="好的,请问您有什么品牌偏好?",
        metadata={'topic': '手机咨询'}
    )
    
    # 获取完整上下文
    context = await memory_service.get_full_context(
        user_id=user_id,
        query="我的手机买了吗"
    )
    
    print("短期记忆:")
    print(context['short_term'])
    print("\n长期记忆:")
    print(context['long_term'])
    print("\n组合上下文:")
    print(context['combined'])

知识图谱记忆

知识图谱概念

知识图谱是一种结构化的记忆形式,存储实体和它们之间的关系:

┌────────────────────────────────────────────────┐
│                   知识图谱示例                     │
│                                                │
│              ┌─────────┐                        │
│              │  用户A  │                        │
│              └────┬────┘                        │
│                   │ 购买                         │
│         ┌────────┴────────┐                    │
│         ↓                  ↓                     │
│   ┌─────────┐        ┌─────────┐                │
│   │ 华为手机 │        │ 苹果手机 │                │
│   └────┬────┘        └─────────┘                │
│        │ 使用                                │
│        ↓                                      │
│   ┌─────────┐                                │
│   │ EMUI系统 │                                │
│   └─────────┘                                │
└────────────────────────────────────────────────┘

知识图谱实现

from typing import List, Dict, Any, Optional, Set
from dataclasses import dataclass, field
from enum import Enum
 
class EntityType(Enum):
    """实体类型"""
    PERSON = "person"
    PRODUCT = "product"
    ORGANIZATION = "organization"
    LOCATION = "location"
    EVENT = "event"
    CONCEPT = "concept"
 
class RelationType(Enum):
    """关系类型"""
    PURCHASED = "purchased"
    USES = "uses"
    WORKS_FOR = "works_for"
    LOCATED_IN = "located_in"
    RELATED_TO = "related_to"
    IS_A = "is_a"
 
@dataclass
class Entity:
    """实体"""
    id: str
    type: EntityType
    name: str
    properties: Dict[str, Any] = field(default_factory=dict)
    created_at: datetime = field(default_factory=datetime.now)
 
@dataclass
class Relation:
    """关系"""
    id: str
    source_id: str
    target_id: str
    type: RelationType
    properties: Dict[str, Any] = field(default_factory=dict)
    created_at: datetime = field(default_factory=datetime.now)
 
class KnowledgeGraph:
    """知识图谱"""
    
    def __init__(self):
        self.entities: Dict[str, Entity] = {}
        self.relations: Dict[str, Relation] = {}
        self.adjacency: Dict[str, Set[str]] = {}  # 邻接表
        self.entity_index: Dict[str, List[str]] = {}  # 实体名称索引
    
    def add_entity(self, entity: Entity) -> None:
        """添加实体"""
        self.entities[entity.id] = entity
        self.adjacency.setdefault(entity.id, set())
        
        # 更新索引
        name_key = entity.name.lower()
        if name_key not in self.entity_index:
            self.entity_index[name_key] = []
        self.entity_index[name_key].append(entity.id)
    
    def add_relation(self, relation: Relation) -> None:
        """添加关系"""
        self.relations[relation.id] = relation
        
        # 更新邻接表
        self.adjacency.setdefault(relation.source_id, set()).add(relation.target_id)
        self.adjacency.setdefault(relation.target_id, set()).add(relation.source_id)
    
    def get_entity(self, entity_id: str) -> Optional[Entity]:
        """获取实体"""
        return self.entities.get(entity_id)
    
    def find_entity_by_name(self, name: str) -> Optional[Entity]:
        """通过名称查找实体"""
        name_key = name.lower()
        entity_ids = self.entity_index.get(name_key, [])
        if entity_ids:
            return self.entities.get(entity_ids[0])
        return None
    
    def get_neighbors(
        self,
        entity_id: str,
        depth: int = 1,
        relation_type: RelationType = None
    ) -> List[tuple]:
        """获取邻居节点"""
        result = []
        visited = set()
        
        def dfs(current_id: str, current_depth: int):
            if current_depth > depth:
                return
            
            for relation in self._get_relations(current_id):
                if relation_type and relation.type != relation_type:
                    continue
                
                neighbor_id = relation.target_id if relation.source_id == current_id else relation.source_id
                
                if neighbor_id not in visited:
                    visited.add(neighbor_id)
                    result.append((neighbor_id, relation))
                    dfs(neighbor_id, current_depth + 1)
        
        dfs(entity_id, 0)
        return result
    
    def _get_relations(self, entity_id: str) -> List[Relation]:
        """获取实体的所有关系"""
        return [
            r for r in self.relations.values()
            if r.source_id == entity_id or r.target_id == entity_id
        ]
    
    def query_path(
        self,
        source_id: str,
        target_id: str,
        max_depth: int = 3
    ) -> List[List[str]]:
        """查询两点间的路径"""
        paths = []
        
        def dfs(current: str, path: List[str]):
            if len(path) > max_depth:
                return
            if current == target_id:
                paths.append(path.copy())
                return
            
            for neighbor in self.adjacency.get(current, set()):
                if neighbor not in path:
                    path.append(neighbor)
                    dfs(neighbor, path)
                    path.pop()
        
        dfs(source_id, [source_id])
        return paths
    
    def get_subgraph(
        self,
        entity_id: str,
        depth: int = 2
    ) -> Dict[str, Any]:
        """获取子图"""
        visited = {entity_id}
        nodes = []
        edges = []
        
        def dfs(current_id: str, current_depth: int):
            if current_depth > depth:
                return
            
            for neighbor_id, relation in self.get_neighbors(current_id, depth=1):
                if neighbor_id not in visited:
                    visited.add(neighbor_id)
                    nodes.append(self.entities.get(neighbor_id))
                
                edges.append({
                    'source': relation.source_id,
                    'target': relation.target_id,
                    'type': relation.type.value
                })
                
                dfs(neighbor_id, current_depth + 1)
        
        dfs(entity_id, 0)
        
        return {
            'entities': [e for e in nodes if e],
            'relations': edges
        }
 
 
# 使用示例
def example():
    kg = KnowledgeGraph()
    
    # 添加实体
    user = Entity(
        id="user_001",
        type=EntityType.PERSON,
        name="张三",
        properties={'偏好': '华为手机'}
    )
    phone = Entity(
        id="product_001",
        type=EntityType.PRODUCT,
        name="华为Mate60",
        properties={'价格': 6999, '品牌': '华为'}
    )
    
    kg.add_entity(user)
    kg.add_entity(phone)
    
    # 添加关系
    relation = Relation(
        id="rel_001",
        source_id="user_001",
        target_id="product_001",
        type=RelationType.PURCHASED,
        properties={'时间': '2024-01-15', '价格': 6999}
    )
    kg.add_relation(relation)
    
    # 查询
    print(kg.get_entity("user_001"))
    print(kg.find_entity_by_name("华为"))
    
    # 获取邻居
    neighbors = kg.get_neighbors("user_001")
    for neighbor_id, relation in neighbors:
        entity = kg.entities.get(neighbor_id)
        print(f"关联实体: {entity.name}, 关系: {relation.type.value}")

从对话中提取知识

class KnowledgeExtractor:
    """知识提取器 - 从对话中提取实体和关系"""
    
    def __init__(self, llm_client):
        self.llm = llm_client
    
    async def extract_from_conversation(
        self,
        conversation: str
    ) -> Dict[str, Any]:
        """从对话中提取知识"""
        prompt = f"""
从以下对话中提取实体和关系:
 
{conversation}
 
请以JSON格式返回:
{{
    "entities": [
        {{
            "name": "实体名",
            "type": "实体类型(person/product/organization/location/event/concept)",
            "properties": {{}}
        }}
    ],
    "relations": [
        {{
            "source": "实体1名称",
            "target": "实体2名称",
            "type": "关系类型(purchased/uses/works_for/located_in/related_to/is_a)"
        }}
    ]
}}
 
实体类型:person, product, organization, location, event, concept
关系类型:purchased, uses, works_for, located_in, related_to, is_a
"""
        
        response = await self.llm.chat(prompt)
        
        # 解析JSON响应
        import json
        try:
            return json.loads(response)
        except:
            return {"entities": [], "relations": []}
    
    async def extract_and_store(
        self,
        conversation: str,
        knowledge_graph: KnowledgeGraph
    ) -> None:
        """提取并存储到知识图谱"""
        extraction = await self.extract_from_conversation(conversation)
        
        # 存储实体
        for entity_data in extraction.get("entities", []):
            entity = Entity(
                id=self._generate_entity_id(entity_data["name"]),
                type=EntityType(entity_data["type"]),
                name=entity_data["name"],
                properties=entity_data.get("properties", {})
            )
            knowledge_graph.add_entity(entity)
        
        # 存储关系
        for relation_data in extraction.get("relations", []):
            source = knowledge_graph.find_entity_by_name(relation_data["source"])
            target = knowledge_graph.find_entity_by_name(relation_data["target"])
            
            if source and target:
                relation = Relation(
                    id=self._generate_relation_id(source.id, target.id),
                    source_id=source.id,
                    target_id=target.id,
                    type=RelationType(relation_data["type"])
                )
                knowledge_graph.add_relation(relation)
    
    def _generate_entity_id(self, name: str) -> str:
        import hashlib
        return hashlib.md5(f"entity:{name}".encode()).hexdigest()[:16]
    
    def _generate_relation_id(self, source: str, target: str) -> str:
        import hashlib
        return hashlib.md5(f"rel:{source}:{target}".encode()).hexdigest()[:16]

记忆检索与融合

多记忆源检索

class MemoryRetriever:
    """记忆检索器 - 整合多个记忆源"""
    
    def __init__(
        self,
        short_term: BufferMemory,
        long_term: MemoryService,
        knowledge_graph: KnowledgeGraph
    ):
        self.short_term = short_term
        self.long_term = long_term
        self.knowledge_graph = knowledge_graph
    
    async def retrieve(
        self,
        user_id: str,
        query: str,
        include_short_term: bool = True,
        include_long_term: bool = True,
        include_graph: bool = True,
        top_k: int = 5
    ) -> Dict[str, Any]:
        """多源检索"""
        results = {
            'short_term': [],
            'long_term': [],
            'graph': [],
            'context': []
        }
        
        # 短期记忆检索
        if include_short_term:
            results['short_term'] = self._search_short_term(query)
        
        # 长期记忆检索
        if include_long_term:
            results['long_term'] = await self.long_term.retrieve_memories(
                user_id=user_id,
                query=query,
                top_k=top_k
            )
        
        # 知识图谱检索
        if include_graph:
            results['graph'] = self._search_graph(query)
        
        # 融合构建上下文
        results['context'] = self._fuse_context(results)
        
        return results
    
    def _search_short_term(self, query: str) -> List[Dict]:
        """搜索短期记忆"""
        messages = self.short_term.get_messages()
        relevant = []
        query_words = set(query.lower().split())
        
        for msg in messages:
            msg_words = set(msg.content.lower().split())
            # 计算重叠度
            overlap = len(query_words & msg_words)
            if overlap > 0:
                relevant.append({
                    'role': msg.role,
                    'content': msg.content,
                    'score': overlap / len(query_words | msg_words)
                })
        
        return relevant
    
    def _search_graph(self, query: str) -> List[Dict]:
        """搜索知识图谱"""
        results = []
        query_words = query.lower().split()
        
        for entity in self.knowledge_graph.entities.values():
            if any(word in entity.name.lower() for word in query_words):
                neighbors = self.knowledge_graph.get_neighbors(entity.id, depth=1)
                results.append({
                    'entity': entity,
                    'neighbors': [
                        (neighbor_id, rel)
                        for neighbor_id, rel in neighbors
                    ]
                })
        
        return results
    
    def _fuse_context(self, results: Dict) -> str:
        """融合多源记忆构建上下文"""
        context_parts = []
        
        # 添加图谱知识
        if results['graph']:
            context_parts.append("【相关背景知识】")
            for item in results['graph'][:2]:
                entity = item['entity']
                context_parts.append(f"- {entity.name}{entity.type.value})")
                for neighbor_id, rel in item['neighbors'][:2]:
                    neighbor = self.knowledge_graph.entities.get(neighbor_id)
                    if neighbor:
                        context_parts.append(f"  - 与{neighbor.name}的关系:{rel.type.value}")
        
        # 添加历史对话
        if results['long_term']:
            context_parts.append("\n【相关历史对话】")
            for item in results['long_term'][:2]:
                meta = item['metadata']
                context_parts.append(f"用户:{meta.get('user_message', '')}")
                context_parts.append(f"助手:{meta.get('assistant_response', '')}")
        
        # 添加当前对话
        if results['short_term']:
            context_parts.append("\n【当前对话】")
            for item in results['short_term']:
                context_parts.append(f"{item['role']}{item['content']}")
        
        return "\n".join(context_parts)

上下文组装

class ContextBuilder:
    """上下文构建器 - 构建完整的对话上下文"""
    
    def __init__(
        self,
        max_tokens: int = 8000,
        system_prompt: str = ""
    ):
        self.max_tokens = max_tokens
        self.system_prompt = system_prompt
    
    def build(
        self,
        system_template: str,
        memory_context: str,
        recent_messages: List[Message],
        user_input: str
    ) -> List[Dict]:
        """构建完整的对话上下文"""
        messages = []
        
        # 系统提示词
        if self.system_prompt:
            messages.append({
                "role": "system",
                "content": self.system_prompt
            })
        
        # 记忆上下文
        if memory_context:
            messages.append({
                "role": "system",
                "content": f"【相关记忆】\n{memory_context}"
            })
        
        # 近期消息
        for msg in recent_messages:
            messages.append({
                "role": msg.role,
                "content": msg.content
            })
        
        # 用户输入
        messages.append({
            "role": "user",
            "content": user_input
        })
        
        # Token裁剪
        messages = self._trim_to_token_limit(messages)
        
        return messages
    
    def _trim_to_token_limit(self, messages: List[Dict]) -> List[Dict]:
        """根据Token限制裁剪"""
        while self._count_tokens(messages) > self.max_tokens and len(messages) > 2:
            # 优先裁剪早期消息
            if len(messages) > 2 and messages[1]["role"] == "system":
                # 裁剪记忆上下文
                if len(messages[1]["content"]) > 500:
                    messages[1]["content"] = messages[1]["content"][:500] + "\n...(已截断)"
                else:
                    messages.pop(1)
            elif len(messages) > 1:
                messages.pop(1)
            else:
                break
        
        return messages
    
    def _count_tokens(self, messages: List[Dict]) -> int:
        """估算Token数"""
        return sum(
            len(m["content"]) // 4
            for m in messages
        )

个性化记忆

用户画像存储

from typing import Dict, Any
from dataclasses import dataclass, asdict
from typing import Optional, List
 
@dataclass
class UserProfile:
    """用户画像"""
    user_id: str
    name: Optional[str] = None
    
    # 偏好设置
    preferences: Dict[str, Any] = None
    
    # 沟通风格
    interaction_style: str = "formal"  # formal/casual/technical
    
    # 兴趣话题
    topics_of_interest: List[str] = None
    
    # 沟通模式
    communication_patterns: Dict[str, Any] = None
    
    # 最后更新
    last_updated: datetime = None
    
    def __post_init__(self):
        if self.preferences is None:
            self.preferences = {}
        if self.topics_of_interest is None:
            self.topics_of_interest = []
        if self.communication_patterns is None:
            self.communication_patterns = {}
        if self.last_updated is None:
            self.last_updated = datetime.now()
 
class UserProfileManager:
    """用户画像管理器"""
    
    def __init__(self, db=None):
        self.db = db  # 数据库连接
        self._cache: Dict[str, UserProfile] = {}
    
    async def get_profile(self, user_id: str) -> UserProfile:
        """获取用户画像"""
        # 先查缓存
        if user_id in self._cache:
            return self._cache[user_id]
        
        # 从数据库加载
        if self.db:
            data = await self.db.user_profiles.find_one({"user_id": user_id})
            if data:
                profile = UserProfile(**data)
            else:
                profile = UserProfile(user_id=user_id)
        else:
            profile = UserProfile(user_id=user_id)
        
        self._cache[user_id] = profile
        return profile
    
    async def update_profile(
        self,
        user_id: str,
        updates: Dict[str, Any]
    ) -> None:
        """更新用户画像"""
        profile = await self.get_profile(user_id)
        
        # 更新字段
        for key, value in updates.items():
            if hasattr(profile, key):
                setattr(profile, key, value)
        
        profile.last_updated = datetime.now()
        
        # 保存到数据库
        if self.db:
            await self.db.user_profiles.update_one(
                {"user_id": user_id},
                {"$set": asdict(profile)},
                upsert=True
            )
        
        # 更新缓存
        self._cache[user_id] = profile
    
    async def learn_from_interaction(
        self,
        user_id: str,
        user_message: str,
        assistant_response: str
    ) -> None:
        """从交互中学习用户偏好"""
        profile = await self.get_profile(user_id)
        
        # 学习沟通风格
        casual_indicators = ["!", "?", "哈", "哈", "~", "啦"]
        technical_indicators = ["分析", "评估", "建议", "原因"]
        
        if any(ind in user_message for ind in casual_indicators):
            profile.interaction_style = "casual"
        elif any(ind in user_message for ind in technical_indicators):
            profile.interaction_style = "technical"
        
        # 学习兴趣话题
        topic_keywords = {
            "技术": ["代码", "开发", "API", "系统", "技术"],
            "商业": ["市场", "销售", "运营", "增长", "商业"],
            "创意": ["设计", "创意", "灵感", "想法", "设计"]
        }
        
        for topic, keywords in topic_keywords.items():
            if any(kw in user_message for kw in keywords):
                if topic not in profile.topics_of_interest:
                    profile.topics_of_interest.append(topic)
        
        await self.update_profile(user_id, asdict(profile))

个性化响应适配

class PersonalizedResponseBuilder:
    """个性化响应构建器"""
    
    def __init__(self, profile_manager: UserProfileManager):
        self.profile_manager = profile_manager
    
    async def adapt_response(
        self,
        base_response: str,
        user_id: str
    ) -> str:
        """根据用户画像适配响应"""
        profile = await self.profile_manager.get_profile(user_id)
        
        response = base_response
        
        # 根据用户偏好调整
        if profile.interaction_style == "casual":
            response = self._make_casual(response)
        elif profile.interaction_style == "technical":
            response = self._add_technical_depth(response)
        
        # 添加个性化元素
        if profile.topics_of_interest:
            response = self._add_personalized_content(response, profile)
        
        return response
    
    def _make_casual(self, text: str) -> str:
        """转为休闲风格"""
        replacements = {
            "您好": "嗨",
            "请问": "有啥",
            "感谢": "谢啦",
            "如果": "要是",
            "可以": "能",
            "帮助": "帮忙"
        }
        for formal, casual in replacements.items():
            text = text.replace(formal, casual)
        return text
    
    def _add_technical_depth(self, text: str) -> str:
        """添加技术细节"""
        return text + "\n\n如果需要更详细的技术说明,随时告诉我。"
    
    def _add_personalized_content(
        self,
        text: str,
        profile: UserProfile
    ) -> str:
        """添加个性化内容"""
        if "技术" in profile.topics_of_interest:
            text += "\n\n另外,关于技术方面的问题,我可以帮你深入分析。"
        if "商业" in profile.topics_of_interest:
            text += "\n\n如果需要市场分析或商业建议,也可以问我。"
        
        return text

完整实战案例

智能对话Agent

class ConversationalAgent:
    """具备完整记忆能力的对话Agent"""
    
    def __init__(self, config: dict):
        # 初始化各组件
        self.short_term = SlidingWindowMemory(
            max_tokens=config.get('short_term_tokens', 4000)
        )
        self.long_term = MemoryService(
            vector_store=self._init_vector_store(config),
            embedding_model=config.get('embedding_model', 'text-embedding-3-small')
        )
        self.knowledge_graph = KnowledgeGraph()
        self.retriever = MemoryRetriever(
            short_term=self.short_term,
            long_term=self.long_term,
            knowledge_graph=self.knowledge_graph
        )
        self.context_builder = ContextBuilder(
            max_tokens=config.get('context_limit', 8000),
            system_prompt=config.get('system_prompt', '')
        )
        self.llm = OpenAIChatClient(config['model'])
        self.profile_manager = UserProfileManager()
        self.response_adapter = PersonalizedResponseBuilder(self.profile_manager)
    
    def _init_vector_store(self, config: dict):
        """初始化向量存储"""
        if config.get('use_pinecone'):
            return PineconeVectorStore(
                api_key=config['pinecone_api_key'],
                environment=config['pinecone_env'],
                index_name=config['pinecone_index']
            )
        return InMemoryVectorStore()
    
    async def chat(
        self,
        user_id: str,
        message: str,
        session_id: str = None
    ) -> str:
        """处理对话"""
        self.session_id = session_id or f"session_{int(time.time())}"
        
        # 1. 检索相关记忆
        memories = await self.retriever.retrieve(
            user_id=user_id,
            query=message,
            include_long_term=True,
            include_graph=True
        )
        
        # 2. 构建上下文
        messages = self.context_builder.build(
            system_template=self.system_prompt,
            memory_context=memories['context'],
            recent_messages=self.short_term.get_messages(),
            user_input=message
        )
        
        # 3. 调用LLM
        response = await self.llm.chat(messages)
        
        # 4. 更新记忆
        await self._update_memories(
            user_id=user_id,
            user_message=message,
            assistant_response=response
        )
        
        # 5. 学习用户偏好
        await self.profile_manager.learn_from_interaction(
            user_id=user_id,
            user_message=message,
            assistant_response=response
        )
        
        # 6. 个性化适配
        response = await self.response_adapter.adapt_response(
            base_response=response,
            user_id=user_id
        )
        
        return response
    
    async def _update_memories(
        self,
        user_id: str,
        user_message: str,
        assistant_response: str
    ) -> None:
        """更新各类记忆"""
        # 短期记忆
        self.short_term.add_message("user", user_message)
        self.short_term.add_message("assistant", assistant_response)
        
        # 长期记忆(异步)
        await self.long_term.store_interaction(
            user_id=user_id,
            session_id=self.session_id,
            user_message=user_message,
            assistant_response=assistant_response
        )
        
        # 知识图谱(异步)
        extractor = KnowledgeExtractor(self.llm)
        knowledge = await extractor.extract_from_conversation(
            f"用户:{user_message}\n助手:{assistant_response}"
        )
        # 存储到图谱(简化处理)
        # 实际项目中需要完善实现
 
 
# 使用示例
async def main():
    agent = ConversationalAgent({
        'model': 'gpt-4o',
        'short_term_tokens': 4000,
        'context_limit': 8000,
        'embedding_model': 'text-embedding-3-small',
        'system_prompt': '你是一个专业友好的AI助手。'
    })
    
    # 第一次对话
    response1 = await agent.chat(
        user_id="user_123",
        message="你好,我叫张三,我喜欢技术类的内容",
        session_id="session_001"
    )
    print(f"Agent: {response1}")
    
    # 第二次对话
    response2 = await agent.chat(
        user_id="user_123",
        message="我之前买了个什么手机来着?",
        session_id="session_001"
    )
    print(f"Agent: {response2}")

总结

一个完善的记忆系统应该包含:

层级存储方式容量生命周期实现复杂度
上下文窗口LLM内存Token限制单次调用
短期记忆内存Token限制会话内
长期记忆向量数据库无限持久
知识图谱图数据库无限持久
用户画像结构化存储有限持久

设计建议:

  • 短期记忆必做,实现简单效果明显
  • 长期记忆视需求而定,复杂场景需要
  • 知识图谱成本高,适合专业领域应用
  • 用户画像提升个性化体验

相关资源


本文档由归愚知识系统生成 last updated: 2026-04-24