Production Operations and Distributed Deployment: Monitoring, Versioning, and Maintaining Edge AI at Scale

Production Operations and Distributed Deployment: Monitoring, Versioning, and Maintaining Edge AI at Scale

Production edge AI deployments require comprehensive operational infrastructure managing distributed device fleets, ensuring consistent performance, detecting degradation, and enabling safe model updates. This final post in the series covers production operations at scale including Prometheus and Jaeger integration for distributed monitoring, data drift detection mechanisms identifying model degradation, model versioning strategies with semantic versioning and registry management, canary deployment patterns enabling safe rollouts, OTA update procedures for zero-downtime upgrades, comprehensive health checking across device fleets, feedback loop implementation for continuous improvement, and orchestration patterns managing 100+ distributed edge devices.

Part 5 covered advanced optimization for maximizing edge performance. This concluding post addresses operational maturity: implementing observability for distributed systems, detecting when models require retraining, managing model lifecycles across device fleets, safely deploying updates without service interruption, monitoring device health proactively, and establishing feedback mechanisms enabling continuous model improvement based on production data.

Distributed Monitoring with Prometheus and Jaeger

Observability across distributed edge deployments requires standardized metrics collection, distributed tracing, and centralized aggregation enabling fleet-wide visibility.

flowchart TD
    A[Edge Device 1] --> B[Prometheus Exporter]
    C[Edge Device 2] --> D[Prometheus Exporter]
    E[Edge Device N] --> F[Prometheus Exporter]
    
    B --> G[Prometheus Server]
    D --> G
    F --> G
    
    G --> H[Grafana Dashboard]
    
    A --> I[Jaeger Agent]
    C --> J[Jaeger Agent]
    E --> K[Jaeger Agent]
    
    I --> L[Jaeger Collector]
    J --> L
    K --> L
    
    L --> M[Jaeger Storage]
    M --> N[Jaeger UI]
    
    G --> O[Alertmanager]
    O --> P[Notifications]
    
    H --> Q[Operations Team]
    N --> Q
    P --> Q

Prometheus Metrics Exporter (Python):

#!/usr/bin/env python3
"""
Prometheus metrics exporter for edge inference service
"""

from prometheus_client import start_http_server, Gauge, Counter, Histogram, Info
import time
import psutil
import subprocess
import torch

class EdgeMetricsExporter:
    """
    Export comprehensive edge device metrics to Prometheus
    """
    
    def __init__(self, device_id, port=8000):
        self.device_id = device_id
        self.port = port
        
        # Device information
        self.device_info = Info('edge_device', 'Edge device information')
        self.device_info.info({
            'device_id': device_id,
            'platform': self.get_platform_info(),
            'cuda_version': torch.version.cuda or 'N/A'
        })
        
        # Inference metrics
        self.inference_latency = Histogram(
            'inference_latency_seconds',
            'Inference latency in seconds',
            buckets=[0.01, 0.02, 0.03, 0.05, 0.1, 0.2, 0.5, 1.0]
        )
        
        self.inference_requests = Counter(
            'inference_requests_total',
            'Total inference requests',
            ['model', 'status']
        )
        
        self.active_requests = Gauge(
            'inference_active_requests',
            'Currently active inference requests'
        )
        
        self.queue_depth = Gauge(
            'inference_queue_depth',
            'Inference request queue depth'
        )
        
        # GPU metrics
        self.gpu_utilization = Gauge(
            'gpu_utilization_percent',
            'GPU utilization percentage'
        )
        
        self.gpu_memory_used = Gauge(
            'gpu_memory_used_bytes',
            'GPU memory used in bytes'
        )
        
        self.gpu_memory_total = Gauge(
            'gpu_memory_total_bytes',
            'Total GPU memory in bytes'
        )
        
        self.gpu_temperature = Gauge(
            'gpu_temperature_celsius',
            'GPU temperature in Celsius'
        )
        
        self.gpu_power = Gauge(
            'gpu_power_watts',
            'GPU power consumption in watts'
        )
        
        # System metrics
        self.cpu_utilization = Gauge(
            'cpu_utilization_percent',
            'CPU utilization percentage'
        )
        
        self.memory_used = Gauge(
            'memory_used_bytes',
            'System memory used in bytes'
        )
        
        self.memory_total = Gauge(
            'memory_total_bytes',
            'Total system memory in bytes'
        )
        
        self.disk_used = Gauge(
            'disk_used_bytes',
            'Disk space used in bytes'
        )
        
        # Model metrics
        self.model_loaded = Gauge(
            'model_loaded',
            'Model load status (1=loaded, 0=not loaded)',
            ['model_name', 'version']
        )
        
        self.model_load_time = Histogram(
            'model_load_time_seconds',
            'Model loading time in seconds'
        )
        
        # Health metrics
        self.device_health = Gauge(
            'device_health_status',
            'Device health status (1=healthy, 0=unhealthy)'
        )
    
    def get_platform_info(self):
        """Get Jetson platform information"""
        try:
            result = subprocess.run(
                ['cat', '/proc/device-tree/model'],
                capture_output=True,
                text=True
            )
            return result.stdout.strip()
        except:
            return 'Unknown'
    
    def update_gpu_metrics(self):
        """Update GPU metrics"""
        try:
            if torch.cuda.is_available():
                # Memory
                free_memory, total_memory = torch.cuda.mem_get_info()
                used_memory = total_memory - free_memory
                
                self.gpu_memory_used.set(used_memory)
                self.gpu_memory_total.set(total_memory)
                
                # Utilization (requires nvidia-smi or jetson_stats)
                try:
                    result = subprocess.run(
                        ['nvidia-smi', '--query-gpu=utilization.gpu,temperature.gpu,power.draw',
                         '--format=csv,noheader,nounits'],
                        capture_output=True,
                        text=True
                    )
                    
                    if result.returncode == 0:
                        util, temp, power = result.stdout.strip().split(',')
                        self.gpu_utilization.set(float(util))
                        self.gpu_temperature.set(float(temp))
                        self.gpu_power.set(float(power))
                except:
                    pass
        except Exception as e:
            print(f"Error updating GPU metrics: {e}")
    
    def update_system_metrics(self):
        """Update system metrics"""
        try:
            # CPU
            cpu_percent = psutil.cpu_percent(interval=1)
            self.cpu_utilization.set(cpu_percent)
            
            # Memory
            memory = psutil.virtual_memory()
            self.memory_used.set(memory.used)
            self.memory_total.set(memory.total)
            
            # Disk
            disk = psutil.disk_usage('/')
            self.disk_used.set(disk.used)
            
        except Exception as e:
            print(f"Error updating system metrics: {e}")
    
    def record_inference(self, model_name, latency, success):
        """Record inference metrics"""
        self.inference_latency.observe(latency)
        status = 'success' if success else 'error'
        self.inference_requests.labels(model=model_name, status=status).inc()
    
    def set_queue_depth(self, depth):
        """Set current queue depth"""
        self.queue_depth.set(depth)
    
    def set_active_requests(self, count):
        """Set active request count"""
        self.active_requests.set(count)
    
    def set_model_status(self, model_name, version, loaded):
        """Set model load status"""
        self.model_loaded.labels(model_name=model_name, version=version).set(1 if loaded else 0)
    
    def set_health_status(self, healthy):
        """Set device health status"""
        self.device_health.set(1 if healthy else 0)
    
    def start(self):
        """Start metrics HTTP server"""
        start_http_server(self.port)
        print(f"Metrics server started on port {self.port}")
        
        # Start background update thread
        import threading
        self.running = True
        self.update_thread = threading.Thread(target=self.update_loop)
        self.update_thread.daemon = True
        self.update_thread.start()
    
    def update_loop(self):
        """Background metrics update loop"""
        while self.running:
            self.update_gpu_metrics()
            self.update_system_metrics()
            time.sleep(5)  # Update every 5 seconds
    
    def stop(self):
        """Stop metrics exporter"""
        self.running = False

