构建基于Nomad与ClickHouse的高吞吐可观测TensorFlow模型服务


最初的尝试是一场灾难。一个简单的 Flask 应用,通过 tensorflow.keras.models.load_model 加载模型,然后用 Gunicorn 部署了几个 worker。当流量只有每秒几次请求时,一切看起来都还不错。但当请求量上升到每秒数百次,问题就暴露无遗:延迟抖动剧烈,CPU 负载忽高忽低,最糟糕的是,我们对服务内部发生的事情几乎一无所知。日志是散乱的文本行,无法进行有效的聚合分析,我们甚至分不清是哪个版本的模型导致了性能下降。我们需要一个真正的生产级解决方案。

我们的目标很明确:构建一个高吞吞吐、低延迟、并且具备深度可观测性的模型推理服务。这意味着每一次预测请求——它的输入、模型的输出、执行耗时、所用模型版本——都必须被可靠、高效地记录下来,以便进行近乎实时的性能监控和业务分析。

技术选型决策

在复盘时,我们对整个技术栈进行了重新评估。

  1. 服务框架:Django vs. Flask
    虽然 Flask 轻量,但在生产环境中,我们需要更结构化的东西。Django 提供了强大的中间件系统、ORM(尽管我们在这里不会重度使用它连接主业务数据库)以及更成熟的生态用于异步任务处理。对于构建一个可观测的服务,Django 的中间件是实现请求级别日志记录的完美切入点。

  2. 调度与部署:Nomad vs. Kubernetes
    Kubernetes 是事实上的标准,但它的复杂性对于我们当前规模的团队而言是一种负担。Nomad 提供了一个更简单、更轻量级的替代方案。它是一个单一的二进制文件,易于部署和管理,并且其作业规约(HCL)直观清晰。对于部署无状态的 API 服务,Nomad 的能力绰绰有余,且运维成本远低于一个完整的 K8s 集群。

  3. 日志与分析存储:ClickHouse vs. Elasticsearch/PostgreSQL
    这是整个架构的核心决策。我们需要一个能承受极高写入吞吐量,并能对海量数据进行快速分析查询的数据库。

    • PostgreSQL: 优秀的事务型数据库,但不适合这种写入密集型的分析场景。高并发写入会迅速成为瓶颈。
    • Elasticsearch: 强大的全文搜索引擎,常用于日志存储。但在结构化数据的聚合分析性能上,尤其是在涉及大数据集的高基数维度查询时,通常不如列式数据库。
    • ClickHouse: 一个为在线分析处理(OLAP)而生的列式数据库管理系统。它的数据压缩率极高,写入速度惊人,聚合查询性能卓越。这正是我们需要的:一个专门用来存储和分析数以万亿计的模型推理日志的地方。

最终的技术栈组合拳是:TensorFlow 作为模型核心,Django 构筑 API 服务,Nomad 负责部署和生命周期管理,ClickHouse 作为可观测性数据的基石。

架构设计

整个系统的数据流非常清晰。

graph TD
    subgraph "Client Side"
        Client[Clients]
    end

    subgraph "Infrastructure"
        LB[Load Balancer]
    end

    subgraph "Nomad Cluster"
        Service[Django/TensorFlow Service]
    end
    
    subgraph "Data Pipeline"
        Queue[Async Task Queue]
        Writer[Log Batch Writer]
    end

    subgraph "Data Storage"
        CH[ClickHouse]
    end
    
    Client --> LB
    LB --> Service
    Service --"Log Payload (Non-blocking)"--> Queue
    Queue --> Writer
    Writer --"Bulk Insert"--> CH

关键在于服务层与数据记录层的解耦。Django 服务在处理完推理请求后,不是直接同步写入 ClickHouse,而是将日志载荷(包含请求元数据、输入、输出、耗时等)推送到一个异步任务队列中。一个独立的后台工作进程负责从队列中批量拉取日志,聚合成批次后一次性写入 ClickHouse,从而将对推理服务主流程的性能影响降至最低。

核心实现

1. ClickHouse 表结构设计

这是地基。一个好的表结构决定了查询性能的上限。我们为推理日志设计了如下的 MergeTree 表。

