Apache Airflow - 工作流调度平台

项目简介

Apache Airflow是一个开源的工作流调度和监控平台,最初由Airbnb开发,后来成为Apache软件基金会的顶级项目。Airflow使用Python编写,允许用户以代码形式定义工作流,这些工作流以有向无环图(DAG)的形式组织任务。

Airflow的核心理念是”工作流即代码”,通过Python脚本定义任务依赖关系、调度规则和执行逻辑。这种方式使得工作流的版本控制、测试和维护变得更加容易。

主要特性

  • 动态工作流:通过Python代码动态生成工作流
  • 可扩展性:支持插件和自定义操作符
  • 丰富的UI:提供直观的Web界面管理和监控
  • 强大的调度:支持复杂的时间调度和依赖关系
  • 可靠性:任务失败重试和告警机制
  • 多种执行器:支持本地、集群和云端执行

项目原理

核心概念

DAG(有向无环图)

  • 工作流的定义,描述任务及其依赖关系
  • 包含调度信息和执行参数
  • Python文件形式存储

Task(任务)

  • DAG中的基本执行单元
  • 由Operator定义具体操作
  • 可以是Python函数、Bash命令、SQL查询等

Operator(操作符)

  • 任务的模板,定义要执行的操作
  • 内置多种操作符:BashOperator、PythonOperator、SQLOperator等
  • 支持自定义操作符

Scheduler(调度器)

  • 解析DAG文件,创建任务实例
  • 根据调度规则触发任务执行
  • 管理任务状态和依赖关系

架构组件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
Airflow架构
├── Web Server (Web服务器)
│ ├── 用户界面
│ ├── REST API
│ └── 认证授权
├── Scheduler (调度器)
│ ├── DAG解析
│ ├── 任务调度
│ └── 状态管理
├── Executor (执行器)
│ ├── LocalExecutor
│ ├── CeleryExecutor
│ └── KubernetesExecutor
├── Metadata Database
│ ├── DAG定义
│ ├── 任务状态
│ └── 执行历史
└── Worker Nodes
├── 任务执行
└── 结果返回

使用场景

1. 数据管道ETL

构建复杂的数据提取、转换和加载流程。

2. 机器学习流水线

协调模型训练、验证和部署的完整流程。

3. 业务流程自动化

自动化日常业务操作和报告生成。

4. 系统监控和维护

定期执行系统检查、备份和清理任务。

5. 多系统集成

协调不同系统间的数据同步和业务流程。

具体案例

案例1:基本DAG定义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from datetime import datetime, timedelta
from airflow import DAG
from airflow.operators.bash import BashOperator
from airflow.operators.python import PythonOperator
from airflow.operators.email import EmailOperator

# 默认参数
default_args = {
'owner': 'data_team',
'depends_on_past': False,
'start_date': datetime(2023, 1, 1),
'email_on_failure': True,
'email_on_retry': False,
'retries': 1,
'retry_delay': timedelta(minutes=5),
'email': ['admin@example.com']
}

# 创建DAG
dag = DAG(
'daily_data_pipeline',
default_args=default_args,
description='每日数据处理流水线',
schedule_interval=timedelta(days=1), # 每天执行一次
catchup=False, # 不执行历史任务
tags=['数据处理', 'ETL']
)

# Python函数任务
def extract_data(**context):
"""数据提取函数"""
import pandas as pd
from sqlalchemy import create_engine

# 连接数据库
engine = create_engine('postgresql://user:pass@localhost/db')

# 提取数据
query = """
SELECT * FROM sales
WHERE DATE(created_at) = '{{ ds }}'
"""

df = pd.read_csv(query)

# 保存到临时文件
output_path = f"/tmp/sales_data_{{ ds }}.csv"
df.to_csv(output_path, index=False)

print(f"提取了 {len(df)} 条记录到 {output_path}")
return output_path

def transform_data(**context):
"""数据转换函数"""
import pandas as pd

# 从上游任务获取文件路径
input_path = context['task_instance'].xcom_pull(task_ids='extract_data')

# 读取数据
df = pd.read_csv(input_path)

# 数据清洗和转换
df['amount'] = pd.to_numeric(df['amount'], errors='coerce')
df = df.dropna()
df['date'] = pd.to_datetime(df['date'])

