先来几张图说明一下做这个的原因:

aws-ddos-attack-waf.png
aws-ddos-attack-waf-traffic.png
aws-ddos-attack-bill.png

一、前提条件

需要使用以下几个服务:

  1. AWS WAF 启用全局频率限制、全局黑名单、启用 logging 输出日志
  2. AWS CloudWatch 存储日志,攻击检测,发送告警到SNS
  3. AWS SNS 触发Lambda相关功能
  4. AWS Lambda 发送通知,更新 AWS WAF 全局黑名单
  5. Amazon EventBridge 定期执行 Lambda函数,更新 WAF 黑名单

1.1 AWS WAF 配置

需要在 WAF ACL 中设置以下几项:

  1. 创建一个 IP Sets 作为全局黑名单
  2. 创建一个全局黑名单规则,如果 IP 在全局黑名单中则 block,并设置优先级最高
  3. 根据业务需求创建一个全局频率限制规则,优先级仅次于黑名单规则
  4. 启用 WAF 的 logging,设置 Logging destination 为 CloudWatch,设置一个 filter : rule action on request 是 block 就 keep in logs

二、攻击检测及告警配置

2.1 Cloudwatch 配置

Cloudwatch 的告警基于 WAF ACL 中频率限制的规则。在 Cloudwatch 中设置一个告警,告警 Metric 基于 AWS/WAFV2 中 五分钟内的 BlockedRequests 总数,如果5分钟内被限制的请求大于 3000 则是 In alarm 状态,则触发 SNS Alert 告警通知;否则是 OK 状态,触发 SNS OK 通知 。

2.2 Amazon SNS 配置

SNS 需要配置两个 Standard 的 topic,一个是WAF-Ddos-Alert,一个是WAF-Ddos-Ok,用于触发 Lambda 发送攻击告警和攻击结束通知。

2.3 编写 Lambda 实现攻击检测和告警

新建一个 Lambda python function,代码放在下面,设置以下参数:

  1. trigger 设置为 SNS topic : WAF-Ddos-Alert
  2. Lambda 的 Execution role 需添加权限 : CloudWatchReadOnlyAccess CloudWatchLogsReadOnlyAccess
  3. 设置 telegram 相关的环境变量 : BOT_TOKEN 和 USER_WILL
  4. 设置超时时间为5分钟

代码如下:

import os
import json
import urllib3
import boto3
from datetime import datetime, timedelta
import time

def lambda_handler(event, context):
    
    # 从环境变量中获取 Telegram Bot Token 和聊天 ID
    bot_token = os.environ['BOT_TOKEN']
    user_will_chat_id = os.environ['USER_WILL']
    
    send_telegram_message(bot_token,user_will_chat_id, "请注意有DDos攻击!!!")
    
    #按实际情况修改
    log_group_name = 'aws-waf-logs-All-Blocked-Logs'
    
    # 获取当前时间和五分钟前的时间
    current_time = datetime.now()
    period = current_time - timedelta(minutes=15)
    
    # 创建 CloudWatch Logs Insights 客户端
    logs_client = boto3.client('logs')
    queries = [
        {
            'name': 'Top5_Host',
            'query': "fields @timestamp, @message | parse @message '{\"name\":\"Host\",\"value\":\"*\"}' as host | stats count(*) as requestCount by host | sort requestCount desc | limit 5"
        },
        {
            'name': 'Top5_IP',
            'query': "fields httpRequest.clientIp | stats count(*) as requestCount by httpRequest.clientIp | sort requestCount desc | limit 5"
        },
        {
            'name': 'Top5_Country',
            'query': "fields httpRequest.country | stats count(*) as requestCount by httpRequest.country | sort requestCount desc | limit 5"
        },
        {
            'name': 'Top5_Rule',
            'query': "fields terminatingRuleId | stats count(*) as requestCount by terminatingRuleId | sort requestCount desc | limit 5"
        }
    ]

    msg = '\n'
    for query in queries:
        response = logs_client.start_query(
            logGroupName=log_group_name,
            startTime=int(period.timestamp()),
            endTime=int(current_time.timestamp()),
            queryString=query['query']
        )
        query_id = response['queryId']
        print(f"Started query {query['name']}. Query ID: {query_id}")
    
        # 等待查询完成
        retries = 0
        max_retries = 5
        while retries < max_retries:
            query_response = logs_client.get_query_results(queryId=query_id)
            if len(query_response['results']) > 0 and 'status' in query_response:
                status = query_response['status']
                print(f"Query Status: {status}")
                print(query_response)
                
                if status == 'Complete':
                    break
            retries += 1
            if retries == max_retries:
                raise Exception("Query did not complete within the allowed time")
            else:
                print(f"Retrying {retries}...")
                time.sleep(10) 
                
        msg += f"{query['name']}: \n"
        for result in query_response['results']:
            request_item = [item['value'] for item in result ][0]
            request_count = [item['value'] for item in result if item['field'] == 'requestCount'][0]
            
            msg += f"{request_item} --> {request_count}\n"


    print(msg)
    
    send_telegram_message(bot_token,user_will_chat_id, msg)
    
def send_telegram_message(bot_token, chat_id, message):
    bot_api_url = f"https://api.telegram.org/bot{bot_token}/sendMessage"
    http = urllib3.PoolManager()

    # 准备请求数据
    data = {
        'chat_id': chat_id,
        'text': message
    }

    # 发送 POST 请求到 Telegram 机器人 API
    msg_response = http.request('POST', bot_api_url, fields=data)
    
    # 根据响应内容判断是否成功发送消息
    if json.loads(msg_response.data.decode('utf-8'))['ok']:
        print(f"Message sent to chat {chat_id}: {message}")
    else:
        print(f"Failed to send message to chat {chat_id}")