# Example usage
if __name__ == '__main__':
    exporter = EdgeMetricsExporter(device_id='jetson-001', port=8000)
    exporter.start()
    
    # Simulate some metrics
    exporter.set_model_status('yolov8n', 'v1.0.0', True)
    exporter.set_health_status(True)
    
    for i in range(100):
        # Simulate inference
        latency = 0.02 + (0.01 * (i % 10) / 10)
        success = i % 20 != 0
        exporter.record_inference('yolov8n', latency, success)
        
        exporter.set_queue_depth(i % 5)
        exporter.set_active_requests(min(i % 8, 4))
        
        time.sleep(0.1)
    
    # Keep running
    try:
        while True:
            time.sleep(1)
    except KeyboardInterrupt:
        exporter.stop()

Jaeger Distributed Tracing Integration (Python):

#!/usr/bin/env python3
"""
Jaeger distributed tracing for edge inference pipeline
"""

from jaeger_client import Config
from opentracing.ext import tags
from opentracing.propagation import Format
import time

class EdgeTracingService:
    """
    Distributed tracing service for edge inference
    """
    
    def __init__(self, service_name, agent_host='localhost', agent_port=6831):
        self.service_name = service_name
        
        # Configure Jaeger
        config = Config(
            config={
                'sampler': {
                    'type': 'const',
                    'param': 1,  # Sample all traces
                },
                'local_agent': {
                    'reporting_host': agent_host,
                    'reporting_port': agent_port,
                },
                'logging': True,
            },
            service_name=service_name,
            validate=True
        )
        
        self.tracer = config.initialize_tracer()
    
    def trace_inference(self, model_name, input_data):
        """Trace complete inference pipeline"""
        
        with self.tracer.start_active_span('inference_request') as scope:
            span = scope.span
            span.set_tag('model', model_name)
            span.set_tag('input_size', len(input_data))
            span.set_tag(tags.SPAN_KIND, tags.SPAN_KIND_RPC_SERVER)
            
            try:
                # Preprocessing
                preprocessed = self.trace_preprocessing(input_data, span)
                
                # Model inference
                result = self.trace_model_inference(model_name, preprocessed, span)
                
                # Postprocessing
                final_result = self.trace_postprocessing(result, span)
                
                span.set_tag('status', 'success')
                span.set_tag('detections', len(final_result))
                
                return final_result
                
            except Exception as e:
                span.set_tag('error', True)
                span.set_tag('error.message', str(e))
                span.log_kv({'event': 'error', 'message': str(e)})
                raise
    
    def trace_preprocessing(self, input_data, parent_span):
        """Trace preprocessing step"""
        
        with self.tracer.start_span('preprocessing', child_of=parent_span) as span:
            span.set_tag('operation', 'image_preprocessing')
            
            start = time.time()
            
            # Simulate preprocessing
            time.sleep(0.002)  # 2ms
            preprocessed = input_data
            
            duration = time.time() - start
            span.set_tag('duration_ms', duration * 1000)
            
            return preprocessed
    
    def trace_model_inference(self, model_name, input_data, parent_span):
        """Trace model inference step"""
        
        with self.tracer.start_span('model_inference', child_of=parent_span) as span:
            span.set_tag('model', model_name)
            span.set_tag('operation', 'tensorrt_inference')
            
            start = time.time()
            
            # Simulate inference
            time.sleep(0.015)  # 15ms
            result = {'detections': []}
            
            duration = time.time() - start
            span.set_tag('duration_ms', duration * 1000)
            
            return result
    
    def trace_postprocessing(self, inference_result, parent_span):
        """Trace postprocessing step"""
        
        with self.tracer.start_span('postprocessing', child_of=parent_span) as span:
            span.set_tag('operation', 'nms_filtering')
            
            start = time.time()
            
            # Simulate postprocessing
            time.sleep(0.003)  # 3ms
            final_result = inference_result
            
            duration = time.time() - start
            span.set_tag('duration_ms', duration * 1000)
            
            return final_result
    
    def close(self):
        """Close tracer"""
        self.tracer.close()

