知识图谱与大模型结合:RAG系统的设计与优化

摘要

检索增强生成(Retrieval-Augmented Generation, RAG)系统通过将外部知识库与大语言模型相结合,有效解决了大模型知识更新滞后、幻觉问题和领域知识不足等挑战。本文深入探讨了知识图谱与大语言模型结合的RAG系统设计原理、技术架构、优化策略和实际应用。文章首先介绍了知识图谱的基本概念和构建方法,然后详细阐述了RAG系统的核心组件、工作流程和关键技术,接着讨论了系统优化的多个维度,最后通过实际案例展示了RAG系统在不同领域的应用效果。研究表明,合理设计的RAG系统能够显著提升大语言模型在特定领域的准确性和可靠性,为构建更加智能和可信的AI应用提供了重要技术路径。

关键词:知识图谱, RAG, 检索增强生成, 大语言模型, 知识融合, 向量检索

1. 引言

1.1 背景与动机

大语言模型(Large Language Models, LLMs)在自然语言处理领域取得了革命性进展,展现出强大的语言理解和生成能力。然而,这些模型也面临着一些关键挑战:

  1. 知识时效性问题:模型训练数据存在时间截止点,无法获取最新信息
  2. 幻觉现象:模型可能生成看似合理但实际错误的信息
  3. 领域知识局限:在特定专业领域的知识深度和准确性有限
  4. 可解释性不足:难以追溯生成内容的知识来源

检索增强生成(RAG)技术通过将外部知识库与大语言模型相结合,为解决这些问题提供了有效途径。RAG系统能够在生成过程中动态检索相关知识,确保输出内容的准确性和时效性。

1.2 知识图谱的价值

知识图谱作为一种结构化的知识表示方法,具有以下优势:

  • 结构化表示:以实体、关系、属性的形式组织知识
  • 语义丰富性:包含丰富的语义关系和推理能力
  • 可解释性强:提供清晰的知识来源和推理路径
  • 更新便捷:支持增量更新和知识维护

将知识图谱与大语言模型结合,能够充分发挥两者的优势,构建更加智能和可靠的AI系统。

2. 知识图谱基础

2.1 知识图谱的基本概念

知识图谱是一种用于表示现实世界中实体及其关系的图结构数据模型。其基本组成要素包括:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import networkx as nx
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple, Any
import json
import numpy as np
from dataclasses import dataclass

@dataclass
class Entity:
"""实体类"""
id: str
name: str
type: str
properties: Dict[str, Any]

def __hash__(self):
return hash(self.id)

@dataclass
class Relation:
"""关系类"""
id: str
name: str
source: str
target: str
properties: Dict[str, Any]

class KnowledgeGraph:
"""知识图谱类"""

def __init__(self):
self.entities: Dict[str, Entity] = {}
self.relations: Dict[str, Relation] = {}
self.graph = nx.MultiDiGraph()

def add_entity(self, entity: Entity):
"""添加实体"""
self.entities[entity.id] = entity
self.graph.add_node(entity.id, **entity.properties, name=entity.name, type=entity.type)

def add_relation(self, relation: Relation):
"""添加关系"""
self.relations[relation.id] = relation
self.graph.add_edge(
relation.source,
relation.target,
key=relation.id,
name=relation.name,
**relation.properties
)

def get_neighbors(self, entity_id: str, relation_type: str = None) -> List[str]:
"""获取邻居实体"""
neighbors = []
for neighbor in self.graph.neighbors(entity_id):
if relation_type is None:
neighbors.append(neighbor)
else:
# 检查关系类型
edges = self.graph.get_edge_data(entity_id, neighbor)
for edge_data in edges.values():
if edge_data.get('name') == relation_type:
neighbors.append(neighbor)
break
return neighbors

def find_path(self, start_entity: str, end_entity: str, max_length: int = 3) -> List[List[str]]:
"""查找实体间的路径"""
try:
paths = list(nx.all_simple_paths(
self.graph, start_entity, end_entity, cutoff=max_length
))
return paths
except nx.NetworkXNoPath:
return []

def get_entity_info(self, entity_id: str) -> Dict[str, Any]:
"""获取实体详细信息"""
if entity_id not in self.entities:
return None

entity = self.entities[entity_id]
node_data = self.graph.nodes[entity_id]

# 获取相关关系
outgoing_relations = []
incoming_relations = []

for neighbor in self.graph.neighbors(entity_id):
edges = self.graph.get_edge_data(entity_id, neighbor)
for edge_key, edge_data in edges.items():
outgoing_relations.append({
'relation': edge_data.get('name'),
'target': neighbor,
'target_name': self.entities.get(neighbor, {}).name if neighbor in self.entities else neighbor
})

for predecessor in self.graph.predecessors(entity_id):
edges = self.graph.get_edge_data(predecessor, entity_id)
for edge_key, edge_data in edges.items():
incoming_relations.append({
'relation': edge_data.get('name'),
'source': predecessor,
'source_name': self.entities.get(predecessor, {}).name if predecessor in self.entities else predecessor
})

return {
'entity': entity,
'properties': node_data,
'outgoing_relations': outgoing_relations,
'incoming_relations': incoming_relations
}

def visualize_subgraph(self, center_entity: str, depth: int = 2):
"""可视化子图"""
# 获取子图节点
subgraph_nodes = set([center_entity])
current_level = set([center_entity])

for _ in range(depth):
next_level = set()
for node in current_level:
neighbors = list(self.graph.neighbors(node)) + list(self.graph.predecessors(node))
next_level.update(neighbors)
subgraph_nodes.update(next_level)
current_level = next_level

# 创建子图
subgraph = self.graph.subgraph(subgraph_nodes)

# 绘制图形
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(subgraph, k=2, iterations=50)

# 绘制节点
node_colors = ['red' if node == center_entity else 'lightblue' for node in subgraph.nodes()]
nx.draw_networkx_nodes(subgraph, pos, node_color=node_colors, node_size=1000, alpha=0.7)

# 绘制边
nx.draw_networkx_edges(subgraph, pos, alpha=0.5, arrows=True, arrowsize=20)

# 绘制标签
labels = {node: self.entities[node].name if node in self.entities else node
for node in subgraph.nodes()}
nx.draw_networkx_labels(subgraph, pos, labels, font_size=8)

# 绘制边标签
edge_labels = {}
for u, v, data in subgraph.edges(data=True):
edge_labels[(u, v)] = data.get('name', '')
nx.draw_networkx_edge_labels(subgraph, pos, edge_labels, font_size=6)

plt.title(f"Knowledge Graph Subgraph (Center: {self.entities[center_entity].name})")
plt.axis('off')
plt.tight_layout()
plt.show()

# 示例:构建一个简单的知识图谱
def build_sample_kg():
"""构建示例知识图谱"""
kg = KnowledgeGraph()

# 添加实体
entities = [
Entity("person_1", "张三", "Person", {"age": 30, "occupation": "工程师"}),
Entity("person_2", "李四", "Person", {"age": 28, "occupation": "设计师"}),
Entity("company_1", "科技公司A", "Company", {"industry": "软件开发", "size": "大型"}),
Entity("project_1", "AI项目", "Project", {"status": "进行中", "budget": 1000000}),
Entity("tech_1", "机器学习", "Technology", {"category": "AI", "maturity": "成熟"})
]

for entity in entities:
kg.add_entity(entity)

# 添加关系
relations = [
Relation("rel_1", "works_for", "person_1", "company_1", {"start_date": "2020-01-01"}),
Relation("rel_2", "works_for", "person_2", "company_1", {"start_date": "2021-06-01"}),
Relation("rel_3", "manages", "person_1", "project_1", {"role": "项目经理"}),
Relation("rel_4", "participates_in", "person_2", "project_1", {"role": "设计师"}),
Relation("rel_5", "uses_technology", "project_1", "tech_1", {"usage_level": "核心"}),
Relation("rel_6", "knows", "person_1", "tech_1", {"proficiency": "专家"})
]

for relation in relations:
kg.add_relation(relation)

return kg

# 测试知识图谱
if __name__ == "__main__":
kg = build_sample_kg()

# 查询示例
print("张三的邻居实体:", kg.get_neighbors("person_1"))
print("张三工作的公司:", kg.get_neighbors("person_1", "works_for"))

# 获取实体详细信息
info = kg.get_entity_info("person_1")
print("\n张三的详细信息:")
print(f"姓名: {info['entity'].name}")
print(f"类型: {info['entity'].type}")
print(f"属性: {info['entity'].properties}")
print(f"出度关系: {info['outgoing_relations']}")
print(f"入度关系: {info['incoming_relations']}")

2.2 知识图谱构建流程

知识图谱的构建是一个复杂的过程,主要包括以下步骤:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import re
import spacy
from transformers import AutoTokenizer, AutoModel
import torch
from sklearn.cluster import DBSCAN
from sklearn.metrics.pairwise import cosine_similarity

class KnowledgeGraphBuilder:
"""知识图谱构建器"""

def __init__(self, model_name="bert-base-chinese"):
self.nlp = spacy.load("zh_core_web_sm") # 中文NLP模型
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.kg = KnowledgeGraph()

def extract_entities_from_text(self, text: str) -> List[Dict[str, Any]]:
"""从文本中提取实体"""
doc = self.nlp(text)
entities = []

# 使用spaCy的命名实体识别
for ent in doc.ents:
entities.append({
'text': ent.text,
'label': ent.label_,
'start': ent.start_char,
'end': ent.end_char,
'confidence': 1.0 # spaCy不直接提供置信度
})

# 补充基于规则的实体提取
# 提取邮箱
email_pattern = r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'
for match in re.finditer(email_pattern, text):
entities.append({
'text': match.group(),
'label': 'EMAIL',
'start': match.start(),
'end': match.end(),
'confidence': 0.9
})

# 提取电话号码
phone_pattern = r'\b(?:\+86)?\s*1[3-9]\d{9}\b'
for match in re.finditer(phone_pattern, text):
entities.append({
'text': match.group(),
'label': 'PHONE',
'start': match.start(),
'end': match.end(),
'confidence': 0.9
})

return entities