该 Lambda 函数会在收到SNS通知时,先发送攻击通知,然后从 Cloudwatch 中所有被 block 请求中查询访问次数最多的5个域名/5个IP/5个国家/5个匹配到的waf规则。在发送 Cloudwatch 查询到的攻击信息。

三、攻击缓解

有了以上的准备工作,缓解攻击只需要统计一下 、攻击的 IP,然后更新 WAF 的黑名单即可,为了快速检测攻击 IP,可以在Amazon EventBridge 中设置一个定时器,每五分钟触发一下 Lambda 更新 WAF的黑名单。

3.1 EventBridge 设置 Scheduler

在 EventBridge -> Scheduler 中新建一个定时器,类型为:Recurring schedule,Schedule type是 Rate-based schedule 时间设置成5分钟,Target 设置为 Lambda 下面的函数。

3.2 使用 Lambda 更新 WAF 黑名单

在创建一个 Lambda函数,配置参数如下:

  1. trigger 可以不设置,如果需要手动拉黑,在可以设置为一个新的SNS topic
  2. Lambda 的 Execution role 需添加权限 : CloudWatchReadOnlyAccess、CloudWatchLogsReadOnlyAccess、AWSWAFFullAccess
  3. 设置超时时间为5分钟

Lambda 代码如下:

import boto3
from datetime import datetime, timedelta
import time

def lambda_handler(event, context):
    log_group_name = 'aws-waf-logs-All-Blocked-Logs'
    ip_set_name = 'Global_blacklist-auto'
    
    # 获取当前时间和统计攻击的时间窗口
    current_time = datetime.now()
    period = current_time - timedelta(minutes=10)
    
    # 创建 CloudWatch Logs Insights 客户端
    logs_client = boto3.client('logs')
    
    print("Step 1: Executing query to filter blocked IP addresses...")
    
    # 执行查询语句,获取统计时间窗口内被 block 次数大于100次的 IP 地址列表
    query = "fields httpRequest.clientIp | stats count(*) as requestCount by httpRequest.clientIp | filter requestCount > 100 | sort requestCount desc"

    response = logs_client.start_query(
        logGroupName=log_group_name,
        startTime=int(period.timestamp()),
        endTime=int(current_time.timestamp()),
        queryString=query
    )
    query_id = response['queryId']
    print(f"Started with Query ID: {query_id}")
    
    print("Step 2: Waiting for query to complete...")
    
    # 等待查询完成
    retries = 0
    max_retries = 10
    while retries < max_retries:
        query_status = logs_client.get_query_results(queryId=query_id)
        if len(query_status['results']) > 0 and 'status' in query_status:
            status = query_status['status']
            print(f"Query Status: {status}")
            print(query_status)
            if status == 'Complete':
                break
        retries += 1
        if retries == max_retries:
            raise Exception("Query did not complete within the allowed time")
        else:
            print(f"Retrying {retries}...")
            time.sleep(10) 
            
    print("Step 3: Getting all blocked ips from cloudwatch...")
    # 提取 IP 地址并整理为 WAF 可以使用的格式
    ips = [result[0]['value'] for result in query_status['results']]
    
    # 去重 IP 地址
    ips = list(set(ips))
    print(f"taotal ip: {len(ips)}")
    print(ips)
    
    print("Step 4: Getting IP Addresses from ip_set in CloudFront WAF...")
    wafv2_client = boto3.client('wafv2')
    response = wafv2_client.list_ip_sets(
        Scope='CLOUDFRONT'
    )
    ip_sets = response['IPSets']
    
    ip_set_id = None
    lock_token = None

    for ip_set in ip_sets:
        if ip_set['Name'] == ip_set_name:
            ip_set_id = ip_set['Id']
            lock_token = ip_set['LockToken']
            break
    
    # 获取现有 IP 集的 IP 地址
    response = wafv2_client.get_ip_set(
        Name=ip_set_name,
        Id=ip_set_id,
        Scope='CLOUDFRONT'
    )
    existing_ip_addresses = response['IPSet']['Addresses']
    print(f"tatal existing ips: {len(existing_ip_addresses)}")
    print(existing_ip_addresses)
    
    print(f"Step 5: Updating IP Set {ip_set_name} in CloudFront WAF...")
    
    # 将 cidr_addresses 合并到现有 IP 地址中
    cidr_addresses = [ip + '/32' for ip in ips]
    merged_ip_addresses = list(set(existing_ip_addresses + cidr_addresses))
    print(merged_ip_addresses)
    
    response = wafv2_client.update_ip_set(
        Name=ip_set_name,
        Scope='CLOUDFRONT',
        Id=ip_set_id,
        LockToken=lock_token,
        Description='IP blacklist for Global WAF',
        Addresses=merged_ip_addresses
    )
    print(f"total blocked ips: {len(merged_ip_addresses)} ")

3.3 其他设置

其他措施主要是完善 WAF 规则,增强 lambda 脚本的功能 及 DNS 分流等,可以根据实际情况做很多配置。

四、告警通知和效果

攻击时需要注意一下 WAF/Cloudfront/Cloudwatch的费用,心里要有预期。
两个 Lambda 函数都有非常详细的日志打印到 Cloudwatch 的 log group 中,可以从 Cloudwatch中查看。
告警通知如下:
aws-ddos-attack-tg-msg.png