# Example usage
if __name__ == '__main__':
    tracing = EdgeTracingService('edge-inference-jetson-001')
    
    for i in range(10):
        dummy_input = b'dummy_image_data'
        result = tracing.trace_inference('yolov8n', dummy_input)
        print(f"Request {i}: {result}")
        time.sleep(0.1)
    
    tracing.close()

Data Drift Detection

Production models degrade as input data distributions shift from training distributions. Data drift detection identifies when retraining becomes necessary.

Data Drift Monitor (Python):

#!/usr/bin/env python3
"""
Data drift detection for production edge models
"""

import numpy as np
from collections import deque
from scipy.stats import ks_2samp, wasserstein_distance
import time

class DataDriftMonitor:
    """
    Monitor input data distribution for drift detection
    """
    
    def __init__(self, baseline_samples=1000, window_size=500, 
                 drift_threshold=0.1):
        self.baseline_samples = baseline_samples
        self.window_size = window_size
        self.drift_threshold = drift_threshold
        
        # Baseline distribution
        self.baseline_features = None
        self.baseline_established = False
        
        # Current window
        self.current_window = deque(maxlen=window_size)
        
        # Drift history
        self.drift_scores = deque(maxlen=1000)
        self.drift_alerts = []
        
        # Feature extraction
        self.feature_extractors = {
            'mean_intensity': self.extract_mean_intensity,
            'std_intensity': self.extract_std_intensity,
            'brightness': self.extract_brightness,
            'contrast': self.extract_contrast,
        }
    
    def extract_features(self, image):
        """Extract statistical features from image"""
        features = {}
        
        for name, extractor in self.feature_extractors.items():
            features[name] = extractor(image)
        
        return features
    
    def extract_mean_intensity(self, image):
        """Mean pixel intensity"""
        return np.mean(image)
    
    def extract_std_intensity(self, image):
        """Standard deviation of intensity"""
        return np.std(image)
    
    def extract_brightness(self, image):
        """Average brightness"""
        return np.mean(image) / 255.0
    
    def extract_contrast(self, image):
        """Contrast measure"""
        return np.std(image) / (np.mean(image) + 1e-7)
    
    def establish_baseline(self, images):
        """Establish baseline distribution from training/validation data"""
        
        print(f"Establishing baseline from {len(images)} images...")
        
        baseline_features = {name: [] for name in self.feature_extractors.keys()}
        
        for image in images[:self.baseline_samples]:
            features = self.extract_features(image)
            for name, value in features.items():
                baseline_features[name].append(value)
        
        # Convert to numpy arrays
        self.baseline_features = {
            name: np.array(values) 
            for name, values in baseline_features.items()
        }
        
        self.baseline_established = True
        print("Baseline established")
    
    def add_sample(self, image):
        """Add new sample to current window"""
        
        features = self.extract_features(image)
        self.current_window.append(features)
    
    def detect_drift(self):
        """Detect data drift using statistical tests"""
        
        if not self.baseline_established:
            return None, "Baseline not established"
        
        if len(self.current_window) < self.window_size:
            return None, f"Insufficient samples: {len(self.current_window)}/{self.window_size}"
        
        # Compute drift scores for each feature
        drift_results = {}
        
        for feature_name in self.feature_extractors.keys():
            baseline = self.baseline_features[feature_name]
            current = np.array([sample[feature_name] for sample in self.current_window])
            
            # Kolmogorov-Smirnov test
            ks_statistic, ks_pvalue = ks_2samp(baseline, current)
            
            # Wasserstein distance
            wasserstein = wasserstein_distance(baseline, current)
            
            drift_results[feature_name] = {
                'ks_statistic': ks_statistic,
                'ks_pvalue': ks_pvalue,
                'wasserstein': wasserstein,
                'drifted': ks_statistic > self.drift_threshold
            }
        
        # Overall drift score (max KS statistic)
        max_ks = max(result['ks_statistic'] for result in drift_results.values())
        overall_drift = max_ks > self.drift_threshold
        
        # Record drift score
        self.drift_scores.append({
            'timestamp': time.time(),
            'max_ks': max_ks,
            'drifted': overall_drift,
            'details': drift_results
        })
        
        # Generate alert if drifted
        if overall_drift:
            alert = {
                'timestamp': time.time(),
                'severity': 'warning' if max_ks < self.drift_threshold * 1.5 else 'critical',
                'max_ks': max_ks,
                'affected_features': [
                    name for name, result in drift_results.items() 
                    if result['drifted']
                ]
            }
            self.drift_alerts.append(alert)
            
            return True, drift_results
        
        return False, drift_results
    
    def get_drift_report(self):
        """Generate drift monitoring report"""
        
        if not self.drift_scores:
            return {'status': 'No data'}
        
        recent_scores = list(self.drift_scores)[-100:]
        max_scores = [score['max_ks'] for score in recent_scores]
        
        return {
            'baseline_samples': self.baseline_samples,
            'current_window_size': len(self.current_window),
            'total_samples_processed': len(self.drift_scores),
            'recent_max_ks': {
                'mean': np.mean(max_scores),
                'max': np.max(max_scores),
                'current': recent_scores[-1]['max_ks']
            },
            'drift_threshold': self.drift_threshold,
            'total_alerts': len(self.drift_alerts),
            'recent_alerts': self.drift_alerts[-5:]
        }