# 数据聚合
daily_summary = df.groupby('category').agg({
'amount': ['sum', 'mean', 'count'],
'date': 'first'
}).round(2)

# 保存处理后的数据
output_path = f"/tmp/processed_sales_{{ ds }}.csv"
daily_summary.to_csv(output_path)

print(f"数据转换完成,保存到 {output_path}")
return output_path

def load_data(**context):
"""数据加载函数"""
import pandas as pd
from sqlalchemy import create_engine

# 获取处理后的数据
input_path = context['task_instance'].xcom_pull(task_ids='transform_data')
df = pd.read_csv(input_path)

# 连接目标数据库
engine = create_engine('postgresql://user:pass@localhost/warehouse')

# 加载数据
df.to_sql('daily_sales_summary', engine, if_exists='append', index=False)

print(f"成功加载 {len(df)} 条记录到数据仓库")

# 定义任务
extract_task = PythonOperator(
task_id='extract_data',
python_callable=extract_data,
dag=dag
)

transform_task = PythonOperator(
task_id='transform_data',
python_callable=transform_data,
dag=dag
)

load_task = PythonOperator(
task_id='load_data',
python_callable=load_data,
dag=dag
)

# 数据质量检查
quality_check = BashOperator(
task_id='quality_check',
bash_command="""
python /scripts/data_quality_check.py \
--date {{ ds }} \
--table daily_sales_summary
""",
dag=dag
)

# 生成报告
generate_report = BashOperator(
task_id='generate_report',
bash_command="""
python /scripts/generate_daily_report.py \
--date {{ ds }} \
--output /reports/daily_report_{{ ds }}.pdf
""",
dag=dag
)

# 发送邮件通知
send_notification = EmailOperator(
task_id='send_notification',
to=['team@example.com'],
subject='每日数据处理完成 - {{ ds }}',
html_content="""
<h3>每日数据处理完成</h3>
<p>处理日期: {{ ds }}</p>
<p>状态: 成功</p>
<p>报告已生成并发送。</p>
""",
dag=dag
)

# 设置任务依赖关系
extract_task >> transform_task >> load_task >> quality_check >> generate_report >> send_notification

案例2:机器学习流水线

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
from airflow import DAG
from airflow.operators.python import PythonOperator, BranchPythonOperator
from airflow.operators.dummy import DummyOperator
from datetime import datetime, timedelta

def prepare_data(**context):
"""数据准备"""
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import joblib

# 加载数据
df = pd.read_csv('/data/customer_data.csv')

# 特征工程
features = ['age', 'income', 'spending_score', 'membership_years']
target = 'customer_segment'

X = df[features]
y = df[target]

# 数据分割
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)

# 特征缩放
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# 保存预处理后的数据
joblib.dump(X_train_scaled, '/tmp/X_train.pkl')
joblib.dump(X_test_scaled, '/tmp/X_test.pkl')
joblib.dump(y_train, '/tmp/y_train.pkl')
joblib.dump(y_test, '/tmp/y_test.pkl')
joblib.dump(scaler, '/tmp/scaler.pkl')

print(f"数据准备完成: 训练集 {X_train_scaled.shape}, 测试集 {X_test_scaled.shape}")

def train_model(**context):
"""模型训练"""
import joblib
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report

# 加载数据
X_train = joblib.load('/tmp/X_train.pkl')
y_train = joblib.load('/tmp/y_train.pkl')

# 训练模型
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)

# 保存模型
model_path = f'/models/customer_segment_model_{{ ds }}.pkl'
joblib.dump(model, model_path)

print(f"模型训练完成,保存到: {model_path}")
return model_path

def evaluate_model(**context):
"""模型评估"""
import joblib
from sklearn.metrics import accuracy_score, classification_report
import json

# 加载模型和测试数据
model_path = context['task_instance'].xcom_pull(task_ids='train_model')
model = joblib.load(model_path)

X_test = joblib.load('/tmp/X_test.pkl')
y_test = joblib.load('/tmp/y_test.pkl')

# 预测和评估
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
report = classification_report(y_test, y_pred, output_dict=True)

# 保存评估结果
evaluation_result = {
'accuracy': accuracy,
'classification_report': report,
'model_path': model_path
}

with open(f'/tmp/evaluation_{{ ds }}.json', 'w') as f:
json.dump(evaluation_result, f, indent=2)