-- file: schema.sql
-- 在 ClickHouse 中执行
CREATE TABLE default.inference_logs (
    -- 时间与分区键
    `event_date` Date DEFAULT toDate(event_timestamp),
    `event_timestamp` DateTime64(3, 'Asia/Shanghai'),
    
    -- 请求元数据
    `request_id` UUID,
    `client_ip` IPv4,
    `endpoint` String,
    
    -- 模型元数据
    `model_name` String,
    `model_version` String,
    
    -- 性能指标
    `latency_ms` Float64,
    `cpu_usage_start` Float32,
    `cpu_usage_end` Float32,
    `memory_rss_mb` Float32,

    -- 业务数据 (输入与输出)
    -- 使用 JSON 字符串存储,灵活性高
    `request_body` String,
    `response_body` String,
    
    -- 状态
    `status_code` UInt16,
    `error_message` String
) ENGINE = MergeTree()
PARTITION BY toYYYYMM(event_date)
ORDER BY (model_name, model_version, event_timestamp)
SETTINGS index_granularity = 8192;

设计考量:

  • ENGINE = MergeTree(): ClickHouse 的核心引擎,适用于高负载任务。
  • PARTITION BY toYYYYMM(event_date): 按月分区。这使得删除旧数据或对特定月份进行查询极为高效。
  • ORDER BY (model_name, model_version, event_timestamp): 排序键是性能优化的关键。将查询中频繁作为过滤条件的字段放在前面(如模型名称和版本),ClickHouse 可以利用其稀疏索引快速跳过不相关的数据块。

2. Django 推理服务与异步日志中间件

我们将日志记录逻辑封装在一个 Django 中间件中,使其对业务视图代码完全透明。

predictor/middleware.py

import time
import uuid
import threading
from typing import Callable
from django.http import HttpRequest, JsonResponse

from . import logging_client # 我们稍后会实现这个

# 使用一个简单的线程本地存储来在请求生命周期内传递数据
_local = threading.local()

class ObservabilityMiddleware:
    def __init__(self, get_response: Callable):
        self.get_response = get_response

    def __call__(self, request: HttpRequest):
        # 请求处理前
        start_time = time.perf_counter()
        _local.request_id = uuid.uuid4()

        response = self.get_response(request)

        # 请求处理后
        latency_ms = (time.perf_counter() - start_time) * 1000
        
        # 从视图或模型逻辑中获取模型信息
        model_name = getattr(_local, 'model_name', 'unknown')
        model_version = getattr(_local, 'model_version', 'unknown')
        
        try:
            # 异步记录日志,不阻塞响应
            logging_client.log_inference_event(
                request_id=_local.request_id,
                request=request,
                response=response,
                latency_ms=latency_ms,
                model_name=model_name,
                model_version=model_version
            )
        except Exception as e:
            # 在真实项目中,这里应该有更健壮的错误处理,例如记录到备用日志文件
            print(f"Failed to log inference event: {e}")

        # 将 request_id 添加到响应头,方便客户端追踪
        response['X-Request-ID'] = str(_local.request_id)
        
        # 清理线程本地存储
        del _local.request_id
        if hasattr(_local, 'model_name'): del _local.model_name
        if hasattr(_local, 'model_version'): del _local.model_version

        return response

    @staticmethod
    def set_model_info(name: str, version: str):
        """
        一个静态方法,允许视图在内部设置模型信息
        """
        _local.model_name = name
        _local.model_version = version

predictor/views.py

import os
import json
import tensorflow as tf
from django.http import HttpRequest, JsonResponse
from django.views.decorators.csrf import csrf_exempt
from django.views.decorators.http import require_http_methods

from .middleware import ObservabilityMiddleware

# 从环境变量中读取模型路径,这是由 Nomad job file 注入的
MODEL_PATH = os.environ.get("MODEL_PATH", "models/default")
MODEL_NAME = os.environ.get("MODEL_NAME", "default_model")
MODEL_VERSION = os.environ.get("MODEL_VERSION", "v1.0")

# 全局加载模型,避免每次请求都重新加载
# 这是一个简化的示例,生产环境中需要更复杂的模型管理和热更新机制
try:
    model = tf.keras.models.load_model(MODEL_PATH)