# Example usage
if __name__ == '__main__':
    monitor = DataDriftMonitor(
        baseline_samples=1000,
        window_size=500,
        drift_threshold=0.1
    )
    
    # Establish baseline with synthetic data
    baseline_images = [
        np.random.normal(120, 30, (640, 640, 3)).clip(0, 255).astype(np.uint8)
        for _ in range(1000)
    ]
    monitor.establish_baseline(baseline_images)
    
    # Simulate production samples (gradually shifting distribution)
    for i in range(2000):
        # Introduce gradual drift after 1000 samples
        if i < 1000:
            image = np.random.normal(120, 30, (640, 640, 3)).clip(0, 255).astype(np.uint8)
        else:
            # Shifted distribution
            shift = (i - 1000) / 1000 * 40  # Gradually shift mean
            image = np.random.normal(120 + shift, 30, (640, 640, 3)).clip(0, 255).astype(np.uint8)
        
        monitor.add_sample(image)
        
        # Check drift every 500 samples
        if i % 500 == 0 and i > 0:
            drifted, results = monitor.detect_drift()
            
            if drifted:
                print(f"\n⚠ DRIFT DETECTED at sample {i}")
                for feature, result in results.items():
                    if result['drifted']:
                        print(f"  {feature}: KS={result['ks_statistic']:.3f}")
            else:
                print(f"\n✓ No drift detected at sample {i}")
    
    # Final report
    report = monitor.get_drift_report()
    print(f"\n{'='*60}")
    print("Drift Monitoring Report:")
    print(f"Total samples: {report['total_samples_processed']}")
    print(f"Total alerts: {report['total_alerts']}")
    print(f"Recent max KS: {report['recent_max_ks']['current']:.3f}")

Model Versioning and Registry

Managing model versions across distributed deployments requires centralized registry with semantic versioning and metadata tracking.

Model Registry (Python):

#!/usr/bin/env python3
"""
Model registry for version management
"""

import json
import hashlib
import os
from dataclasses import dataclass, asdict
from typing import List, Optional
from datetime import datetime

@dataclass
class ModelMetadata:
    name: str
    version: str
    architecture: str
    framework: str
    precision: str
    input_shape: tuple
    mAP: float
    latency_ms: float
    model_size_mb: float
    training_date: str
    checksum: str
    tags: List[str]
    description: str