def extract_relations_from_text(self, text: str, entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""从文本中提取关系"""
doc = self.nlp(text)
relations = []

# 基于依存句法分析提取关系
for token in doc:
if token.dep_ in ['nsubj', 'dobj', 'pobj']: # 主语、直接宾语、介词宾语
head = token.head
if head.pos_ == 'VERB': # 动词作为关系
# 查找对应的实体
subj_entity = self._find_entity_by_position(entities, token.idx, token.idx + len(token.text))
obj_entities = []

# 查找与该动词相关的其他实体
for child in head.children:
if child != token and child.dep_ in ['dobj', 'pobj']:
obj_entity = self._find_entity_by_position(entities, child.idx, child.idx + len(child.text))
if obj_entity:
obj_entities.append(obj_entity)

# 创建关系
if subj_entity and obj_entities:
for obj_entity in obj_entities:
relations.append({
'subject': subj_entity['text'],
'predicate': head.lemma_,
'object': obj_entity['text'],
'confidence': 0.7
})

return relations

def _find_entity_by_position(self, entities: List[Dict[str, Any]], start: int, end: int) -> Dict[str, Any]:
"""根据位置查找实体"""
for entity in entities:
if entity['start'] <= start < entity['end'] or entity['start'] < end <= entity['end']:
return entity
return None

def entity_linking(self, extracted_entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""实体链接:将提取的实体链接到知识图谱中的标准实体"""
linked_entities = []

for entity in extracted_entities:
# 简化的实体链接:基于文本相似度
best_match = None
best_score = 0.0

for kg_entity_id, kg_entity in self.kg.entities.items():
# 计算文本相似度
similarity = self._calculate_text_similarity(entity['text'], kg_entity.name)
if similarity > best_score and similarity > 0.8: # 阈值
best_score = similarity
best_match = kg_entity_id

if best_match:
linked_entities.append({
**entity,
'linked_entity_id': best_match,
'linking_confidence': best_score
})
else:
# 创建新实体
new_entity_id = f"entity_{len(self.kg.entities)}"
new_entity = Entity(
id=new_entity_id,
name=entity['text'],
type=entity['label'],
properties={'confidence': entity['confidence']}
)
self.kg.add_entity(new_entity)

linked_entities.append({
**entity,
'linked_entity_id': new_entity_id,
'linking_confidence': 1.0
})

return linked_entities

def _calculate_text_similarity(self, text1: str, text2: str) -> float:
"""计算文本相似度"""
# 使用BERT计算语义相似度
inputs1 = self.tokenizer(text1, return_tensors="pt", padding=True, truncation=True)
inputs2 = self.tokenizer(text2, return_tensors="pt", padding=True, truncation=True)

with torch.no_grad():
outputs1 = self.model(**inputs1)
outputs2 = self.model(**inputs2)

# 使用[CLS]标记的嵌入表示
emb1 = outputs1.last_hidden_state[:, 0, :].numpy()
emb2 = outputs2.last_hidden_state[:, 0, :].numpy()

similarity = cosine_similarity(emb1, emb2)[0][0]
return float(similarity)

def build_from_text(self, text: str):
"""从文本构建知识图谱"""
# 1. 实体提取
entities = self.extract_entities_from_text(text)
print(f"提取到 {len(entities)} 个实体")

# 2. 关系提取
relations = self.extract_relations_from_text(text, entities)
print(f"提取到 {len(relations)} 个关系")

# 3. 实体链接
linked_entities = self.entity_linking(entities)
print(f"链接了 {len(linked_entities)} 个实体")

# 4. 添加关系到知识图谱
for relation in relations:
# 查找对应的实体ID
subj_id = None
obj_id = None

for entity in linked_entities:
if entity['text'] == relation['subject']:
subj_id = entity['linked_entity_id']
if entity['text'] == relation['object']:
obj_id = entity['linked_entity_id']

if subj_id and obj_id:
relation_id = f"rel_{len(self.kg.relations)}"
kg_relation = Relation(
id=relation_id,
name=relation['predicate'],
source=subj_id,
target=obj_id,
properties={'confidence': relation['confidence']}
)
self.kg.add_relation(kg_relation)

return self.kg

# 示例使用
if __name__ == "__main__":
builder = KnowledgeGraphBuilder()

sample_text = """
张三是一名软件工程师,他在北京的科技公司工作。
张三负责开发人工智能项目,该项目使用了机器学习技术。
李四是张三的同事,他们一起参与了这个项目的开发。
公司的邮箱是contact@techcompany.com,联系电话是13800138000。
"""

kg = builder.build_from_text(sample_text)
print(f"\n构建完成!知识图谱包含 {len(kg.entities)} 个实体和 {len(kg.relations)} 个关系")

2.3 知识图谱存储与查询

高效的存储和查询是知识图谱应用的关键:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import sqlite3
from typing import Optional
import json

class KnowledgeGraphDB:
"""知识图谱数据库管理器"""

def __init__(self, db_path: str = "knowledge_graph.db"):
self.db_path = db_path
self.init_database()

def init_database(self):
"""初始化数据库"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()

# 创建实体表
cursor.execute("""
CREATE TABLE IF NOT EXISTS entities (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
type TEXT NOT NULL,
properties TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")

# 创建关系表
cursor.execute("""
CREATE TABLE IF NOT EXISTS relations (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
source_id TEXT NOT NULL,
target_id TEXT NOT NULL,
properties TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (source_id) REFERENCES entities (id),
FOREIGN KEY (target_id) REFERENCES entities (id)
)
""")

# 创建索引
cursor.execute("CREATE INDEX IF NOT EXISTS idx_entity_name ON entities (name)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_entity_type ON entities (type)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_relation_name ON relations (name)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_relation_source ON relations (source_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_relation_target ON relations (target_id)")

conn.commit()
conn.close()

def save_entity(self, entity: Entity):
"""保存实体"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()

cursor.execute("""
INSERT OR REPLACE INTO entities (id, name, type, properties)
VALUES (?, ?, ?, ?)
""", (entity.id, entity.name, entity.type, json.dumps(entity.properties)))

conn.commit()
conn.close()

def save_relation(self, relation: Relation):
"""保存关系"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()

cursor.execute("""
INSERT OR REPLACE INTO relations (id, name, source_id, target_id, properties)
VALUES (?, ?, ?, ?, ?)
""", (relation.id, relation.name, relation.source, relation.target, json.dumps(relation.properties)))

conn.commit()
conn.close()

def get_entity(self, entity_id: str) -> Optional[Entity]:
"""获取实体"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()

cursor.execute("SELECT id, name, type, properties FROM entities WHERE id = ?", (entity_id,))
row = cursor.fetchone()

conn.close()

if row:
return Entity(
id=row[0],
name=row[1],
type=row[2],
properties=json.loads(row[3]) if row[3] else {}
)
return None

def search_entities(self, name_pattern: str = None, entity_type: str = None, limit: int = 100) -> List[Entity]:
"""搜索实体"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()

query = "SELECT id, name, type, properties FROM entities WHERE 1=1"
params = []

if name_pattern:
query += " AND name LIKE ?"
params.append(f"%{name_pattern}%")

if entity_type:
query += " AND type = ?"
params.append(entity_type)

query += f" LIMIT {limit}"

cursor.execute(query, params)
rows = cursor.fetchall()

conn.close()

entities = []
for row in rows:
entities.append(Entity(
id=row[0],
name=row[1],
type=row[2],
properties=json.loads(row[3]) if row[3] else {}
))

return entities

def get_relations(self, source_id: str = None, target_id: str = None, relation_name: str = None) -> List[Relation]:
"""获取关系"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()

query = "SELECT id, name, source_id, target_id, properties FROM relations WHERE 1=1"
params = []

if source_id:
query += " AND source_id = ?"
params.append(source_id)

if target_id:
query += " AND target_id = ?"
params.append(target_id)

if relation_name:
query += " AND name = ?"
params.append(relation_name)

cursor.execute(query, params)
rows = cursor.fetchall()

conn.close()

relations = []
for row in rows:
relations.append(Relation(
id=row[0],
name=row[1],
source=row[2],
target=row[3],
properties=json.loads(row[4]) if row[4] else {}
))

return relations

def execute_sparql_like_query(self, subject: str = None, predicate: str = None, object_: str = None) -> List[Dict[str, str]]:
"""执行类SPARQL查询"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()

# 构建查询
query = """
SELECT
e1.name as subject_name,
r.name as predicate_name,
e2.name as object_name,
e1.id as subject_id,
r.id as relation_id,
e2.id as object_id
FROM relations r
JOIN entities e1 ON r.source_id = e1.id
JOIN entities e2 ON r.target_id = e2.id
WHERE 1=1
"""

params = []

if subject:
query += " AND e1.name LIKE ?"
params.append(f"%{subject}%")

if predicate:
query += " AND r.name LIKE ?"
params.append(f"%{predicate}%")

if object_:
query += " AND e2.name LIKE ?"
params.append(f"%{object_}%")

cursor.execute(query, params)
rows = cursor.fetchall()

conn.close()

results = []
for row in rows:
results.append({
'subject_name': row[0],
'predicate_name': row[1],
'object_name': row[2],
'subject_id': row[3],
'relation_id': row[4],
'object_id': row[5]
})

return results

# 示例使用
if __name__ == "__main__":
# 创建数据库管理器
db = KnowledgeGraphDB()

# 保存示例数据
entity1 = Entity("person_1", "张三", "Person", {"age": 30})
entity2 = Entity("company_1", "科技公司", "Company", {"industry": "IT"})

db.save_entity(entity1)
db.save_entity(entity2)

relation1 = Relation("rel_1", "works_for", "person_1", "company_1", {})
db.save_relation(relation1)

# 查询示例
print("搜索实体:", [e.name for e in db.search_entities(name_pattern="张")])
print("查询关系:", [(r.name, r.source, r.target) for r in db.get_relations(source_id="person_1")])
print("SPARQL查询:", db.execute_sparql_like_query(subject="张三"))

3. RAG系统架构设计

3.1 RAG系统核心组件

RAG系统主要由以下核心组件构成:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Tuple
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

class Retriever(ABC):
"""检索器抽象基类"""

@abstractmethod
def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""检索相关文档"""
pass

class VectorRetriever(Retriever):
"""向量检索器"""

def __init__(self, embedding_model_name: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"):
self.embedding_model = SentenceTransformer(embedding_model_name)
self.index = None
self.documents = []
self.document_embeddings = None

def build_index(self, documents: List[Dict[str, Any]]):
"""构建向量索引"""
self.documents = documents

# 提取文档文本
texts = [doc.get('content', '') for doc in documents]

# 生成嵌入向量
self.document_embeddings = self.embedding_model.encode(texts)

# 构建FAISS索引
dimension = self.document_embeddings.shape[1]
self.index = faiss.IndexFlatIP(dimension) # 内积相似度

# 归一化向量(用于余弦相似度)
faiss.normalize_L2(self.document_embeddings)
self.index.add(self.document_embeddings.astype('float32'))

print(f"构建索引完成,包含 {len(documents)} 个文档")

def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""检索相关文档"""
if self.index is None:
raise ValueError("索引未构建,请先调用 build_index()")

# 生成查询向量
query_embedding = self.embedding_model.encode([query])
faiss.normalize_L2(query_embedding)

# 检索
scores, indices = self.index.search(query_embedding.astype('float32'), top_k)

# 返回结果
results = []
for i, (score, idx) in enumerate(zip(scores[0], indices[0])):
if idx < len(self.documents):
result = self.documents[idx].copy()
result['score'] = float(score)
result['rank'] = i + 1
results.append(result)

return results

class KnowledgeGraphRetriever(Retriever):
"""知识图谱检索器"""

def __init__(self, kg_db: KnowledgeGraphDB, embedding_model_name: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"):
self.kg_db = kg_db
self.embedding_model = SentenceTransformer(embedding_model_name)

def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""从知识图谱检索相关信息"""
results = []

# 1. 实体匹配检索
entities = self.kg_db.search_entities(name_pattern=query, limit=top_k)
for entity in entities:
# 获取实体的详细信息和关系
relations = self.kg_db.get_relations(source_id=entity.id)

content = f"实体: {entity.name} (类型: {entity.type})\n"
content += f"属性: {entity.properties}\n"

if relations:
content += "关系:\n"
for rel in relations[:3]: # 限制关系数量
target_entity = self.kg_db.get_entity(rel.target)
if target_entity:
content += f" - {rel.name}: {target_entity.name}\n"

results.append({
'content': content,
'source': 'knowledge_graph',
'entity_id': entity.id,
'entity_name': entity.name,
'entity_type': entity.type
})

# 2. 关系检索
sparql_results = self.kg_db.execute_sparql_like_query(predicate=query)
for result in sparql_results[:top_k]:
content = f"关系: {result['subject_name']} {result['predicate_name']} {result['object_name']}"
results.append({
'content': content,
'source': 'knowledge_graph_relation',
'subject': result['subject_name'],
'predicate': result['predicate_name'],
'object': result['object_name']
})

# 计算相关性分数
if results:
query_embedding = self.embedding_model.encode([query])
contents = [r['content'] for r in results]
content_embeddings = self.embedding_model.encode(contents)

# 计算余弦相似度
similarities = np.dot(query_embedding, content_embeddings.T)[0]

for i, result in enumerate(results):
result['score'] = float(similarities[i])

# 按分数排序
results.sort(key=lambda x: x['score'], reverse=True)

return results[:top_k]

class HybridRetriever(Retriever):
"""混合检索器:结合向量检索和知识图谱检索"""

def __init__(self, vector_retriever: VectorRetriever, kg_retriever: KnowledgeGraphRetriever,
vector_weight: float = 0.6, kg_weight: float = 0.4):
self.vector_retriever = vector_retriever
self.kg_retriever = kg_retriever
self.vector_weight = vector_weight
self.kg_weight = kg_weight

def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""混合检索"""
# 分别从两个检索器获取结果
vector_results = self.vector_retriever.retrieve(query, top_k)
kg_results = self.kg_retriever.retrieve(query, top_k)

# 合并结果并重新计算分数
all_results = []

# 处理向量检索结果
for result in vector_results:
result['retrieval_type'] = 'vector'
result['final_score'] = result['score'] * self.vector_weight
all_results.append(result)

# 处理知识图谱检索结果
for result in kg_results:
result['retrieval_type'] = 'knowledge_graph'
result['final_score'] = result['score'] * self.kg_weight
all_results.append(result)

# 按最终分数排序
all_results.sort(key=lambda x: x['final_score'], reverse=True)

return all_results[:top_k]

class Generator:
"""生成器:基于检索结果生成回答"""

def __init__(self, model_name: str = "microsoft/DialoGPT-medium"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name)

# 设置pad_token
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token

def generate_response(self, query: str, retrieved_docs: List[Dict[str, Any]],
max_length: int = 512, temperature: float = 0.7) -> str:
"""基于检索结果生成回答"""
# 构建上下文
context = self._build_context(query, retrieved_docs)

# 编码输入
inputs = self.tokenizer.encode(context, return_tensors="pt", max_length=max_length, truncation=True)

# 生成回答
with torch.no_grad():
outputs = self.model.generate(
inputs,
max_length=inputs.shape[1] + 150,
temperature=temperature,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)

# 解码输出
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

# 提取新生成的部分
original_length = len(self.tokenizer.decode(inputs[0], skip_special_tokens=True))
generated_response = response[original_length:].strip()

return generated_response

def _build_context(self, query: str, retrieved_docs: List[Dict[str, Any]]) -> str:
"""构建上下文"""
context = "基于以下信息回答问题:\n\n"

for i, doc in enumerate(retrieved_docs[:3]): # 限制上下文长度
context += f"信息{i+1}: {doc['content']}\n"
if doc.get('source'):
context += f"来源: {doc['source']}\n"
context += "\n"

context += f"问题: {query}\n"
context += "回答: "

return context

class RAGSystem:
"""RAG系统主类"""

def __init__(self, retriever: Retriever, generator: Generator):
self.retriever = retriever
self.generator = generator

def query(self, question: str, top_k: int = 5, **generation_kwargs) -> Dict[str, Any]:
"""处理查询"""
# 检索相关文档
retrieved_docs = self.retriever.retrieve(question, top_k)

# 生成回答
response = self.generator.generate_response(question, retrieved_docs, **generation_kwargs)

return {
'question': question,
'answer': response,
'retrieved_documents': retrieved_docs,
'num_retrieved': len(retrieved_docs)
}

def batch_query(self, questions: List[str], **kwargs) -> List[Dict[str, Any]]:
"""批量处理查询"""
results = []
for question in questions:
result = self.query(question, **kwargs)
results.append(result)
return results

# 示例使用
if __name__ == "__main__":
# 准备示例文档
documents = [
{
'id': 'doc1',
'content': '人工智能是计算机科学的一个分支,致力于创建能够执行通常需要人类智能的任务的系统。',
'source': 'AI教科书',
'title': 'AI简介'
},
{
'id': 'doc2',
'content': '机器学习是人工智能的一个子领域,通过算法让计算机从数据中学习模式。',
'source': 'ML指南',
'title': '机器学习基础'
},
{
'id': 'doc3',
'content': '深度学习使用多层神经网络来模拟人脑的学习过程,在图像识别和自然语言处理方面取得了突破。',
'source': 'DL论文',
'title': '深度学习原理'
}
]

# 创建向量检索器
vector_retriever = VectorRetriever()
vector_retriever.build_index(documents)

# 创建生成器
generator = Generator()

# 创建RAG系统
rag_system = RAGSystem(vector_retriever, generator)

# 测试查询
result = rag_system.query("什么是机器学习?")

print("问题:", result['question'])
print("回答:", result['answer'])
print("\n检索到的文档:")
for doc in result['retrieved_documents']:
print(f"- {doc['title']}: {doc['content'][:50]}... (分数: {doc['score']:.3f})")

3.2 RAG系统工作流程

RAG系统的工作流程可以分为以下几个阶段:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
import logging
from datetime import datetime
from typing import Optional
import time

class RAGPipeline:
"""RAG处理流水线"""

def __init__(self, rag_system: RAGSystem, enable_logging: bool = True):
self.rag_system = rag_system
self.enable_logging = enable_logging

if enable_logging:
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)

def process_query(self, query: str, user_id: str = None, session_id: str = None) -> Dict[str, Any]:
"""处理单个查询的完整流程"""
start_time = time.time()

if self.enable_logging:
self.logger.info(f"开始处理查询: {query[:50]}...")

try:
# 1. 查询预处理
processed_query = self._preprocess_query(query)

# 2. 查询理解与意图识别
query_intent = self._analyze_query_intent(processed_query)

# 3. 检索策略选择
retrieval_strategy = self._select_retrieval_strategy(query_intent)

# 4. 执行检索
retrieved_docs = self._execute_retrieval(processed_query, retrieval_strategy)

# 5. 检索结果后处理
processed_docs = self._postprocess_retrieval_results(retrieved_docs, query_intent)

# 6. 生成回答
response = self._generate_response(processed_query, processed_docs, query_intent)

# 7. 回答后处理
final_response = self._postprocess_response(response, query_intent)

# 8. 记录和评估
processing_time = time.time() - start_time
result = self._create_result_object(
query, final_response, processed_docs,
processing_time, user_id, session_id, query_intent
)

if self.enable_logging:
self.logger.info(f"查询处理完成,耗时: {processing_time:.2f}秒")

return result

except Exception as e:
if self.enable_logging:
self.logger.error(f"查询处理失败: {str(e)}")

return {
'query': query,
'answer': '抱歉,处理您的查询时出现了错误。',
'error': str(e),
'success': False,
'timestamp': datetime.now().isoformat()
}

def _preprocess_query(self, query: str) -> str:
"""查询预处理"""
# 清理和标准化查询文本
processed = query.strip()

# 移除多余的空格
processed = ' '.join(processed.split())

# 简单的拼写纠错(这里可以集成更复杂的纠错算法)
# processed = self._spell_check(processed)

return processed

def _analyze_query_intent(self, query: str) -> Dict[str, Any]:
"""分析查询意图"""
intent = {
'type': 'general', # general, factual, procedural, comparative
'domain': 'general', # technology, science, business, etc.
'complexity': 'simple', # simple, medium, complex
'requires_reasoning': False,
'requires_calculation': False,
'temporal_aspect': None # past, present, future
}

# 简单的意图识别规则
query_lower = query.lower()

# 识别查询类型
if any(word in query_lower for word in ['什么是', '定义', '含义']):
intent['type'] = 'factual'
elif any(word in query_lower for word in ['如何', '怎么', '步骤', '方法']):
intent['type'] = 'procedural'
elif any(word in query_lower for word in ['比较', '区别', '差异', '对比']):
intent['type'] = 'comparative'

# 识别领域
if any(word in query_lower for word in ['ai', '人工智能', '机器学习', '深度学习']):
intent['domain'] = 'technology'
elif any(word in query_lower for word in ['科学', '物理', '化学', '生物']):
intent['domain'] = 'science'

# 识别复杂度
if len(query.split()) > 10 or '为什么' in query_lower:
intent['complexity'] = 'complex'
intent['requires_reasoning'] = True

return intent

def _select_retrieval_strategy(self, query_intent: Dict[str, Any]) -> Dict[str, Any]:
"""选择检索策略"""
strategy = {
'retrieval_type': 'hybrid', # vector, knowledge_graph, hybrid
'top_k': 5,
'rerank': False,
'expand_query': False
}

# 根据意图调整策略
if query_intent['type'] == 'factual':
strategy['retrieval_type'] = 'knowledge_graph'
strategy['top_k'] = 3
elif query_intent['type'] == 'procedural':
strategy['retrieval_type'] = 'vector'
strategy['top_k'] = 7
elif query_intent['complexity'] == 'complex':
strategy['top_k'] = 10
strategy['rerank'] = True
strategy['expand_query'] = True

return strategy

def _execute_retrieval(self, query: str, strategy: Dict[str, Any]) -> List[Dict[str, Any]]:
"""执行检索"""
# 查询扩展
if strategy.get('expand_query', False):
expanded_query = self._expand_query(query)
else:
expanded_query = query

# 执行检索
retrieved_docs = self.rag_system.retriever.retrieve(expanded_query, strategy['top_k'])

# 重排序
if strategy.get('rerank', False):
retrieved_docs = self._rerank_documents(query, retrieved_docs)

return retrieved_docs

def _expand_query(self, query: str) -> str:
"""查询扩展"""
# 简单的同义词扩展(实际应用中可以使用更复杂的方法)
synonyms = {
'AI': ['人工智能', '机器智能'],
'机器学习': ['ML', '机器学习算法'],
'深度学习': ['DL', '神经网络']
}

expanded = query
for term, syns in synonyms.items():
if term in query:
expanded += ' ' + ' '.join(syns)

return expanded

def _rerank_documents(self, query: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""重排序文档"""
# 这里可以实现更复杂的重排序算法
# 例如基于BERT的交叉编码器
return documents # 暂时返回原始排序

def _postprocess_retrieval_results(self, documents: List[Dict[str, Any]],
query_intent: Dict[str, Any]) -> List[Dict[str, Any]]:
"""检索结果后处理"""
processed_docs = []

for doc in documents:
# 过滤低质量文档
if doc.get('score', 0) < 0.1:
continue

# 根据意图调整文档内容
if query_intent['type'] == 'factual':
# 对于事实性查询,优先选择定义性内容
doc['relevance_boost'] = 1.2 if '定义' in doc.get('content', '') else 1.0

processed_docs.append(doc)

return processed_docs

def _generate_response(self, query: str, documents: List[Dict[str, Any]],
query_intent: Dict[str, Any]) -> str:
"""生成回答"""
# 根据意图调整生成参数
generation_params = {
'temperature': 0.7,
'max_length': 512
}

if query_intent['type'] == 'factual':
generation_params['temperature'] = 0.3 # 更确定性的回答
elif query_intent['complexity'] == 'complex':
generation_params['max_length'] = 1024 # 更长的回答

return self.rag_system.generator.generate_response(query, documents, **generation_params)

def _postprocess_response(self, response: str, query_intent: Dict[str, Any]) -> str:
"""回答后处理"""
# 清理回答
processed = response.strip()

# 移除重复内容
sentences = processed.split('。')
unique_sentences = []
seen = set()

for sentence in sentences:
sentence = sentence.strip()
if sentence and sentence not in seen:
unique_sentences.append(sentence)
seen.add(sentence)

processed = '。'.join(unique_sentences)
if processed and not processed.endswith('。'):
processed += '。'

return processed

def _create_result_object(self, query: str, response: str, documents: List[Dict[str, Any]],
processing_time: float, user_id: str, session_id: str,
query_intent: Dict[str, Any]) -> Dict[str, Any]:
"""创建结果对象"""
return {
'query': query,
'answer': response,
'retrieved_documents': documents,
'processing_time': processing_time,
'user_id': user_id,
'session_id': session_id,
'query_intent': query_intent,
'success': True,
'timestamp': datetime.now().isoformat(),
'num_retrieved': len(documents)
}

# 示例使用
if __name__ == "__main__":
# 创建RAG系统(使用之前定义的组件)
vector_retriever = VectorRetriever()
vector_retriever.build_index(documents)

generator = Generator()
rag_system = RAGSystem(vector_retriever, generator)

# 创建处理流水线
pipeline = RAGPipeline(rag_system)

# 测试查询
result = pipeline.process_query("什么是深度学习?", user_id="user123", session_id="session456")

print("查询结果:")
print(f"问题: {result['query']}")
print(f"回答: {result['answer']}")
print(f"处理时间: {result['processing_time']:.2f}秒")
print(f"查询意图: {result['query_intent']}")

4. RAG系统优化策略

4.1 检索优化

检索质量直接影响RAG系统的最终效果,以下是几种关键的优化策略:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
class AdvancedRetriever(Retriever):
"""高级检索器:集成多种优化技术"""

def __init__(self, base_retriever: Retriever,
query_expansion_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-12-v2"):
self.base_retriever = base_retriever
self.query_expansion_model = SentenceTransformer(query_expansion_model)
# self.reranker = CrossEncoder(reranker_model) # 需要安装sentence-transformers[cross-encoder]
self.query_cache = {}

def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""优化的检索流程"""
# 1. 查询缓存检查
cache_key = f"{query}_{top_k}"
if cache_key in self.query_cache:
return self.query_cache[cache_key]

# 2. 查询扩展
expanded_queries = self._expand_query_semantic(query)

# 3. 多查询检索
all_results = []
for exp_query in expanded_queries:
results = self.base_retriever.retrieve(exp_query, top_k * 2)
all_results.extend(results)

# 4. 去重和合并
unique_results = self._deduplicate_results(all_results)

# 5. 重排序
reranked_results = self._rerank_results(query, unique_results)

# 6. 缓存结果
final_results = reranked_results[:top_k]
self.query_cache[cache_key] = final_results

return final_results

def _expand_query_semantic(self, query: str) -> List[str]:
"""语义查询扩展"""
# 生成查询的语义变体
expanded_queries = [query]

# 方法1:同义词替换
synonyms = self._get_synonyms(query)
for synonym_query in synonyms:
expanded_queries.append(synonym_query)

# 方法2:查询重写
rewritten_queries = self._rewrite_query(query)
expanded_queries.extend(rewritten_queries)

return expanded_queries[:3] # 限制扩展查询数量

def _get_synonyms(self, query: str) -> List[str]:
"""获取同义词查询"""
# 简化的同义词映射
synonym_map = {
'人工智能': ['AI', '机器智能', '智能系统'],
'机器学习': ['ML', '机器学习算法', '自动学习'],
'深度学习': ['DL', '神经网络', '深层神经网络'],
'自然语言处理': ['NLP', '文本处理', '语言理解']
}

synonyms = []
for term, syns in synonym_map.items():
if term in query:
for syn in syns:
synonyms.append(query.replace(term, syn))

return synonyms

def _rewrite_query(self, query: str) -> List[str]:
"""查询重写"""
rewritten = []

# 添加上下文词汇
if '是什么' in query:
base_term = query.replace('是什么', '').strip()
rewritten.append(f"{base_term}的定义")
rewritten.append(f"{base_term}的含义")

if '如何' in query:
base_term = query.replace('如何', '').strip()
rewritten.append(f"{base_term}的方法")
rewritten.append(f"{base_term}的步骤")

return rewritten

def _deduplicate_results(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""去重检索结果"""
seen_content = set()
unique_results = []

for result in results:
content_hash = hash(result.get('content', ''))
if content_hash not in seen_content:
seen_content.add(content_hash)
unique_results.append(result)

return unique_results

def _rerank_results(self, query: str, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""重排序结果"""
# 简化的重排序:基于多个因子
for result in results:
# 计算综合分数
base_score = result.get('score', 0.0)

# 长度惩罚:过短或过长的文档降权
content_length = len(result.get('content', ''))
length_penalty = 1.0
if content_length < 50:
length_penalty = 0.8
elif content_length > 1000:
length_penalty = 0.9

# 来源权重
source_weight = 1.0
if result.get('source') == 'knowledge_graph':
source_weight = 1.2 # 知识图谱来源加权

# 新鲜度权重(如果有时间戳)
freshness_weight = 1.0
# if 'timestamp' in result:
# freshness_weight = self._calculate_freshness_weight(result['timestamp'])

# 计算最终分数
final_score = base_score * length_penalty * source_weight * freshness_weight
result['rerank_score'] = final_score

# 按重排序分数排序
results.sort(key=lambda x: x.get('rerank_score', 0), reverse=True)
return results

class QueryOptimizer:
"""查询优化器"""

def __init__(self):
self.stop_words = {'的', '了', '在', '是', '有', '和', '与', '或', '但是', '然而'}
self.query_patterns = {
'definition': ['什么是', '定义', '含义', '意思'],
'how_to': ['如何', '怎么', '怎样', '方法'],
'comparison': ['区别', '差异', '对比', '比较'],
'causation': ['为什么', '原因', '导致', '影响']
}

def optimize_query(self, query: str) -> Dict[str, Any]:
"""优化查询"""
# 1. 查询清理
cleaned_query = self._clean_query(query)

# 2. 查询分类
query_type = self._classify_query(cleaned_query)

# 3. 关键词提取
keywords = self._extract_keywords(cleaned_query)

# 4. 查询扩展建议
expansion_suggestions = self._suggest_expansions(cleaned_query, query_type)

return {
'original_query': query,
'cleaned_query': cleaned_query,
'query_type': query_type,
'keywords': keywords,
'expansion_suggestions': expansion_suggestions
}

def _clean_query(self, query: str) -> str:
"""清理查询"""
# 移除标点符号和多余空格
import re
cleaned = re.sub(r'[^\w\s]', ' ', query)
cleaned = ' '.join(cleaned.split())
return cleaned.strip()

def _classify_query(self, query: str) -> str:
"""分类查询"""
query_lower = query.lower()

for query_type, patterns in self.query_patterns.items():
if any(pattern in query_lower for pattern in patterns):
return query_type

return 'general'

def _extract_keywords(self, query: str) -> List[str]:
"""提取关键词"""
words = query.split()
keywords = [word for word in words if word not in self.stop_words and len(word) > 1]
return keywords

def _suggest_expansions(self, query: str, query_type: str) -> List[str]:
"""建议查询扩展"""
suggestions = []

if query_type == 'definition':
suggestions.extend(['概念', '特点', '应用'])
elif query_type == 'how_to':
suggestions.extend(['步骤', '流程', '实现'])
elif query_type == 'comparison':
suggestions.extend(['优缺点', '特性', '适用场景'])

return suggestions

4.2 生成优化

生成质量的优化涉及多个方面:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
class AdvancedGenerator:
"""高级生成器:集成多种优化技术"""

def __init__(self, model_name: str = "microsoft/DialoGPT-medium"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name)

if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token

self.response_cache = {}
self.generation_history = []

def generate_response(self, query: str, retrieved_docs: List[Dict[str, Any]],
**kwargs) -> Dict[str, Any]:
"""生成优化的回答"""
# 1. 上下文优化
optimized_context = self._optimize_context(query, retrieved_docs)

# 2. 生成参数调优
generation_params = self._optimize_generation_params(query, retrieved_docs, **kwargs)

# 3. 多候选生成
candidates = self._generate_multiple_candidates(optimized_context, generation_params)

# 4. 候选选择
best_candidate = self._select_best_candidate(query, candidates, retrieved_docs)

# 5. 后处理
final_response = self._postprocess_response(best_candidate, query)

# 6. 质量评估
quality_score = self._evaluate_response_quality(query, final_response, retrieved_docs)

return {
'response': final_response,
'quality_score': quality_score,
'context_length': len(optimized_context),
'generation_params': generation_params
}

def _optimize_context(self, query: str, retrieved_docs: List[Dict[str, Any]]) -> str:
"""优化上下文构建"""
# 1. 文档相关性排序
sorted_docs = sorted(retrieved_docs, key=lambda x: x.get('score', 0), reverse=True)

# 2. 上下文长度控制
max_context_length = 1000 # 字符数限制
context_parts = []
current_length = 0

# 3. 智能文档选择
for doc in sorted_docs:
content = doc.get('content', '')

# 检查内容相关性
if self._is_content_relevant(query, content):
if current_length + len(content) <= max_context_length:
context_parts.append(content)
current_length += len(content)
else:
# 截取部分内容
remaining_length = max_context_length - current_length
if remaining_length > 100: # 至少保留100字符
truncated_content = content[:remaining_length] + "..."
context_parts.append(truncated_content)
break

# 4. 构建结构化上下文
context = f"问题:{query}\n\n相关信息:\n"
for i, part in enumerate(context_parts, 1):
context += f"{i}. {part}\n\n"

context += "请基于以上信息回答问题:"

return context

def _is_content_relevant(self, query: str, content: str) -> bool:
"""判断内容相关性"""
query_words = set(query.lower().split())
content_words = set(content.lower().split())

# 计算词汇重叠度
overlap = len(query_words.intersection(content_words))
relevance_ratio = overlap / len(query_words) if query_words else 0

return relevance_ratio > 0.2 # 20%的词汇重叠阈值

def _optimize_generation_params(self, query: str, retrieved_docs: List[Dict[str, Any]],
**kwargs) -> Dict[str, Any]:
"""优化生成参数"""
params = {
'max_length': 512,
'temperature': 0.7,
'top_p': 0.9,
'top_k': 50,
'repetition_penalty': 1.1,
'do_sample': True
}

# 根据查询类型调整参数
if '定义' in query or '什么是' in query:
params['temperature'] = 0.3 # 更确定性的回答
params['top_p'] = 0.8
elif '如何' in query or '怎么' in query:
params['max_length'] = 768 # 更长的回答
params['temperature'] = 0.5

# 根据检索质量调整
avg_score = np.mean([doc.get('score', 0) for doc in retrieved_docs]) if retrieved_docs else 0
if avg_score < 0.5:
params['temperature'] = 0.8 # 低质量检索时增加创造性

# 应用用户自定义参数
params.update(kwargs)

return params

def _generate_multiple_candidates(self, context: str, params: Dict[str, Any]) -> List[str]:
"""生成多个候选回答"""
candidates = []
num_candidates = 3

inputs = self.tokenizer.encode(context, return_tensors="pt", max_length=1024, truncation=True)

for i in range(num_candidates):
# 为每个候选使用略微不同的参数
candidate_params = params.copy()
candidate_params['temperature'] = params['temperature'] + (i * 0.1)

with torch.no_grad():
outputs = self.model.generate(
inputs,
max_length=inputs.shape[1] + candidate_params['max_length'],
temperature=candidate_params['temperature'],
top_p=candidate_params['top_p'],
top_k=candidate_params['top_k'],
repetition_penalty=candidate_params['repetition_penalty'],
do_sample=candidate_params['do_sample'],
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)

response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
original_length = len(self.tokenizer.decode(inputs[0], skip_special_tokens=True))
generated_response = response[original_length:].strip()

candidates.append(generated_response)

return candidates

def _select_best_candidate(self, query: str, candidates: List[str],
retrieved_docs: List[Dict[str, Any]]) -> str:
"""选择最佳候选回答"""
if not candidates:
return "抱歉,无法生成合适的回答。"

best_candidate = candidates[0]
best_score = 0

for candidate in candidates:
score = self._score_candidate(query, candidate, retrieved_docs)
if score > best_score:
best_score = score
best_candidate = candidate

return best_candidate

def _score_candidate(self, query: str, candidate: str,
retrieved_docs: List[Dict[str, Any]]) -> float:
"""评分候选回答"""
score = 0.0

# 1. 长度合理性
length_score = min(len(candidate) / 200, 1.0) # 理想长度200字符
score += length_score * 0.2

# 2. 与查询的相关性
query_words = set(query.lower().split())
candidate_words = set(candidate.lower().split())
relevance_score = len(query_words.intersection(candidate_words)) / len(query_words)
score += relevance_score * 0.3

# 3. 与检索文档的一致性
consistency_score = 0
for doc in retrieved_docs[:3]: # 只考虑前3个文档
doc_words = set(doc.get('content', '').lower().split())
consistency = len(candidate_words.intersection(doc_words)) / len(candidate_words)
consistency_score += consistency
consistency_score = consistency_score / min(len(retrieved_docs), 3) if retrieved_docs else 0
score += consistency_score * 0.3

# 4. 流畅性(简单检查)
fluency_score = 1.0 if len(candidate.split('。')) > 1 else 0.5 # 多句子更流畅
score += fluency_score * 0.2

return score

def _postprocess_response(self, response: str, query: str) -> str:
"""后处理回答"""
# 1. 清理格式
processed = response.strip()

# 2. 移除重复句子
sentences = [s.strip() for s in processed.split('。') if s.strip()]
unique_sentences = []
seen = set()

for sentence in sentences:
if sentence not in seen and len(sentence) > 5:
unique_sentences.append(sentence)
seen.add(sentence)

processed = '。'.join(unique_sentences)
if processed and not processed.endswith('。'):
processed += '。'

# 3. 添加适当的开头(如果需要)
if not any(processed.startswith(prefix) for prefix in ['根据', '基于', '从']):
if '什么是' in query or '定义' in query:
processed = f"根据相关资料,{processed}"

return processed

def _evaluate_response_quality(self, query: str, response: str,
retrieved_docs: List[Dict[str, Any]]) -> float:
"""评估回答质量"""
# 综合多个维度评估质量
quality_factors = {
'completeness': self._evaluate_completeness(query, response),
'accuracy': self._evaluate_accuracy(response, retrieved_docs),
'relevance': self._evaluate_relevance(query, response),
'fluency': self._evaluate_fluency(response)
}

# 加权平均
weights = {'completeness': 0.3, 'accuracy': 0.3, 'relevance': 0.25, 'fluency': 0.15}
quality_score = sum(quality_factors[factor] * weights[factor]
for factor in quality_factors)

return quality_score

def _evaluate_completeness(self, query: str, response: str) -> float:
"""评估回答完整性"""
# 简单的完整性检查
if len(response) < 20:
return 0.2
elif len(response) < 50:
return 0.6
else:
return 1.0

def _evaluate_accuracy(self, response: str, retrieved_docs: List[Dict[str, Any]]) -> float:
"""评估回答准确性"""
if not retrieved_docs:
return 0.5 # 无法验证时给中等分数

# 检查回答是否与检索文档一致
response_words = set(response.lower().split())
doc_words = set()

for doc in retrieved_docs[:3]:
doc_words.update(doc.get('content', '').lower().split())

if not doc_words:
return 0.5

consistency = len(response_words.intersection(doc_words)) / len(response_words)
return min(consistency * 2, 1.0) # 放大一致性分数

def _evaluate_relevance(self, query: str, response: str) -> float:
"""评估回答相关性"""
query_words = set(query.lower().split())
response_words = set(response.lower().split())

if not query_words:
return 0.5

relevance = len(query_words.intersection(response_words)) / len(query_words)
return relevance

def _evaluate_fluency(self, response: str) -> float:
"""评估回答流畅性"""
# 简单的流畅性检查
sentences = response.split('。')

if len(sentences) < 2:
return 0.6 # 单句回答

# 检查句子长度分布
sentence_lengths = [len(s.strip()) for s in sentences if s.strip()]
if not sentence_lengths:
return 0.3

avg_length = np.mean(sentence_lengths)
if 10 <= avg_length <= 50: # 理想句子长度
return 1.0
else:
return 0.7

4.3 系统评估与监控

RAG系统的持续优化需要完善的评估和监控机制:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
class RAGEvaluator:
"""RAG系统评估器"""

def __init__(self):
self.metrics_history = []
self.evaluation_cache = {}

def evaluate_system(self, rag_system: RAGSystem, test_queries: List[Dict[str, Any]],
ground_truth: List[str] = None) -> Dict[str, Any]:
"""全面评估RAG系统"""
results = {
'retrieval_metrics': {},
'generation_metrics': {},
'end_to_end_metrics': {},
'performance_metrics': {}
}

# 1. 检索评估
retrieval_results = self._evaluate_retrieval(rag_system.retriever, test_queries)
results['retrieval_metrics'] = retrieval_results

# 2. 生成评估
generation_results = self._evaluate_generation(rag_system, test_queries, ground_truth)
results['generation_metrics'] = generation_results

# 3. 端到端评估
e2e_results = self._evaluate_end_to_end(rag_system, test_queries, ground_truth)
results['end_to_end_metrics'] = e2e_results

# 4. 性能评估
performance_results = self._evaluate_performance(rag_system, test_queries)
results['performance_metrics'] = performance_results

# 记录评估历史
self.metrics_history.append({
'timestamp': datetime.now().isoformat(),
'results': results
})

return results

def _evaluate_retrieval(self, retriever: Retriever, test_queries: List[Dict[str, Any]]) -> Dict[str, float]:
"""评估检索性能"""
metrics = {
'precision_at_k': [],
'recall_at_k': [],
'mrr': [], # Mean Reciprocal Rank
'ndcg': [] # Normalized Discounted Cumulative Gain
}

for query_data in test_queries:
query = query_data['query']
relevant_docs = query_data.get('relevant_docs', [])

# 执行检索
retrieved_docs = retriever.retrieve(query, top_k=10)
retrieved_ids = [doc.get('id', '') for doc in retrieved_docs]

# 计算指标
if relevant_docs:
precision = self._calculate_precision_at_k(retrieved_ids, relevant_docs, k=5)
recall = self._calculate_recall_at_k(retrieved_ids, relevant_docs, k=5)
mrr = self._calculate_mrr(retrieved_ids, relevant_docs)
ndcg = self._calculate_ndcg(retrieved_ids, relevant_docs, k=5)

metrics['precision_at_k'].append(precision)
metrics['recall_at_k'].append(recall)
metrics['mrr'].append(mrr)
metrics['ndcg'].append(ndcg)

# 计算平均值
return {
'avg_precision_at_5': np.mean(metrics['precision_at_k']) if metrics['precision_at_k'] else 0,
'avg_recall_at_5': np.mean(metrics['recall_at_k']) if metrics['recall_at_k'] else 0,
'mean_reciprocal_rank': np.mean(metrics['mrr']) if metrics['mrr'] else 0,
'avg_ndcg_at_5': np.mean(metrics['ndcg']) if metrics['ndcg'] else 0
}

def _calculate_precision_at_k(self, retrieved: List[str], relevant: List[str], k: int) -> float:
"""计算Precision@K"""
retrieved_k = retrieved[:k]
relevant_retrieved = len(set(retrieved_k).intersection(set(relevant)))
return relevant_retrieved / k if k > 0 else 0

def _calculate_recall_at_k(self, retrieved: List[str], relevant: List[str], k: int) -> float:
"""计算Recall@K"""
retrieved_k = retrieved[:k]
relevant_retrieved = len(set(retrieved_k).intersection(set(relevant)))
return relevant_retrieved / len(relevant) if relevant else 0

def _calculate_mrr(self, retrieved: List[str], relevant: List[str]) -> float:
"""计算Mean Reciprocal Rank"""
for i, doc_id in enumerate(retrieved):
if doc_id in relevant:
return 1.0 / (i + 1)
return 0.0

def _calculate_ndcg(self, retrieved: List[str], relevant: List[str], k: int) -> float:
"""计算NDCG@K"""
# 简化的NDCG计算
dcg = 0.0
for i, doc_id in enumerate(retrieved[:k]):
if doc_id in relevant:
dcg += 1.0 / np.log2(i + 2)

# 理想DCG
idcg = sum(1.0 / np.log2(i + 2) for i in range(min(len(relevant), k)))

return dcg / idcg if idcg > 0 else 0

def _evaluate_generation(self, rag_system: RAGSystem, test_queries: List[Dict[str, Any]],
ground_truth: List[str] = None) -> Dict[str, float]:
"""评估生成质量"""
if not ground_truth:
return {'note': 'No ground truth provided for generation evaluation'}

metrics = {
'bleu_scores': [],
'rouge_scores': [],
'semantic_similarity': []
}

for i, query_data in enumerate(test_queries):
if i >= len(ground_truth):
break

query = query_data['query']
expected_answer = ground_truth[i]

# 生成回答
result = rag_system.query(query)
generated_answer = result['answer']

# 计算BLEU分数(简化版)
bleu = self._calculate_simple_bleu(generated_answer, expected_answer)
metrics['bleu_scores'].append(bleu)

# 计算语义相似度
semantic_sim = self._calculate_semantic_similarity(generated_answer, expected_answer)
metrics['semantic_similarity'].append(semantic_sim)

return {
'avg_bleu': np.mean(metrics['bleu_scores']) if metrics['bleu_scores'] else 0,
'avg_semantic_similarity': np.mean(metrics['semantic_similarity']) if metrics['semantic_similarity'] else 0
}

def _calculate_simple_bleu(self, generated: str, reference: str) -> float:
"""简化的BLEU分数计算"""
gen_words = set(generated.lower().split())
ref_words = set(reference.lower().split())

if not ref_words:
return 0.0

overlap = len(gen_words.intersection(ref_words))
return overlap / len(ref_words)

def _calculate_semantic_similarity(self, text1: str, text2: str) -> float:
"""计算语义相似度"""
# 这里可以使用更复杂的语义相似度计算方法
# 简化版本:基于词汇重叠
words1 = set(text1.lower().split())
words2 = set(text2.lower().split())

if not words1 and not words2:
return 1.0
if not words1 or not words2:
return 0.0

intersection = len(words1.intersection(words2))
union = len(words1.union(words2))

return intersection / union if union > 0 else 0

def _evaluate_end_to_end(self, rag_system: RAGSystem, test_queries: List[Dict[str, Any]],
ground_truth: List[str] = None) -> Dict[str, float]:
"""端到端评估"""
metrics = {
'answer_relevance': [],
'answer_completeness': [],
'factual_accuracy': []
}

for query_data in test_queries:
query = query_data['query']
result = rag_system.query(query)
answer = result['answer']

# 评估回答相关性
relevance = self._evaluate_answer_relevance(query, answer)
metrics['answer_relevance'].append(relevance)

# 评估回答完整性
completeness = self._evaluate_answer_completeness(query, answer)
metrics['answer_completeness'].append(completeness)

# 评估事实准确性(基于检索文档)
accuracy = self._evaluate_factual_accuracy(answer, result['retrieved_documents'])
metrics['factual_accuracy'].append(accuracy)

return {
'avg_relevance': np.mean(metrics['answer_relevance']),
'avg_completeness': np.mean(metrics['answer_completeness']),
'avg_accuracy': np.mean(metrics['factual_accuracy'])
}

def _evaluate_answer_relevance(self, query: str, answer: str) -> float:
"""评估回答相关性"""
query_words = set(query.lower().split())
answer_words = set(answer.lower().split())

if not query_words:
return 0.5

overlap = len(query_words.intersection(answer_words))
return overlap / len(query_words)

def _evaluate_answer_completeness(self, query: str, answer: str) -> float:
"""评估回答完整性"""
# 基于长度和结构的简单评估
if len(answer) < 20:
return 0.3
elif len(answer) < 100:
return 0.7
else:
return 1.0

def _evaluate_factual_accuracy(self, answer: str, retrieved_docs: List[Dict[str, Any]]) -> float:
"""评估事实准确性"""
if not retrieved_docs:
return 0.5

answer_words = set(answer.lower().split())
doc_words = set()

for doc in retrieved_docs[:3]: # 只考虑前3个文档
doc_words.update(doc.get('content', '').lower().split())

if not doc_words:
return 0.5

consistency = len(answer_words.intersection(doc_words)) / len(answer_words)
return min(consistency * 1.5, 1.0) # 放大一致性分数

def _evaluate_performance(self, rag_system: RAGSystem, test_queries: List[Dict[str, Any]]) -> Dict[str, float]:
"""评估系统性能"""
response_times = []

for query_data in test_queries[:10]: # 限制测试数量
query = query_data['query']

start_time = time.time()
result = rag_system.query(query)
end_time = time.time()

response_times.append(end_time - start_time)

return {
'avg_response_time': np.mean(response_times),
'max_response_time': np.max(response_times),
'min_response_time': np.min(response_times)
}

class RAGMonitor:
"""RAG系统监控器"""

def __init__(self, rag_system: RAGSystem):
self.rag_system = rag_system
self.query_logs = []
self.performance_metrics = []
self.alert_thresholds = {
'response_time': 5.0, # 秒
'error_rate': 0.1, # 10%
'quality_score': 0.6 # 最低质量分数
}

def log_query(self, query: str, result: Dict[str, Any], user_id: str = None):
"""记录查询日志"""
log_entry = {
'timestamp': datetime.now().isoformat(),
'query': query,
'user_id': user_id,
'response_time': result.get('processing_time', 0),
'success': result.get('success', True),
'num_retrieved': result.get('num_retrieved', 0),
'quality_score': result.get('quality_score', 0)
}

self.query_logs.append(log_entry)

# 检查告警条件
self._check_alerts(log_entry)

def _check_alerts(self, log_entry: Dict[str, Any]):
"""检查告警条件"""
# 响应时间告警
if log_entry['response_time'] > self.alert_thresholds['response_time']:
self._send_alert(f"响应时间过长: {log_entry['response_time']:.2f}秒")

# 质量分数告警
if log_entry['quality_score'] < self.alert_thresholds['quality_score']:
self._send_alert(f"回答质量过低: {log_entry['quality_score']:.2f}")

# 错误率告警
recent_logs = self.query_logs[-100:] # 最近100条记录
if len(recent_logs) >= 10:
error_rate = sum(1 for log in recent_logs if not log['success']) / len(recent_logs)
if error_rate > self.alert_thresholds['error_rate']:
self._send_alert(f"错误率过高: {error_rate:.2%}")

def _send_alert(self, message: str):
"""发送告警"""
print(f"[ALERT] {datetime.now().isoformat()}: {message}")
# 这里可以集成实际的告警系统,如邮件、短信、Slack等

def get_performance_summary(self, hours: int = 24) -> Dict[str, Any]:
"""获取性能摘要"""
cutoff_time = datetime.now() - timedelta(hours=hours)
recent_logs = [
log for log in self.query_logs
if datetime.fromisoformat(log['timestamp']) > cutoff_time
]

if not recent_logs:
return {'message': 'No recent data available'}

response_times = [log['response_time'] for log in recent_logs]
quality_scores = [log['quality_score'] for log in recent_logs if log['quality_score'] > 0]
success_rate = sum(1 for log in recent_logs if log['success']) / len(recent_logs)

return {
'total_queries': len(recent_logs),
'success_rate': success_rate,
'avg_response_time': np.mean(response_times),
'p95_response_time': np.percentile(response_times, 95),
'avg_quality_score': np.mean(quality_scores) if quality_scores else 0,
'unique_users': len(set(log['user_id'] for log in recent_logs if log['user_id']))
}

5. 实际应用案例

5.1 智能客服系统

基于RAG的智能客服系统能够结合企业知识库提供准确的客户服务:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
class CustomerServiceRAG:
"""智能客服RAG系统"""

def __init__(self, knowledge_base_path: str, faq_data: List[Dict[str, Any]]):
# 初始化知识库
self.kb = self._build_knowledge_base(knowledge_base_path, faq_data)

# 创建检索器
self.retriever = VectorRetriever()
self.retriever.build_index(self.kb)

# 创建生成器
self.generator = AdvancedGenerator()

# 创建RAG系统
self.rag_system = RAGSystem(self.retriever, self.generator)

# 对话历史
self.conversation_history = {}

def _build_knowledge_base(self, kb_path: str, faq_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""构建知识库"""
documents = []

# 添加FAQ数据
for faq in faq_data:
documents.append({
'id': f"faq_{faq['id']}",
'content': f"问题: {faq['question']}\n答案: {faq['answer']}",
'source': 'faq',
'category': faq.get('category', 'general'),
'title': faq['question']
})

# 添加产品文档(如果有)
# documents.extend(self._load_product_docs(kb_path))

return documents

def handle_customer_query(self, query: str, customer_id: str,
session_id: str = None) -> Dict[str, Any]:
"""处理客户查询"""
# 获取对话历史
history = self.conversation_history.get(customer_id, [])

# 上下文增强查询
enhanced_query = self._enhance_query_with_context(query, history)

# 执行RAG查询
result = self.rag_system.query(enhanced_query)

# 后处理回答
processed_answer = self._postprocess_customer_answer(result['answer'], query)

# 更新对话历史
history.append({
'timestamp': datetime.now().isoformat(),
'query': query,
'answer': processed_answer,
'session_id': session_id
})
self.conversation_history[customer_id] = history[-10:] # 保留最近10轮对话

# 生成建议操作
suggested_actions = self._generate_suggested_actions(query, result)

return {
'answer': processed_answer,
'confidence': result.get('quality_score', 0),
'suggested_actions': suggested_actions,
'retrieved_sources': [doc.get('source', '') for doc in result['retrieved_documents']],
'escalate_to_human': self._should_escalate(result, query)
}

def _enhance_query_with_context(self, query: str, history: List[Dict[str, Any]]) -> str:
"""使用对话历史增强查询"""
if not history:
return query

# 获取最近的对话上下文
recent_context = []
for item in history[-3:]: # 最近3轮对话
recent_context.append(f"用户: {item['query']}")
recent_context.append(f"客服: {item['answer']}")

context_str = "\n".join(recent_context)
enhanced_query = f"对话历史:\n{context_str}\n\n当前问题: {query}"

return enhanced_query

def _postprocess_customer_answer(self, answer: str, query: str) -> str:
"""后处理客服回答"""
# 添加礼貌用语
if not answer.startswith(('您好', '感谢', '很高兴')):
answer = f"您好!{answer}"

# 添加结尾
if not answer.endswith(('。', '!', '?')):
answer += "。"

# 添加后续服务提示
answer += "\n\n如果您还有其他问题,请随时告诉我。"

return answer

def _generate_suggested_actions(self, query: str, result: Dict[str, Any]) -> List[str]:
"""生成建议操作"""
suggestions = []

# 基于查询类型生成建议
if '退款' in query or '退货' in query:
suggestions.append('查看退款政策')
suggestions.append('联系售后服务')
elif '订单' in query:
suggestions.append('查询订单状态')
suggestions.append('修改订单信息')
elif '产品' in query or '功能' in query:
suggestions.append('查看产品详情')
suggestions.append('观看使用教程')

# 基于检索结果生成建议
for doc in result.get('retrieved_documents', [])[:2]:
if doc.get('category') == 'tutorial':
suggestions.append('查看相关教程')
elif doc.get('category') == 'policy':
suggestions.append('了解相关政策')

return list(set(suggestions)) # 去重

def _should_escalate(self, result: Dict[str, Any], query: str) -> bool:
"""判断是否需要转人工"""
# 低置信度回答
if result.get('quality_score', 0) < 0.5:
return True

# 复杂查询关键词
escalation_keywords = ['投诉', '不满意', '经理', '人工', '转接']
if any(keyword in query for keyword in escalation_keywords):
return True

# 检索结果不足
if result.get('num_retrieved', 0) < 2:
return True

return False

5.2 技术文档问答系统

针对技术文档的RAG系统需要处理复杂的技术概念和代码示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
class TechnicalDocRAG:
"""技术文档问答系统"""

def __init__(self, doc_sources: List[str]):
self.documents = self._load_technical_docs(doc_sources)

# 创建混合检索器
vector_retriever = VectorRetriever()
vector_retriever.build_index(self.documents)

# 创建知识图谱(技术概念关系)
self.tech_kg = self._build_technical_kg()
kg_retriever = KnowledgeGraphRetriever(self.tech_kg)

# 混合检索器
self.retriever = HybridRetriever(vector_retriever, kg_retriever)

# 专门的技术文档生成器
self.generator = TechnicalGenerator()

self.rag_system = RAGSystem(self.retriever, self.generator)

def _load_technical_docs(self, sources: List[str]) -> List[Dict[str, Any]]:
"""加载技术文档"""
documents = []

for source in sources:
# 这里可以实现从各种源加载文档的逻辑
# 如API文档、代码注释、README文件等
pass

return documents

def _build_technical_kg(self) -> KnowledgeGraphDB:
"""构建技术概念知识图谱"""
kg_db = KnowledgeGraphDB()

# 添加技术概念实体和关系
# 这里可以从技术文档中自动提取概念关系

return kg_db

def answer_technical_question(self, question: str,
context: Dict[str, Any] = None) -> Dict[str, Any]:
"""回答技术问题"""
# 技术问题预处理
processed_question = self._preprocess_technical_query(question)

# 执行检索和生成
result = self.rag_system.query(processed_question)

# 技术回答后处理
enhanced_result = self._enhance_technical_answer(result, question, context)

return enhanced_result

def _preprocess_technical_query(self, query: str) -> str:
"""预处理技术查询"""
# 识别代码片段
# 提取技术术语
# 标准化API名称
return query

def _enhance_technical_answer(self, result: Dict[str, Any],
question: str, context: Dict[str, Any]) -> Dict[str, Any]:
"""增强技术回答"""
enhanced = result.copy()

# 添加代码示例
code_examples = self._extract_code_examples(result['retrieved_documents'])
enhanced['code_examples'] = code_examples

# 添加相关API链接
api_links = self._extract_api_links(result['retrieved_documents'])
enhanced['api_references'] = api_links

# 添加相关概念
related_concepts = self._find_related_concepts(question)
enhanced['related_concepts'] = related_concepts

return enhanced

def _extract_code_examples(self, documents: List[Dict[str, Any]]) -> List[str]:
"""提取代码示例"""
code_examples = []

for doc in documents:
content = doc.get('content', '')
# 使用正则表达式提取代码块
import re
code_blocks = re.findall(r'```[\s\S]*?```', content)
code_examples.extend(code_blocks)

return code_examples[:3] # 限制数量

def _extract_api_links(self, documents: List[Dict[str, Any]]) -> List[str]:
"""提取API链接"""
api_links = []

for doc in documents:
if 'api_url' in doc:
api_links.append(doc['api_url'])

return api_links

def _find_related_concepts(self, question: str) -> List[str]:
"""查找相关概念"""
# 从知识图谱中查找相关概念
related = []

# 简化实现
tech_terms = ['API', '函数', '类', '模块', '库', '框架']
for term in tech_terms:
if term in question:
related.append(term)

return related

class TechnicalGenerator(AdvancedGenerator):
"""技术文档专用生成器"""

def _postprocess_response(self, response: str, query: str) -> str:
"""技术回答后处理"""
processed = super()._postprocess_response(response, query)

# 格式化代码片段
processed = self._format_code_snippets(processed)

# 添加技术术语解释
processed = self._add_term_explanations(processed)

return processed

def _format_code_snippets(self, text: str) -> str:
"""格式化代码片段"""
# 简单的代码格式化
import re

# 识别可能的代码片段并添加格式
code_pattern = r'([a-zA-Z_][a-zA-Z0-9_]*\([^)]*\))' # 函数调用
text = re.sub(code_pattern, r'`\1`', text)

return text

def _add_term_explanations(self, text: str) -> str:
"""添加术语解释"""
# 这里可以添加技术术语的简短解释
return text

6. 挑战与解决方案

6.1 主要技术挑战

  1. 检索质量问题

    • 语义鸿沟:查询与文档之间的语义差异
    • 长尾查询:罕见或特定领域的查询
    • 多模态检索:文本、图像、表格等多种模态的统一检索
  2. 生成质量问题

    • 幻觉现象:生成与检索内容不一致的信息
    • 上下文长度限制:大模型输入长度的限制
    • 一致性保持:多轮对话中的一致性维护
  3. 系统性能问题

    • 延迟优化:实时响应的需求
    • 可扩展性:大规模部署的挑战
    • 资源消耗:计算和存储资源的优化

6.2 解决方案

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
class RAGOptimizationSuite:
"""RAG系统优化套件"""

def __init__(self):
self.optimization_strategies = {
'retrieval': self._get_retrieval_optimizations(),
'generation': self._get_generation_optimizations(),
'performance': self._get_performance_optimizations()
}

def _get_retrieval_optimizations(self) -> Dict[str, Any]:
"""检索优化策略"""
return {
'dense_sparse_hybrid': {
'description': '密集和稀疏检索的混合方法',
'implementation': self._implement_hybrid_retrieval
},
'query_expansion': {
'description': '查询扩展技术',
'implementation': self._implement_query_expansion
},
'reranking': {
'description': '检索结果重排序',
'implementation': self._implement_reranking
},
'multi_vector': {
'description': '多向量表示',
'implementation': self._implement_multi_vector
}
}

def _get_generation_optimizations(self) -> Dict[str, Any]:
"""生成优化策略"""
return {
'context_compression': {
'description': '上下文压缩技术',
'implementation': self._implement_context_compression
},
'iterative_refinement': {
'description': '迭代优化生成',
'implementation': self._implement_iterative_refinement
},
'fact_checking': {
'description': '事实核查机制',
'implementation': self._implement_fact_checking
},
'response_fusion': {
'description': '多候选回答融合',
'implementation': self._implement_response_fusion
}
}

def _get_performance_optimizations(self) -> Dict[str, Any]:
"""性能优化策略"""
return {
'caching': {
'description': '多层缓存策略',
'implementation': self._implement_caching
},
'async_processing': {
'description': '异步处理',
'implementation': self._implement_async_processing
},
'model_compression': {
'description': '模型压缩',
'implementation': self._implement_model_compression
},
'batch_processing': {
'description': '批处理优化',
'implementation': self._implement_batch_processing
}
}

def _implement_hybrid_retrieval(self) -> Any:
"""实现混合检索"""
class HybridDenseSparseRetriever(Retriever):
def __init__(self, dense_retriever, sparse_retriever, alpha=0.7):
self.dense_retriever = dense_retriever
self.sparse_retriever = sparse_retriever
self.alpha = alpha # 密集检索权重

def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
# 密集检索
dense_results = self.dense_retriever.retrieve(query, top_k * 2)

# 稀疏检索(如BM25)
sparse_results = self.sparse_retriever.retrieve(query, top_k * 2)

# 分数融合
combined_results = self._combine_scores(dense_results, sparse_results)

return combined_results[:top_k]

def _combine_scores(self, dense_results, sparse_results):
# 实现分数融合逻辑
pass

return HybridDenseSparseRetriever

def _implement_context_compression(self) -> Any:
"""实现上下文压缩"""
class ContextCompressor:
def __init__(self, compression_ratio=0.5):
self.compression_ratio = compression_ratio

def compress_context(self, context: str, query: str) -> str:
# 提取关键句子
sentences = context.split('。')

# 计算句子与查询的相关性
sentence_scores = []
for sentence in sentences:
score = self._calculate_relevance(sentence, query)
sentence_scores.append((sentence, score))

# 选择高相关性句子
sorted_sentences = sorted(sentence_scores, key=lambda x: x[1], reverse=True)
num_keep = int(len(sentences) * self.compression_ratio)

compressed_sentences = [s[0] for s in sorted_sentences[:num_keep]]
return '。'.join(compressed_sentences)

def _calculate_relevance(self, sentence: str, query: str) -> float:
# 简单的相关性计算
sentence_words = set(sentence.lower().split())
query_words = set(query.lower().split())

if not query_words:
return 0

overlap = len(sentence_words.intersection(query_words))
return overlap / len(query_words)

return ContextCompressor

def _implement_fact_checking(self) -> Any:
"""实现事实核查"""
class FactChecker:
def __init__(self, knowledge_base):
self.knowledge_base = knowledge_base

def check_facts(self, generated_text: str, source_docs: List[Dict[str, Any]]) -> Dict[str, Any]:
# 提取声明
claims = self._extract_claims(generated_text)

# 验证每个声明
verification_results = []
for claim in claims:
verification = self._verify_claim(claim, source_docs)
verification_results.append({
'claim': claim,
'verified': verification['verified'],
'confidence': verification['confidence'],
'source': verification.get('source')
})

return {
'overall_accuracy': np.mean([r['confidence'] for r in verification_results]),
'claim_verifications': verification_results
}

def _extract_claims(self, text: str) -> List[str]:
# 简单的声明提取
sentences = [s.strip() for s in text.split('。') if s.strip()]
return sentences

def _verify_claim(self, claim: str, source_docs: List[Dict[str, Any]]) -> Dict[str, Any]:
# 在源文档中查找支持证据
best_match_score = 0
best_source = None

claim_words = set(claim.lower().split())

for doc in source_docs:
doc_words = set(doc.get('content', '').lower().split())
overlap = len(claim_words.intersection(doc_words))
score = overlap / len(claim_words) if claim_words else 0

if score > best_match_score:
best_match_score = score
best_source = doc.get('source', 'unknown')

return {
'verified': best_match_score > 0.3,
'confidence': best_match_score,
'source': best_source
}

return FactChecker

def _implement_caching(self) -> Any:
"""实现缓存策略"""
class MultiLevelCache:
def __init__(self):
self.query_cache = {} # 查询结果缓存
self.embedding_cache = {} # 嵌入向量缓存
self.retrieval_cache = {} # 检索结果缓存

# 缓存配置
self.cache_config = {
'query_ttl': 3600, # 1小时
'embedding_ttl': 86400, # 24小时
'retrieval_ttl': 1800, # 30分钟
'max_cache_size': 10000
}

def get_cached_result(self, cache_type: str, key: str) -> Any:
cache = getattr(self, f"{cache_type}_cache")

if key in cache:
entry = cache[key]
if self._is_cache_valid(entry):
return entry['data']
else:
del cache[key]

return None

def set_cache(self, cache_type: str, key: str, data: Any):
cache = getattr(self, f"{cache_type}_cache")

# 检查缓存大小限制
if len(cache) >= self.cache_config['max_cache_size']:
self._evict_oldest(cache)

cache[key] = {
'data': data,
'timestamp': time.time(),
'ttl': self.cache_config[f"{cache_type}_ttl"]
}

def _is_cache_valid(self, entry: Dict[str, Any]) -> bool:
return time.time() - entry['timestamp'] < entry['ttl']

def _evict_oldest(self, cache: Dict[str, Any]):
# 移除最旧的缓存项
oldest_key = min(cache.keys(), key=lambda k: cache[k]['timestamp'])
del cache[oldest_key]

return MultiLevelCache

7. 总结与展望

7.1 核心贡献

本文深入探讨了知识图谱与大语言模型结合的RAG系统设计与优化,主要贡献包括:

  1. 系统架构设计:提出了完整的RAG系统架构,包括检索器、生成器和知识图谱的有机结合
  2. 优化策略:详细阐述了检索优化、生成优化和系统性能优化的具体方法
  3. 实际应用:展示了RAG系统在智能客服、技术文档问答等领域的应用案例
  4. 评估体系:建立了全面的RAG系统评估和监控框架

7.2 技术发展趋势

  1. 多模态RAG:支持文本、图像、音频等多种模态的统一检索和生成
  2. 实时更新:知识库的实时更新和增量学习能力
  3. 个性化定制:基于用户偏好和历史的个性化RAG系统
  4. 跨语言支持:多语言知识库的统一检索和跨语言生成

7.3 应用前景

RAG系统在以下领域具有广阔的应用前景:

  • 企业知识管理:构建智能化的企业知识库和问答系统
  • 教育培训:个性化的学习助手和智能答疑系统
  • 医疗健康:基于医学知识库的诊断辅助系统
  • 法律服务:法律条文检索和案例分析系统
  • 科研助手:学术文献检索和研究问题解答

7.4 未来挑战

  1. 知识一致性:确保检索知识与生成内容的一致性
  2. 可解释性:提高系统决策过程的透明度和可解释性
  3. 安全性:防范恶意查询和知识污染攻击
  4. 伦理考量:处理偏见、隐私和公平性问题

通过持续的技术创新和优化,RAG系统将在构建更加智能、可靠和有用的AI应用方面发挥重要作用,为人工智能技术的实际应用提供强有力的支撑。

参考文献

  1. Lewis, P., et al. (2020). Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks. NeurIPS.
  2. Karpukhin, V., et al. (2020). Dense Passage Retrieval for Open-Domain Question Answering. EMNLP.
  3. Guu, K., et al. (2020). REALM: Retrieval-Augmented Language Model Pre-Training. ICML.
  4. Borgeaud, S., et al. (2022). Improving Language Models by Retrieving from Trillions of Tokens. ICML.
  5. Izacard, G., & Grave, E. (2021). Leveraging Passage Retrieval with Generative Models for Open Domain Question Answering. EACL.

关键词:知识图谱, RAG, 检索增强生成, 大语言模型, 知识融合, 向量检索, 智能问答, 系统优化

版权所有,如有侵权请联系我