except Exception as e:
    # 如果模型加载失败,服务应该无法启动
    raise RuntimeError(f"Failed to load model from {MODEL_PATH}: {e}")

@csrf_exempt
@require_http_methods(["POST"])
def predict(request: HttpRequest):
    try:
        # 在请求处理的早期设置模型信息
        ObservabilityMiddleware.set_model_info(name=MODEL_NAME, version=MODEL_VERSION)

        data = json.loads(request.body)
        
        # 假设模型需要一个名为 'features' 的 numpy 数组
        features = np.array(data['features'])
        
        # 模型推理
        prediction = model.predict(features)
        
        response_data = {
            'prediction': prediction.tolist(),
            'model_version': MODEL_VERSION
        }
        return JsonResponse(response_data, status=200)

    except json.JSONDecodeError:
        return JsonResponse({'error': 'Invalid JSON'}, status=400)
    except KeyError:
        return JsonResponse({'error': 'Missing "features" key in request'}, status=400)
    except Exception as e:
        # 捕获所有其他异常,并记录
        return JsonResponse({'error': str(e)}, status=500)

3. 高效的 ClickHouse 日志客户端

这是解耦的关键。我们使用 ThreadPoolExecutor 和一个 queue 来实现一个简单的后台批量写入器。

predictor/logging_client.py

import atexit
import json
import queue
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Any

from clickhouse_driver import Client
from django.http import HttpRequest, JsonResponse

# ClickHouse 连接配置,从环境变量获取
CLICKHOUSE_HOST = os.environ.get("CLICKHOUSE_HOST", "localhost")
CLICKHOUSE_PORT = int(os.environ.get("CLICKHOUSE_PORT", 9000))

# 日志批处理配置
MAX_BATCH_SIZE = 1000
MAX_BATCH_INTERVAL_SECONDS = 5.0

# 使用队列在主线程和工作线程之间传递日志
log_queue = queue.Queue(maxsize=MAX_BATCH_SIZE * 10) # 设置一个合理的缓冲区大小

def get_client_ip(request: HttpRequest) -> str:
    x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
    if x_forwarded_for:
        return x_forwarded_for.split(',')[0]
    return request.META.get('REMOTE_ADDR', '0.0.0.0')

class ClickHouseLogWriter:
    def __init__(self):
        self.client = Client(host=CLICKHOUSE_HOST, port=CLICKHOUSE_PORT)
        self.executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="ch_writer")
        self.running = True
        self.worker_thread = threading.Thread(target=self._batch_writer_loop, daemon=True)
        self.worker_thread.start()
        atexit.register(self.shutdown)

    def _batch_writer_loop(self):
        """
        后台工作线程,循环地从队列中获取日志并批量写入
        """
        while self.running:
            batch = []
            start_time = time.time()
            
            # 收集一个批次
            while len(batch) < MAX_BATCH_SIZE and (time.time() - start_time) < MAX_BATCH_INTERVAL_SECONDS:
                try:
                    # 使用超时 get,避免永久阻塞
                    log_entry = log_queue.get(timeout=0.1)
                    batch.append(log_entry)
                    log_queue.task_done()
                except queue.Empty:
                    break # 队列为空,跳出内部循环
            
            if not batch:
                time.sleep(0.1) # 队列为空时,稍作休息
                continue

            try:
                self.client.execute("INSERT INTO default.inference_logs VALUES", batch)
            except Exception as e:
                # 生产环境中,失败的批次应该被持久化到磁盘以便重试
                print(f"Error writing batch to ClickHouse: {e}")

    def log_inference_event(self, request_id, request, response, latency_ms, model_name, model_version):
        # 构造日志字典
        log_entry = {
            'event_timestamp': time.time(),
            'request_id': request_id,
            'client_ip': get_client_ip(request),
            'endpoint': request.path,
            'model_name': model_name,
            'model_version': model_version,
            'latency_ms': latency_ms,
            'cpu_usage_start': 0.0, # 简化,实际中需要用 psutil 等库获取
            'cpu_usage_end': 0.0,
            'memory_rss_mb': 0.0,
            'request_body': request.body.decode('utf-8', errors='ignore'),
            'response_body': response.content.decode('utf-8', errors='ignore') if isinstance(response, JsonResponse) else '',
            'status_code': response.status_code,
            'error_message': ''
        }
        
        # 如果是错误响应,尝试提取错误信息
        if response.status_code >= 400 and isinstance(response, JsonResponse):
            try:
                error_content = json.loads(response.content)
                log_entry['error_message'] = str(error_content.get('error', ''))
            except:
                pass

        try:
            # 非阻塞地放入队列
            log_queue.put_nowait(log_entry)
        except queue.Full:
            print("Warning: Log queue is full. Dropping inference log.")

    def shutdown(self):
        print("Shutting down ClickHouseLogWriter...")
        self.running = False
        # 等待队列中的所有任务完成
        log_queue.join()
        self.executor.shutdown(wait=True)
        print("Shutdown complete.")