class ModelRegistry:
    """
    Centralized model registry with versioning
    """
    
    def __init__(self, registry_path='model_registry.json'):
        self.registry_path = registry_path
        self.models = {}
        self.load_registry()
    
    def load_registry(self):
        """Load registry from disk"""
        if os.path.exists(self.registry_path):
            with open(self.registry_path, 'r') as f:
                data = json.load(f)
                
                for model_key, model_data in data.items():
                    metadata = ModelMetadata(**model_data)
                    self.models[model_key] = metadata
    
    def save_registry(self):
        """Save registry to disk"""
        data = {
            key: asdict(metadata)
            for key, metadata in self.models.items()
        }
        
        with open(self.registry_path, 'w') as f:
            json.dump(data, f, indent=2)
    
    def compute_checksum(self, model_path):
        """Compute SHA256 checksum of model file"""
        sha256 = hashlib.sha256()
        
        with open(model_path, 'rb') as f:
            for chunk in iter(lambda: f.read(4096), b''):
                sha256.update(chunk)
        
        return sha256.hexdigest()
    
    def register_model(self, model_path: str, metadata: ModelMetadata):
        """Register new model version"""
        
        # Compute checksum
        metadata.checksum = self.compute_checksum(model_path)
        
        # Create unique key
        model_key = f"{metadata.name}:{metadata.version}"
        
        # Check if version already exists
        if model_key in self.models:
            raise ValueError(f"Model {model_key} already registered")
        
        # Register
        self.models[model_key] = metadata
        self.save_registry()
        
        print(f"Registered model: {model_key}")
        return model_key
    
    def get_model(self, name: str, version: Optional[str] = None):
        """Get model metadata"""
        
        if version:
            model_key = f"{name}:{version}"
            return self.models.get(model_key)
        else:
            # Get latest version
            versions = self.list_versions(name)
            if not versions:
                return None
            
            latest_version = self.get_latest_version(versions)
            model_key = f"{name}:{latest_version}"
            return self.models.get(model_key)
    
    def list_models(self):
        """List all registered models"""
        models = {}
        
        for key, metadata in self.models.items():
            if metadata.name not in models:
                models[metadata.name] = []
            models[metadata.name].append(metadata.version)
        
        return models
    
    def list_versions(self, name: str):
        """List all versions of a model"""
        return [
            metadata.version
            for key, metadata in self.models.items()
            if metadata.name == name
        ]
    
    def get_latest_version(self, versions: List[str]):
        """Get latest semantic version"""
        
        def parse_version(v):
            parts = v.lstrip('v').split('.')
            return tuple(int(p) for p in parts)
        
        return sorted(versions, key=parse_version, reverse=True)[0]
    
    def compare_models(self, name: str, version1: str, version2: str):
        """Compare two model versions"""
        
        model1 = self.get_model(name, version1)
        model2 = self.get_model(name, version2)
        
        if not model1 or not model2:
            return None
        
        comparison = {
            'version1': version1,
            'version2': version2,
            'mAP_diff': model2.mAP - model1.mAP,
            'latency_diff': model2.latency_ms - model1.latency_ms,
            'size_diff': model2.model_size_mb - model1.model_size_mb
        }
        
        return comparison
    
    def search_by_tag(self, tag: str):
        """Search models by tag"""
        results = []
        
        for key, metadata in self.models.items():
            if tag in metadata.tags:
                results.append((key, metadata))
        
        return results

# Example usage
if __name__ == '__main__':
    registry = ModelRegistry()
    
    # Register models
    models = [
        ModelMetadata(
            name='yolov8n',
            version='v1.0.0',
            architecture='YOLOv8',
            framework='TensorRT',
            precision='INT8',
            input_shape=(1, 3, 640, 640),
            mAP=37.3,
            latency_ms=18.5,
            model_size_mb=6.2,
            training_date='2025-01-10',
            checksum='',
            tags=['production', 'jetson-orin'],
            description='Initial production model'
        ),
        ModelMetadata(
            name='yolov8n',
            version='v1.1.0',
            architecture='YOLOv8',
            framework='TensorRT',
            precision='INT8',
            input_shape=(1, 3, 640, 640),
            mAP=38.1,
            latency_ms=19.2,
            model_size_mb=6.2,
            training_date='2025-01-15',
            checksum='',
            tags=['production', 'jetson-orin', 'improved-accuracy'],
            description='Retrained with additional data'
        )
    ]
    
    for metadata in models:
        # Assuming model files exist
        model_path = f"{metadata.name}_{metadata.version}.engine"
        if not os.path.exists(model_path):
            # Create dummy file for example
            with open(model_path, 'wb') as f:
                f.write(b'dummy_model_data')
        
        registry.register_model(model_path, metadata)
    
    # Query registry
    print("\nRegistered models:")
    for name, versions in registry.list_models().items():
        print(f"{name}: {versions}")
    
    # Get latest
    latest = registry.get_model('yolov8n')
    print(f"\nLatest yolov8n: {latest.version} (mAP: {latest.mAP})")
    
    # Compare versions
    comparison = registry.compare_models('yolov8n', 'v1.0.0', 'v1.1.0')
    print(f"\nComparison v1.0.0 vs v1.1.0:")
    print(f"  mAP improvement: {comparison['mAP_diff']:.2f}")
    print(f"  Latency change: {comparison['latency_diff']:.2f}ms")

Canary Deployment and Rollback

Safe model deployment requires gradual rollout with automated rollback on performance degradation.

Canary Deployment Manager (Python):

#!/usr/bin/env python3
"""
Canary deployment for safe model rollouts
"""

import time
import random
from dataclasses import dataclass
from typing import Optional
from collections import deque
import numpy as np

@dataclass
class CanaryConfig:
    initial_traffic_percent: float = 5.0
    increment_percent: float = 10.0
    increment_interval_seconds: float = 300  # 5 minutes
    max_error_rate: float = 0.02  # 2%
    max_latency_increase: float = 0.15  # 15%
    min_samples: int = 100