print(f"模型准确率: {accuracy:.4f}")

# 返回评估结果用于决策
return accuracy

def check_model_performance(**context):
"""检查模型性能"""
accuracy = context['task_instance'].xcom_pull(task_ids='evaluate_model')

# 设置性能阈值
threshold = 0.85

if accuracy >= threshold:
return 'deploy_model'
else:
return 'retrain_required'

def deploy_model(**context):
"""部署模型"""
import shutil

model_path = context['task_instance'].xcom_pull(task_ids='train_model')

# 复制模型到生产环境
production_path = '/production/models/customer_segment_model.pkl'
shutil.copy2(model_path, production_path)

# 更新模型版本信息
version_info = {
'model_path': production_path,
'deployment_date': '{{ ds }}',
'accuracy': context['task_instance'].xcom_pull(task_ids='evaluate_model')
}

with open('/production/models/version_info.json', 'w') as f:
json.dump(version_info, f, indent=2)

print(f"模型已部署到生产环境: {production_path}")

# 创建ML DAG
ml_dag = DAG(
'ml_pipeline',
default_args=default_args,
description='机器学习模型训练和部署流水线',
schedule_interval='@weekly',
catchup=False
)

# 定义任务
prepare_task = PythonOperator(
task_id='prepare_data',
python_callable=prepare_data,
dag=ml_dag
)

train_task = PythonOperator(
task_id='train_model',
python_callable=train_model,
dag=ml_dag
)

evaluate_task = PythonOperator(
task_id='evaluate_model',
python_callable=evaluate_model,
dag=ml_dag
)

check_performance = BranchPythonOperator(
task_id='check_model_performance',
python_callable=check_model_performance,
dag=ml_dag
)

deploy_task = PythonOperator(
task_id='deploy_model',
python_callable=deploy_model,
dag=ml_dag
)

retrain_task = DummyOperator(
task_id='retrain_required',
dag=ml_dag
)

# 设置依赖关系
prepare_task >> train_task >> evaluate_task >> check_performance
check_performance >> deploy_task
check_performance >> retrain_task

案例3:自定义操作符

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
import requests
import json

class SlackNotificationOperator(BaseOperator):
"""自定义Slack通知操作符"""

@apply_defaults
def __init__(self,
slack_webhook_url,
message,
channel='#general',
username='Airflow',
*args, **kwargs):
super().__init__(*args, **kwargs)
self.slack_webhook_url = slack_webhook_url
self.message = message
self.channel = channel
self.username = username

def execute(self, context):
# 构建Slack消息
slack_message = {
'channel': self.channel,
'username': self.username,
'text': self.message,
'attachments': [
{
'color': 'good' if context.get('task_instance').state == 'success' else 'danger',
'fields': [
{
'title': 'DAG',
'value': context['dag'].dag_id,
'short': True
},
{
'title': 'Task',
'value': context['task'].task_id,
'short': True
},
{
'title': 'Execution Date',
'value': str(context['execution_date']),
'short': True
}
]
}
]
}

# 发送到Slack
response = requests.post(
self.slack_webhook_url,
data=json.dumps(slack_message),
headers={'Content-Type': 'application/json'}
)

if response.status_code != 200:
raise Exception(f"Slack通知失败: {response.text}")

self.log.info("Slack通知发送成功")

class DatabaseBackupOperator(BaseOperator):
"""数据库备份操作符"""