# 创建一个全局实例
_writer_instance = ClickHouseLogWriter()

def log_inference_event(**kwargs):
    _writer_instance.log_inference_event(**kwargs)

4. Dockerfile

一个用于生产的、多阶段构建的 Dockerfile。

# ---- Base Stage ----
FROM python:3.9-slim as base
ENV PYTHONDONTWRITEBYTECODE 1
ENV PYTHONUNBUFFERED 1

WORKDIR /app

# Install system dependencies if any
# RUN apt-get update && apt-get install -y ...

COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# ---- Final Stage ----
FROM base as final

COPY . .

# Django Gunicorn production server command
# NUM_WORKERS 可以通过环境变量在 Nomad job 中动态设置
CMD ["gunicorn", "--bind", "0.0.0.0:8000", "--workers", "${NUM_WORKERS:-3}", "core.wsgi:application"]

5. Nomad 作业规约

这是将所有部分粘合在一起的蓝图。

inference-service.nomad

job "inference-service" {
  datacenters = ["dc1"]
  type = "service"

  group "api" {
    count = 2 # 运行 2 个实例实现冗余

    network {
      port "http" {
        to = 8000
      }
    }
    
    # 使用 Consul 进行服务发现
    service {
      name     = "inference-api"
      port     = "http"
      provider = "consul"

      check {
        type     = "http"
        path     = "/healthz/" # Django 中需要实现这个简单的健康检查接口
        interval = "10s"
        timeout  = "2s"
      }
    }

    task "server" {
      driver = "docker"

      config {
        image = "your-repo/inference-service:latest"
        ports = ["http"]
      }
      
      # 使用 artifact 下载模型文件,而不是将其打包到 Docker 镜像中
      # 这使得模型更新只需要更新 artifact 源并重启任务,无需重建镜像
      artifact {
        source      = "https://your-model-registry/models/anomaly-detector/v1.2.zip"
        destination = "local/models"
        options {
          checksum = "sha256:..."
        }
      }

      # 注入环境变量,将服务与基础设施连接起来
      env {
        MODEL_PATH      = "${NOMAD_TASK_DIR}/models/anomaly-detector-v1.2"
        MODEL_NAME      = "anomaly-detector"
        MODEL_VERSION   = "v1.2"
        NUM_WORKERS     = "4"
        CLICKHOUSE_HOST = "clickhouse.service.consul" # 通过 Consul DNS 解析
        CLICKHOUSE_PORT = "9000"
      }

      resources {
        cpu    = 1000 # 1 GHz
        memory = 2048 # 2 GB
      }
    }
  }
}

局限性与未来展望

这套架构解决了高吞吐下的可观测性问题,但它并非银弹。当前的模型加载机制是服务启动时的一次性加载,不支持热更新。在真实场景中,需要一套更成熟的模型管理方案,例如通过一个控制平面 API 通知服务实例从指定位置拉取新模型并平滑切换。

此外,异步日志队列虽然高效,但在服务崩溃时,内存中的日志队列会丢失数据。对于更严格的场景,可以考虑将日志先写入一个持久化的消息队列(如 Kafka 或 RabbitMQ),再由消费者写入 ClickHouse,从而实现更高的可靠性。

最后,数据的价值在于分析。基于 ClickHouse 中积累的丰富日志,我们可以构建实时监控仪表盘,计算模型的漂移指标,甚至可以训练一个元模型来预测推理服务的性能瓶颈。这条路才刚刚开始。


  目录