class CanaryDeploymentManager:
    """
    Manage canary deployments with automatic rollback
    """
    
    def __init__(self, config: CanaryConfig):
        self.config = config
        self.current_traffic_percent = 0.0
        self.deployment_active = False
        
        # Model versions
        self.stable_model = None
        self.canary_model = None
        
        # Metrics collection
        self.stable_metrics = deque(maxlen=1000)
        self.canary_metrics = deque(maxlen=1000)
        
        # Deployment tracking
        self.deployment_start = None
        self.last_increment = None
        self.rollback_triggered = False
    
    def start_deployment(self, stable_version, canary_version):
        """Start canary deployment"""
        
        print(f"\n{'='*60}")
        print(f"Starting canary deployment")
        print(f"Stable: {stable_version}")
        print(f"Canary: {canary_version}")
        print(f"Initial traffic: {self.config.initial_traffic_percent}%")
        print(f"{'='*60}\n")
        
        self.stable_model = stable_version
        self.canary_model = canary_version
        self.current_traffic_percent = self.config.initial_traffic_percent
        self.deployment_active = True
        self.deployment_start = time.time()
        self.last_increment = time.time()
        self.rollback_triggered = False
    
    def route_request(self):
        """Route request to stable or canary model"""
        
        if not self.deployment_active:
            return 'stable'
        
        # Random selection based on traffic percentage
        if random.random() * 100 < self.current_traffic_percent:
            return 'canary'
        else:
            return 'stable'
    
    def record_inference(self, model_type, latency, success):
        """Record inference metrics"""
        
        metric = {
            'latency': latency,
            'success': success,
            'timestamp': time.time()
        }
        
        if model_type == 'stable':
            self.stable_metrics.append(metric)
        elif model_type == 'canary':
            self.canary_metrics.append(metric)
    
    def evaluate_canary_health(self):
        """Evaluate canary model health"""
        
        if len(self.canary_metrics) < self.config.min_samples:
            return True, "Insufficient samples"
        
        # Get recent metrics
        recent_stable = list(self.stable_metrics)[-self.config.min_samples:]
        recent_canary = list(self.canary_metrics)[-self.config.min_samples:]
        
        if not recent_stable or not recent_canary:
            return True, "Insufficient metrics"
        
        # Calculate error rates
        stable_errors = sum(1 for m in recent_stable if not m['success'])
        canary_errors = sum(1 for m in recent_canary if not m['success'])
        
        stable_error_rate = stable_errors / len(recent_stable)
        canary_error_rate = canary_errors / len(recent_canary)
        
        # Calculate latencies
        stable_latencies = [m['latency'] for m in recent_stable if m['success']]
        canary_latencies = [m['latency'] for m in recent_canary if m['success']]
        
        if not stable_latencies or not canary_latencies:
            return True, "No successful inferences"
        
        stable_p95 = np.percentile(stable_latencies, 95)
        canary_p95 = np.percentile(canary_latencies, 95)
        
        # Check error rate
        if canary_error_rate > self.config.max_error_rate:
            return False, f"Error rate too high: {canary_error_rate*100:.2f}%"
        
        # Check latency regression
        latency_increase = (canary_p95 - stable_p95) / stable_p95
        if latency_increase > self.config.max_latency_increase:
            return False, f"Latency regression: {latency_increase*100:.1f}%"
        
        return True, f"Healthy (errors: {canary_error_rate*100:.2f}%, " \
                     f"latency: {latency_increase*100:.1f}%)"
    
    def update_deployment(self):
        """Update deployment progress"""
        
        if not self.deployment_active or self.rollback_triggered:
            return
        
        # Check if time for next increment
        if time.time() - self.last_increment < self.config.increment_interval_seconds:
            return
        
        # Evaluate canary health
        healthy, message = self.evaluate_canary_health()
        
        if not healthy:
            print(f"\n⚠ ROLLBACK TRIGGERED: {message}")
            self.rollback()
            return
        
        # Increment traffic if healthy
        if self.current_traffic_percent < 100.0:
            self.current_traffic_percent = min(
                100.0,
                self.current_traffic_percent + self.config.increment_percent
            )
            self.last_increment = time.time()
            
            print(f"\n✓ Canary healthy: {message}")
            print(f"  Incrementing traffic to {self.current_traffic_percent}%")
        
        # Complete deployment at 100%
        if self.current_traffic_percent >= 100.0:
            self.complete_deployment()
    
    def rollback(self):
        """Rollback to stable version"""
        
        self.rollback_triggered = True
        self.current_traffic_percent = 0.0
        
        print(f"Rolling back to stable version: {self.stable_model}")
        print("Canary deployment aborted")
        
        self.deployment_active = False
    
    def complete_deployment(self):
        """Complete canary deployment"""
        
        duration = time.time() - self.deployment_start
        
        print(f"\n{'='*60}")
        print(f"Canary deployment COMPLETED")
        print(f"Canary version promoted to stable: {self.canary_model}")
        print(f"Deployment duration: {duration/60:.1f} minutes")
        print(f"{'='*60}\n")
        
        self.stable_model = self.canary_model
        self.canary_model = None
        self.deployment_active = False
    
    def get_status(self):
        """Get deployment status"""
        
        if not self.deployment_active:
            return {
                'active': False,
                'stable_model': self.stable_model
            }
        
        healthy, message = self.evaluate_canary_health()
        
        return {
            'active': True,
            'stable_model': self.stable_model,
            'canary_model': self.canary_model,
            'traffic_percent': self.current_traffic_percent,
            'canary_healthy': healthy,
            'health_message': message,
            'rollback_triggered': self.rollback_triggered
        }