@apply_defaults
def __init__(self,
connection_id,
backup_path,
tables=None,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.connection_id = connection_id
self.backup_path = backup_path
self.tables = tables or []

def execute(self, context):
from airflow.hooks.postgres_hook import PostgresHook
import subprocess
import os

# 获取数据库连接
hook = PostgresHook(postgres_conn_id=self.connection_id)
connection = hook.get_connection(self.connection_id)

# 构建备份命令
backup_file = f"{self.backup_path}/backup_{context['ds']}.sql"

cmd = [
'pg_dump',
'-h', connection.host,
'-p', str(connection.port),
'-U', connection.login,
'-d', connection.schema,
'-f', backup_file,
'--verbose'
]

# 如果指定了表,只备份这些表
if self.tables:
for table in self.tables:
cmd.extend(['-t', table])

# 设置密码环境变量
env = os.environ.copy()
env['PGPASSWORD'] = connection.password

# 执行备份
result = subprocess.run(cmd, env=env, capture_output=True, text=True)

if result.returncode != 0:
raise Exception(f"数据库备份失败: {result.stderr}")

self.log.info(f"数据库备份成功: {backup_file}")
return backup_file

# 使用自定义操作符
backup_task = DatabaseBackupOperator(
task_id='backup_database',
connection_id='postgres_default',
backup_path='/backups',
tables=['users', 'orders', 'products'],
dag=dag
)

notify_task = SlackNotificationOperator(
task_id='notify_completion',
slack_webhook_url='https://hooks.slack.com/services/YOUR/SLACK/WEBHOOK',
message='数据处理流水线执行完成!',
channel='#data-team',
dag=dag
)

案例4:动态DAG生成

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from airflow import DAG
from airflow.operators.python import PythonOperator
from datetime import datetime, timedelta
import yaml

# 读取配置文件
with open('/config/data_sources.yaml', 'r') as f:
data_sources = yaml.safe_load(f)

def create_etl_dag(source_config):
"""为每个数据源创建ETL DAG"""

dag_id = f"etl_{source_config['name']}"

dag = DAG(
dag_id,
default_args={
'owner': 'data_team',
'start_date': datetime(2023, 1, 1),
'retries': 1,
'retry_delay': timedelta(minutes=5)
},
description=f"ETL pipeline for {source_config['name']}",
schedule_interval=source_config['schedule'],
catchup=False
)

def extract_data(**context):
# 根据配置提取数据
source_type = source_config['type']
if source_type == 'database':
# 数据库提取逻辑
pass
elif source_type == 'api':
# API提取逻辑
pass
elif source_type == 'file':
# 文件提取逻辑
pass

def transform_data(**context):
# 根据配置转换数据
transformations = source_config.get('transformations', [])
for transform in transformations:
# 执行转换逻辑
pass

def load_data(**context):
# 加载数据到目标系统
target = source_config['target']
# 执行加载逻辑
pass

# 创建任务
extract_task = PythonOperator(
task_id='extract',
python_callable=extract_data,
dag=dag
)

transform_task = PythonOperator(
task_id='transform',
python_callable=transform_data,
dag=dag
)

load_task = PythonOperator(
task_id='load',
python_callable=load_data,
dag=dag
)

# 设置依赖
extract_task >> transform_task >> load_task

return dag

# 为每个数据源创建DAG
for source in data_sources['sources']:
dag_id = f"etl_{source['name']}"
globals()[dag_id] = create_etl_dag(source)

最佳实践

1. DAG设计原则

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 使用合理的默认参数
default_args = {
'owner': 'team_name',
'depends_on_past': False,
'start_date': datetime(2023, 1, 1),
'retries': 2,
'retry_delay': timedelta(minutes=5),
'email_on_failure': True,
'email_on_retry': False
}

# 设置合理的调度间隔
dag = DAG(
'my_dag',
default_args=default_args,
schedule_interval='@daily', # 使用cron表达式或预定义间隔
catchup=False, # 避免回填历史任务
max_active_runs=1 # 限制并发运行数
)

2. 错误处理和监控

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def task_failure_callback(context):
"""任务失败回调函数"""
task_instance = context['task_instance']
dag_id = context['dag'].dag_id

# 发送告警
send_alert(f"任务失败: {dag_id}.{task_instance.task_id}")

# 记录详细错误信息
log_error(context)

# 在DAG中使用回调
dag = DAG(
'monitored_dag',
default_args=default_args,
on_failure_callback=task_failure_callback
)

3. 资源配置和优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 使用资源池限制并发
task = PythonOperator(
task_id='heavy_task',
python_callable=heavy_computation,
pool='cpu_intensive_pool',
dag=dag
)

# 设置任务超时
task = BashOperator(
task_id='bash_task',
bash_command='long_running_script.sh',
execution_timeout=timedelta(hours=2),
dag=dag
)

Apache Airflow作为现代化的工作流调度平台,其”工作流即代码”的理念和丰富的功能特性使其成为构建复杂数据管道和自动化流程的理想选择。通过合理的设计和配置,Airflow可以为企业提供可靠、可扩展的工作流调度服务。

版权所有,如有侵权请联系我