# Example usage
if __name__ == '__main__':
    config = CanaryConfig(
        initial_traffic_percent=5.0,
        increment_percent=10.0,
        increment_interval_seconds=10,  # 10 seconds for demo
        max_error_rate=0.02,
        max_latency_increase=0.15,
        min_samples=50
    )
    
    manager = CanaryDeploymentManager(config)
    manager.start_deployment('yolov8n-v1.0.0', 'yolov8n-v1.1.0')
    
    # Simulate production traffic
    for i in range(500):
        model_type = manager.route_request()
        
        # Simulate inference
        if model_type == 'stable':
            latency = np.random.normal(0.020, 0.003)  # 20ms ± 3ms
            success = random.random() > 0.005  # 0.5% error rate
        else:
            # Canary slightly better initially
            if i < 200:
                latency = np.random.normal(0.019, 0.003)  # 19ms ± 3ms
                success = random.random() > 0.004  # 0.4% error rate
            else:
                # Introduce regression for demo
                latency = np.random.normal(0.035, 0.005)  # 35ms (regression!)
                success = random.random() > 0.025  # 2.5% error rate (high!)
        
        manager.record_inference(model_type, latency, success)
        
        # Update deployment
        manager.update_deployment()
        
        # Print status periodically
        if i % 100 == 0:
            status = manager.get_status()
            print(f"\nIteration {i}: Traffic to canary: {status['traffic_percent']:.1f}%")
        
        time.sleep(0.02)
    
    # Final status
    final_status = manager.get_status()
    print(f"\nFinal status: {final_status}")

OTA Updates and Device Fleet Management

Managing 100+ distributed edge devices requires automated OTA update mechanisms with coordinated rollout strategies.

Fleet Management System (Python):

#!/usr/bin/env python3
"""
Fleet management for distributed edge devices
"""

import requests
import hashlib
import os
import time
from dataclasses import dataclass
from typing import List, Dict
from enum import Enum

class DeviceStatus(Enum):
    ONLINE = "online"
    OFFLINE = "offline"
    UPDATING = "updating"
    ERROR = "error"

@dataclass
class Device:
    device_id: str
    ip_address: str
    current_version: str
    target_version: str
    status: DeviceStatus
    last_seen: float
    update_progress: float = 0.0

class FleetManager:
    """
    Manage fleet of edge devices with OTA updates
    """
    
    def __init__(self, update_server_url):
        self.update_server_url = update_server_url
        self.devices: Dict[str, Device] = {}
    
    def register_device(self, device_id, ip_address, current_version):
        """Register device with fleet"""
        
        device = Device(
            device_id=device_id,
            ip_address=ip_address,
            current_version=current_version,
            target_version=current_version,
            status=DeviceStatus.ONLINE,
            last_seen=time.time()
        )
        
        self.devices[device_id] = device
        print(f"Registered device: {device_id} ({current_version})")
    
    def plan_rollout(self, target_version, strategy='rolling', batch_size=10):
        """Plan update rollout strategy"""
        
        devices_to_update = [
            d for d in self.devices.values()
            if d.current_version != target_version
        ]
        
        print(f"\nPlanning rollout to {target_version}")
        print(f"Devices to update: {len(devices_to_update)}")
        print(f"Strategy: {strategy}")
        
        if strategy == 'rolling':
            # Update in batches
            batches = [
                devices_to_update[i:i+batch_size]
                for i in range(0, len(devices_to_update), batch_size)
            ]
            
            print(f"Batches: {len(batches)}")
            return batches
        
        elif strategy == 'canary':
            # First 5% as canary
            canary_size = max(1, len(devices_to_update) // 20)
            canary = devices_to_update[:canary_size]
            remaining = devices_to_update[canary_size:]
            
            print(f"Canary devices: {len(canary)}")
            print(f"Remaining: {len(remaining)}")
            
            return [canary] + [
                remaining[i:i+batch_size]
                for i in range(0, len(remaining), batch_size)
            ]
        
        else:
            # All at once
            return [devices_to_update]
    
    def update_device(self, device: Device, target_version, model_url):
        """Trigger OTA update on device"""
        
        print(f"Updating {device.device_id} to {target_version}...")
        
        device.status = DeviceStatus.UPDATING
        device.target_version = target_version
        
        try:
            # Send update command to device
            response = requests.post(
                f"http://{device.ip_address}:8080/update",
                json={
                    'version': target_version,
                    'model_url': model_url,
                    'checksum': self.compute_checksum(model_url)
                },
                timeout=5.0
            )
            
            if response.status_code == 200:
                print(f"  ✓ Update initiated on {device.device_id}")
                return True
            else:
                print(f"  ✗ Update failed on {device.device_id}: {response.status_code}")
                device.status = DeviceStatus.ERROR
                return False
        
        except Exception as e:
            print(f"  ✗ Update failed on {device.device_id}: {e}")
            device.status = DeviceStatus.ERROR
            return False
    
    def compute_checksum(self, url):
        """Compute checksum of model file"""
        # Placeholder - would download and compute actual checksum
        return "dummy_checksum"
    
    def execute_rollout(self, target_version, model_url, strategy='rolling'):
        """Execute update rollout"""
        
        batches = self.plan_rollout(target_version, strategy=strategy, batch_size=10)
        
        for batch_idx, batch in enumerate(batches):
            print(f"\nBatch {batch_idx + 1}/{len(batches)}")
            
            # Update all devices in batch
            for device in batch:
                self.update_device(device, target_version, model_url)
            
            # Wait for batch to complete
            self.wait_for_batch_completion(batch)
            
            # Verify batch health
            if not self.verify_batch_health(batch):
                print(f"\n⚠ Batch {batch_idx + 1} failed health check!")
                print("Aborting rollout")
                return False
            
            print(f"✓ Batch {batch_idx + 1} completed successfully")
        
        print(f"\n{'='*60}")
        print(f"Rollout to {target_version} COMPLETED")
        print(f"{'='*60}")
        return True
    
    def wait_for_batch_completion(self, batch, timeout=300):
        """Wait for batch updates to complete"""
        
        start = time.time()
        
        while time.time() - start < timeout:
            # Check device statuses
            statuses = [self.get_device_status(d.device_id) for d in batch]
            
            updating = sum(1 for s in statuses if s == DeviceStatus.UPDATING)
            completed = sum(1 for s in statuses if s == DeviceStatus.ONLINE)
            errors = sum(1 for s in statuses if s == DeviceStatus.ERROR)
            
            print(f"  Progress: {completed}/{len(batch)} completed, "
                  f"{updating} updating, {errors} errors", end='\r')
            
            if completed + errors == len(batch):
                print()
                break
            
            time.sleep(5)
    
    def get_device_status(self, device_id):
        """Get device status (would query actual device)"""
        
        device = self.devices[device_id]
        
        # Simulate status update
        if device.status == DeviceStatus.UPDATING:
            # Simulate completion
            if time.time() - device.last_seen > 10:
                device.status = DeviceStatus.ONLINE
                device.current_version = device.target_version
        
        return device.status
    
    def verify_batch_health(self, batch):
        """Verify health of updated batch"""
        
        errors = sum(
            1 for d in batch 
            if self.devices[d.device_id].status == DeviceStatus.ERROR
        )
        
        error_rate = errors / len(batch)
        
        return error_rate < 0.05  # Allow 5% error rate
    
    def get_fleet_status(self):
        """Get overall fleet status"""
        
        version_counts = {}
        status_counts = {status: 0 for status in DeviceStatus}
        
        for device in self.devices.values():
            # Count versions
            if device.current_version not in version_counts:
                version_counts[device.current_version] = 0
            version_counts[device.current_version] += 1
            
            # Count statuses
            status_counts[device.status] += 1
        
        return {
            'total_devices': len(self.devices),
            'versions': version_counts,
            'statuses': {
                status.value: count 
                for status, count in status_counts.items()
            }
        }

# Example usage
if __name__ == '__main__':
    manager = FleetManager(update_server_url='https://updates.example.com')
    
    # Register devices
    for i in range(50):
        manager.register_device(
            device_id=f"jetson-{i:03d}",
            ip_address=f"192.168.1.{i+100}",
            current_version='v1.0.0'
        )
    
    # Execute rollout
    manager.execute_rollout(
        target_version='v1.1.0',
        model_url='https://models.example.com/yolov8n-v1.1.0.engine',
        strategy='rolling'
    )
    
    # Fleet status
    status = manager.get_fleet_status()
    print(f"\nFleet Status:")
    print(f"Total devices: {status['total_devices']}")
    print(f"Versions: {status['versions']}")
    print(f"Statuses: {status['statuses']}")

Key Takeaways

Production edge AI deployments demand comprehensive operational infrastructure beyond initial model deployment. Distributed monitoring through Prometheus and Jaeger provides essential visibility into fleet-wide performance, enabling proactive issue detection and rapid troubleshooting. Standardized metrics collection across devices enables aggregated analysis identifying systemic issues versus device-specific anomalies.

Data drift detection identifies model degradation requiring retraining before accuracy impacts become severe. Statistical monitoring of input distributions using KS tests and Wasserstein distance provides quantitative drift assessment. Establishing baselines from training data enables ongoing comparison detecting shifts of 10-15% triggering retraining workflows.

Model versioning with semantic versioning and centralized registry enables consistent deployment tracking across distributed fleets. Metadata preservation including accuracy metrics, latency characteristics, and training provenance supports informed rollback decisions. Canary deployments with automated health monitoring and rollback capabilities enable safe model updates with minimal risk. Gradual traffic shifting from 5% through 100% with continuous health evaluation prevents widespread failures from problematic models.

OTA update mechanisms with coordinated rollout strategies manage 100+ device fleets efficiently. Batch-based rolling updates limit blast radius while maintaining overall system availability. Coordinated update orchestration with health verification at each batch prevents cascade failures ensuring fleet stability throughout deployment windows.

This six-part series covered complete production edge CNN deployment from foundational architecture through operational excellence. Combining optimized model quantization, efficient inference servers, advanced resource management, and comprehensive operational tooling enables reliable, performant, scalable edge AI systems meeting real-world deployment requirements. Production maturity requires attention to entire lifecycle: training, optimization, deployment, monitoring, maintenance, and continuous improvement based on production feedback.

References

Written by:

534 Posts

View All Posts
Follow Me :
How to whitelist website on AdBlocker?

How to whitelist website on AdBlocker?

  1. 1 Click on the AdBlock Plus icon on the top right corner of your browser
  2. 2 Click on "Enabled on this site" from the AdBlock Plus option
  3. 3 Refresh the page and start browsing the site