Advanced Optimization Patterns: Concurrent Multi-Model Inference and Resource Management on Edge Hardware

Advanced Optimization Patterns: Concurrent Multi-Model Inference and Resource Management on Edge Hardware

Edge devices face stringent resource constraints requiring sophisticated optimization strategies beyond basic inference implementation. This post explores advanced optimization patterns for maximizing edge AI performance including memory-aware scheduling coordinating multiple concurrent models, GPU resource pooling minimizing contention, KV cache management for transformer-based models, adaptive batching dynamically grouping requests, thermal-aware resource allocation preventing throttling, and SLA enforcement ensuring quality of service. Production implementations demonstrate 50-70% latency reduction through intelligent resource coordination while maintaining accuracy and reliability.

Part 4 covered building multi-language inference servers with basic request handling. This post advances to production-scale optimization: understanding memory hierarchies and allocation strategies, implementing multi-model coordination without interference, designing adaptive batching systems, managing transformer KV caches efficiently, enforcing service level agreements under resource constraints, and measuring optimization effectiveness through comprehensive benchmarking.

Memory-Aware Scheduling for Multi-Model Deployments

Edge devices frequently serve multiple models simultaneously for different detection tasks, classification requirements, or customer applications. Naive concurrent execution causes GPU memory exhaustion and severe performance degradation.

flowchart TD
    A[Incoming Requests] --> B[Request Classifier]
    B --> C{Model Type}
    C -->|Detection| D[Detection Queue]
    C -->|Classification| E[Classification Queue]
    C -->|Segmentation| F[Segmentation Queue]
    
    D --> G[Memory Allocator]
    E --> G
    F --> G
    
    G --> H{Available Memory?}
    H -->|Yes| I[Load Model]
    H -->|No| J[Eviction Policy]
    J --> K[Unload Idle Model]
    K --> I
    
    I --> L[Execute Inference]
    L --> M[Release Memory]
    M --> N[Return Result]
    
    O[Memory Monitor] --> G
    P[Thermal Monitor] --> G

Multi-Model Memory Scheduler (Python):

#!/usr/bin/env python3
"""
Memory-aware multi-model scheduler for edge inference
"""

import torch
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import threading
import time
from collections import OrderedDict
from dataclasses import dataclass
from typing import Dict, Optional
import psutil

@dataclass
class ModelMetadata:
    name: str
    engine_path: str
    memory_required: int  # bytes
    priority: int
    last_used: float
    load_count: int
    inference_count: int

class MemoryAwareScheduler:
    """
    Multi-model scheduler with intelligent memory management
    """
    
    def __init__(self, max_gpu_memory_mb=4096, cache_size=3):
        self.max_gpu_memory = max_gpu_memory_mb * 1024 * 1024
        self.cache_size = cache_size
        self.loaded_models = OrderedDict()
        self.model_metadata: Dict[str, ModelMetadata] = {}
        self.lock = threading.Lock()
        
        # Memory tracking
        self.current_memory_usage = 0
        
        # Performance metrics
        self.cache_hits = 0
        self.cache_misses = 0
        self.evictions = 0
        
    def register_model(self, name: str, engine_path: str, 
                       priority: int = 1, estimated_memory_mb: int = 100):
        """Register a model for scheduling"""
        
        metadata = ModelMetadata(
            name=name,
            engine_path=engine_path,
            memory_required=estimated_memory_mb * 1024 * 1024,
            priority=priority,
            last_used=0,
            load_count=0,
            inference_count=0
        )
        
        self.model_metadata[name] = metadata
        print(f"Registered model: {name} ({estimated_memory_mb}MB, priority={priority})")
    
    def get_available_memory(self):
        """Get available GPU memory"""
        torch.cuda.synchronize()
        free_memory, total_memory = torch.cuda.mem_get_info()
        return free_memory
    
    def estimate_model_memory(self, engine_path: str):
        """Estimate memory requirements for a model"""
        import os
        
        # Engine file size is rough approximation
        engine_size = os.path.getsize(engine_path)
        
        # Add overhead for activation buffers (2-3x engine size)
        estimated_memory = engine_size * 2.5
        
        return int(estimated_memory)
    
    def select_eviction_candidate(self) -> Optional[str]:
        """Select model to evict using weighted LRU"""
        
        if not self.loaded_models:
            return None
        
        # Score models by: recency, priority, usage frequency
        scores = {}
        current_time = time.time()
        
        for name, model_data in self.loaded_models.items():
            metadata = self.model_metadata[name]
            
            # Time since last use (seconds)
            idle_time = current_time - metadata.last_used
            
            # Weighted score (lower is better for eviction)
            # High priority and recent use = low score (keep)
            # Low priority and old use = high score (evict)
            score = idle_time / (metadata.priority * (metadata.inference_count + 1))
            scores[name] = score
        
        # Return model with highest eviction score
        return max(scores.items(), key=lambda x: x[1])[0]
    
    def evict_model(self, name: str):
        """Evict a model from GPU memory"""
        
        if name not in self.loaded_models:
            return
        
        with self.lock:
            model_data = self.loaded_models[name]
            metadata = self.model_metadata[name]
            
            # Cleanup TensorRT resources
            if 'context' in model_data:
                del model_data['context']
            if 'engine' in model_data:
                del model_data['engine']
            if 'runtime' in model_data:
                del model_data['runtime']
            
            # Free device buffers
            if 'buffers' in model_data:
                for buffer in model_data['buffers']:
                    if 'device' in buffer:
                        buffer['device'].free()
            
            # Update memory tracking
            self.current_memory_usage -= metadata.memory_required
            
            # Remove from cache
            del self.loaded_models[name]
            self.evictions += 1
            
            print(f"Evicted model: {name} (freed ~{metadata.memory_required/1024/1024:.1f}MB)")
    
    def load_model(self, name: str):
        """Load model into GPU memory"""
        
        metadata = self.model_metadata[name]
        
        # Check if already loaded
        if name in self.loaded_models:
            self.cache_hits += 1
            metadata.last_used = time.time()
            metadata.inference_count += 1
            return self.loaded_models[name]
        
        self.cache_misses += 1
        
        # Make space if needed
        while self.current_memory_usage + metadata.memory_required > self.max_gpu_memory:
            if len(self.loaded_models) == 0:
                raise RuntimeError(f"Model {name} requires more memory than available")
            
            evict_candidate = self.select_eviction_candidate()
            self.evict_model(evict_candidate)
        
        # Enforce cache size limit
        while len(self.loaded_models) >= self.cache_size:
            evict_candidate = self.select_eviction_candidate()
            self.evict_model(evict_candidate)
        
        with self.lock:
            print(f"Loading model: {name}")
            
            # Load TensorRT engine
            logger = trt.Logger(trt.Logger.WARNING)
            runtime = trt.Runtime(logger)
            
            with open(metadata.engine_path, 'rb') as f:
                engine = runtime.deserialize_cuda_engine(f.read())
            
            context = engine.create_execution_context()
            
            # Allocate buffers
            buffers = []
            bindings = []
            
            for binding in engine:
                size = trt.volume(engine.get_binding_shape(binding))
                dtype = trt.nptype(engine.get_binding_dtype(binding))
                
                host_mem = cuda.pagelocked_empty(size, dtype)
                device_mem = cuda.mem_alloc(host_mem.nbytes)
                
                bindings.append(int(device_mem))
                
                buffer_info = {
                    'host': host_mem,
                    'device': device_mem,
                    'size': size,
                    'dtype': dtype
                }
                
                if engine.binding_is_input(binding):
                    buffer_info['type'] = 'input'
                else:
                    buffer_info['type'] = 'output'
                
                buffers.append(buffer_info)
            
            # Create CUDA stream
            stream = cuda.Stream()
            
            model_data = {
                'runtime': runtime,
                'engine': engine,
                'context': context,
                'buffers': buffers,
                'bindings': bindings,
                'stream': stream
            }
            
            # Update cache and metadata
            self.loaded_models[name] = model_data
            self.current_memory_usage += metadata.memory_required
            metadata.last_used = time.time()
            metadata.load_count += 1
            metadata.inference_count += 1
            
            print(f"Model loaded: {name} (memory: {self.current_memory_usage/1024/1024:.1f}MB / "
                  f"{self.max_gpu_memory/1024/1024:.1f}MB)")
            
            return model_data
    
    def infer(self, model_name: str, input_data):
        """Execute inference with automatic model loading"""
        
        model_data = self.load_model(model_name)
        
        # Prepare input
        input_buffer = model_data['buffers'][0]
        input_data_flat = input_data.ravel()
        
        # Copy to device
        cuda.memcpy_htod_async(
            input_buffer['device'],
            input_data_flat,
            model_data['stream']
        )
        
        # Execute
        model_data['context'].execute_async_v2(
            bindings=model_data['bindings'],
            stream_handle=model_data['stream'].handle
        )
        
        # Copy output
        output_buffer = model_data['buffers'][1]
        cuda.memcpy_dtoh_async(
            output_buffer['host'],
            output_buffer['device'],
            model_data['stream']
        )
        
        model_data['stream'].synchronize()
        
        # Update metadata
        self.model_metadata[model_name].last_used = time.time()
        
        return output_buffer['host']
    
    def get_statistics(self):
        """Get scheduler performance statistics"""
        
        total_requests = self.cache_hits + self.cache_misses
        hit_rate = self.cache_hits / total_requests if total_requests > 0 else 0
        
        return {
            'loaded_models': len(self.loaded_models),
            'memory_usage_mb': self.current_memory_usage / 1024 / 1024,
            'max_memory_mb': self.max_gpu_memory / 1024 / 1024,
            'cache_hits': self.cache_hits,
            'cache_misses': self.cache_misses,
            'hit_rate': hit_rate,
            'evictions': self.evictions,
            'model_stats': {
                name: {
                    'loads': meta.load_count,
                    'inferences': meta.inference_count,
                    'last_used': time.time() - meta.last_used
                }
                for name, meta in self.model_metadata.items()
            }
        }

# Example usage
if __name__ == '__main__':
    scheduler = MemoryAwareScheduler(max_gpu_memory_mb=4096, cache_size=3)
    
    # Register models
    scheduler.register_model('yolov8n', 'yolov8n_int8.engine', priority=3, estimated_memory_mb=80)
    scheduler.register_model('yolov8s', 'yolov8s_int8.engine', priority=2, estimated_memory_mb=140)
    scheduler.register_model('resnet50', 'resnet50_int8.engine', priority=1, estimated_memory_mb=120)
    
    # Simulate inference requests
    for i in range(100):
        model_name = ['yolov8n', 'yolov8s', 'resnet50'][i % 3]
        dummy_input = torch.randn(1, 3, 640, 640).cuda().cpu().numpy()
        
        result = scheduler.infer(model_name, dummy_input)
        
        if i % 20 == 0:
            stats = scheduler.get_statistics()
            print(f"\nIteration {i}: Hit rate = {stats['hit_rate']:.2%}, "
                  f"Evictions = {stats['evictions']}")

GPU Resource Pooling and Contention Management

Multiple concurrent inference requests compete for limited GPU compute resources. Resource pooling with intelligent scheduling prevents contention-induced latency spikes.

GPU Resource Pool Manager (Python):

#!/usr/bin/env python3
"""
GPU resource pool with contention management
"""

import torch
import threading
import queue
import time
from dataclasses import dataclass
from typing import List, Optional
from enum import Enum

class Priority(Enum):
    LOW = 0
    NORMAL = 1
    HIGH = 2
    CRITICAL = 3

@dataclass
class InferenceTask:
    task_id: str
    model_name: str
    input_data: any
    priority: Priority
    submit_time: float
    deadline: Optional[float] = None
    callback: Optional[callable] = None

class GPUResourcePool:
    """
    GPU resource pool with priority scheduling and contention management
    """
    
    def __init__(self, num_streams=4, enable_cuda_graphs=True):
        self.num_streams = num_streams
        self.enable_cuda_graphs = enable_cuda_graphs
        
        # Create CUDA streams for parallel execution
        self.streams = [torch.cuda.Stream() for _ in range(num_streams)]
        self.stream_busy = [False] * num_streams
        self.stream_lock = threading.Lock()
        
        # Priority queues for different request types
        self.queues = {
            Priority.CRITICAL: queue.PriorityQueue(),
            Priority.HIGH: queue.PriorityQueue(),
            Priority.NORMAL: queue.PriorityQueue(),
            Priority.LOW: queue.PriorityQueue()
        }
        
        # Performance tracking
        self.tasks_completed = 0
        self.tasks_dropped = 0
        self.total_queue_time = 0
        self.total_execution_time = 0
        
        # Worker threads
        self.workers = []
        self.running = True
        
        self.start_workers()
    
    def start_workers(self):
        """Start worker threads for each stream"""
        for i in range(self.num_streams):
            worker = threading.Thread(target=self.worker_loop, args=(i,))
            worker.daemon = True
            worker.start()
            self.workers.append(worker)
    
    def submit_task(self, task: InferenceTask):
        """Submit inference task to appropriate priority queue"""
        
        # Check deadline if specified
        if task.deadline and time.time() > task.deadline:
            self.tasks_dropped += 1
            print(f"Task {task.task_id} dropped (deadline exceeded)")
            return
        
        # Add to priority queue
        # Use negative timestamp for FIFO within same priority
        priority_value = (task.priority.value, -task.submit_time)
        self.queues[task.priority].put((priority_value, task))
    
    def get_next_task(self) -> Optional[InferenceTask]:
        """Get next task respecting priority order"""
        
        # Check priorities in order: CRITICAL -> HIGH -> NORMAL -> LOW
        for priority in [Priority.CRITICAL, Priority.HIGH, Priority.NORMAL, Priority.LOW]:
            try:
                _, task = self.queues[priority].get_nowait()
                return task
            except queue.Empty:
                continue
        
        return None
    
    def acquire_stream(self) -> Optional[int]:
        """Acquire an available CUDA stream"""
        
        with self.stream_lock:
            for i, busy in enumerate(self.stream_busy):
                if not busy:
                    self.stream_busy[i] = True
                    return i
        
        return None
    
    def release_stream(self, stream_id: int):
        """Release CUDA stream"""
        
        with self.stream_lock:
            self.stream_busy[stream_id] = False
    
    def worker_loop(self, worker_id: int):
        """Worker thread processing tasks"""
        
        while self.running:
            # Get stream
            stream_id = None
            while stream_id is None and self.running:
                stream_id = self.acquire_stream()
                if stream_id is None:
                    time.sleep(0.001)
            
            if not self.running:
                break
            
            # Get next task
            task = self.get_next_task()
            
            if task is None:
                self.release_stream(stream_id)
                time.sleep(0.001)
                continue
            
            # Check deadline again
            if task.deadline and time.time() > task.deadline:
                self.tasks_dropped += 1
                self.release_stream(stream_id)
                print(f"Task {task.task_id} dropped at execution (deadline exceeded)")
                continue
            
            # Execute task
            queue_time = time.time() - task.submit_time
            self.total_queue_time += queue_time
            
            try:
                start_time = time.time()
                
                # Execute on assigned stream
                with torch.cuda.stream(self.streams[stream_id]):
                    result = self.execute_inference(task)
                
                execution_time = time.time() - start_time
                self.total_execution_time += execution_time
                self.tasks_completed += 1
                
                # Callback if provided
                if task.callback:
                    task.callback(result, queue_time, execution_time)
                
            except Exception as e:
                print(f"Task {task.task_id} failed: {e}")
            
            finally:
                self.release_stream(stream_id)
    
    def execute_inference(self, task: InferenceTask):
        """Execute inference (placeholder - integrate with scheduler)"""
        
        # Simulate inference
        time.sleep(0.02)  # 20ms
        
        return {'task_id': task.task_id, 'detections': []}
    
    def get_statistics(self):
        """Get resource pool statistics"""
        
        total_tasks = self.tasks_completed + self.tasks_dropped
        
        return {
            'tasks_completed': self.tasks_completed,
            'tasks_dropped': self.tasks_dropped,
            'drop_rate': self.tasks_dropped / total_tasks if total_tasks > 0 else 0,
            'avg_queue_time': self.total_queue_time / self.tasks_completed if self.tasks_completed > 0 else 0,
            'avg_execution_time': self.total_execution_time / self.tasks_completed if self.tasks_completed > 0 else 0,
            'queue_depths': {
                priority.name: self.queues[priority].qsize()
                for priority in Priority
            }
        }
    
    def shutdown(self):
        """Shutdown resource pool"""
        self.running = False
        for worker in self.workers:
            worker.join()

# Example usage
if __name__ == '__main__':
    pool = GPUResourcePool(num_streams=4)
    
    # Submit tasks with different priorities
    for i in range(100):
        priority = Priority.HIGH if i % 10 == 0 else Priority.NORMAL
        deadline = time.time() + 0.5  # 500ms deadline
        
        task = InferenceTask(
            task_id=f"task_{i}",
            model_name='yolov8n',
            input_data=None,
            priority=priority,
            submit_time=time.time(),
            deadline=deadline
        )
        
        pool.submit_task(task)
        time.sleep(0.01)  # 10ms between submissions
    
    # Wait for completion
    time.sleep(5)
    
    stats = pool.get_statistics()
    print(f"\nPool Statistics:")
    print(f"Completed: {stats['tasks_completed']}")
    print(f"Dropped: {stats['tasks_dropped']}")
    print(f"Drop rate: {stats['drop_rate']:.2%}")
    print(f"Avg queue time: {stats['avg_queue_time']*1000:.2f}ms")
    print(f"Avg execution time: {stats['avg_execution_time']*1000:.2f}ms")
    
    pool.shutdown()

Adaptive Batching for Throughput Optimization

Batching multiple inference requests reduces per-request overhead and increases GPU utilization. Adaptive batching balances throughput gains against latency constraints.

Adaptive Batch Scheduler (Python):

#!/usr/bin/env python3
"""
Adaptive batching for optimized throughput
"""

import torch
import threading
import time
import numpy as np
from collections import deque
from dataclasses import dataclass

@dataclass
class BatchConfig:
    min_batch_size: int = 1
    max_batch_size: int = 8
    max_wait_time: float = 0.01  # 10ms
    target_latency: float = 0.05  # 50ms

class AdaptiveBatchScheduler:
    """
    Adaptive batching with dynamic batch size adjustment
    """
    
    def __init__(self, config: BatchConfig):
        self.config = config
        self.pending_requests = deque()
        self.lock = threading.Lock()
        self.condition = threading.Condition(self.lock)
        
        # Performance tracking for adaptation
        self.latency_history = deque(maxlen=100)
        self.batch_size_history = deque(maxlen=100)
        self.throughput_history = deque(maxlen=100)
        
        # Current adaptive parameters
        self.current_batch_size = config.min_batch_size
        self.current_wait_time = config.max_wait_time
        
        # Batch processing thread
        self.running = True
        self.processor_thread = threading.Thread(target=self.batch_processor)
        self.processor_thread.daemon = True
        self.processor_thread.start()
    
    def submit_request(self, input_data, callback):
        """Submit inference request"""
        
        request = {
            'input': input_data,
            'callback': callback,
            'submit_time': time.time()
        }
        
        with self.condition:
            self.pending_requests.append(request)
            self.condition.notify()
    
    def get_batch(self):
        """Collect requests into batch"""
        
        with self.condition:
            # Wait for requests or timeout
            deadline = time.time() + self.current_wait_time
            
            while len(self.pending_requests) < self.current_batch_size:
                remaining = deadline - time.time()
                if remaining <= 0:
                    break
                
                self.condition.wait(timeout=remaining)
            
            # Collect batch
            batch_size = min(len(self.pending_requests), self.current_batch_size)
            
            if batch_size == 0:
                return None
            
            batch = []
            for _ in range(batch_size):
                batch.append(self.pending_requests.popleft())
            
            return batch
    
    def batch_processor(self):
        """Background batch processing thread"""
        
        while self.running:
            batch = self.get_batch()
            
            if batch is None:
                time.sleep(0.001)
                continue
            
            # Process batch
            batch_start = time.time()
            results = self.execute_batch(batch)
            batch_duration = time.time() - batch_start
            
            # Calculate metrics
            avg_latency = np.mean([time.time() - req['submit_time'] for req in batch])
            throughput = len(batch) / batch_duration
            
            # Record history
            self.latency_history.append(avg_latency)
            self.batch_size_history.append(len(batch))
            self.throughput_history.append(throughput)
            
            # Adapt batch parameters
            self.adapt_batch_parameters(avg_latency, throughput)
            
            # Invoke callbacks
            for req, result in zip(batch, results):
                if req['callback']:
                    latency = time.time() - req['submit_time']
                    req['callback'](result, latency)
    
    def execute_batch(self, batch):
        """Execute batched inference"""
        
        # Stack inputs
        inputs = np.stack([req['input'] for req in batch])
        inputs_tensor = torch.from_numpy(inputs).cuda()
        
        # Simulated batched inference
        time.sleep(0.015 * (1 + len(batch) * 0.1))  # Simulated batch processing
        
        # Return per-request results
        results = [{'detections': []} for _ in batch]
        return results
    
    def adapt_batch_parameters(self, current_latency, current_throughput):
        """Adapt batch size and wait time based on performance"""
        
        if len(self.latency_history) < 10:
            return
        
        avg_latency = np.mean(list(self.latency_history)[-10:])
        avg_throughput = np.mean(list(self.throughput_history)[-10:])
        
        # If latency too high, reduce batch size
        if avg_latency > self.config.target_latency * 1.2:
            self.current_batch_size = max(
                self.config.min_batch_size,
                self.current_batch_size - 1
            )
            self.current_wait_time = max(
                self.config.max_wait_time * 0.5,
                self.current_wait_time * 0.9
            )
        
        # If latency acceptable and queue building, increase batch size
        elif avg_latency < self.config.target_latency * 0.8 and len(self.pending_requests) > 5:
            self.current_batch_size = min(
                self.config.max_batch_size,
                self.current_batch_size + 1
            )
            self.current_wait_time = min(
                self.config.max_wait_time,
                self.current_wait_time * 1.1
            )
    
    def get_statistics(self):
        """Get batching statistics"""
        
        if not self.latency_history:
            return {}
        
        return {
            'current_batch_size': self.current_batch_size,
            'current_wait_time': self.current_wait_time * 1000,  # ms
            'avg_latency': np.mean(list(self.latency_history)) * 1000,  # ms
            'p95_latency': np.percentile(list(self.latency_history), 95) * 1000,  # ms
            'avg_throughput': np.mean(list(self.throughput_history)),  # req/s
            'avg_batch_size': np.mean(list(self.batch_size_history))
        }
    
    def shutdown(self):
        """Shutdown scheduler"""
        self.running = False
        with self.condition:
            self.condition.notify()
        self.processor_thread.join()

# Example usage
if __name__ == '__main__':
    config = BatchConfig(
        min_batch_size=1,
        max_batch_size=8,
        max_wait_time=0.01,
        target_latency=0.05
    )
    
    scheduler = AdaptiveBatchScheduler(config)
    
    completed = []
    def callback(result, latency):
        completed.append(latency)
    
    # Simulate varying load
    for i in range(200):
        input_data = np.random.randn(1, 3, 640, 640).astype(np.float32)
        scheduler.submit_request(input_data, callback)
        
        # Variable request rate
        if i < 50:
            time.sleep(0.02)  # Low load
        elif i < 100:
            time.sleep(0.005)  # Medium load
        elif i < 150:
            time.sleep(0.002)  # High load
        else:
            time.sleep(0.01)  # Medium load
        
        if i % 50 == 0 and i > 0:
            stats = scheduler.get_statistics()
            print(f"\nIteration {i}:")
            print(f"  Batch size: {stats['current_batch_size']}")
            print(f"  Avg latency: {stats['avg_latency']:.2f}ms")
            print(f"  P95 latency: {stats['p95_latency']:.2f}ms")
            print(f"  Throughput: {stats['avg_throughput']:.1f} req/s")
    
    time.sleep(2)
    scheduler.shutdown()
    
    print(f"\nCompleted {len(completed)} requests")
    print(f"Average latency: {np.mean(completed)*1000:.2f}ms")

KV Cache Management for Transformer Models

Transformer-based models like vision transformers or multimodal models require KV cache management for efficient sequence processing on edge devices.

KV Cache Manager (Python):

#!/usr/bin/env python3
"""
KV cache management for transformer models on edge
"""

import torch
import threading
from collections import OrderedDict
from dataclasses import dataclass
import time

@dataclass
class CacheEntry:
    key_cache: torch.Tensor
    value_cache: torch.Tensor
    sequence_length: int
    last_accessed: float
    access_count: int

class KVCacheManager:
    """
    Manages KV caches for transformer models with memory-efficient strategies
    """
    
    def __init__(self, max_cache_memory_mb=512, num_layers=12, 
                 hidden_size=768, num_heads=12):
        self.max_cache_memory = max_cache_memory_mb * 1024 * 1024
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        
        # Cache storage
        self.caches = OrderedDict()
        self.cache_lock = threading.Lock()
        
        # Memory tracking
        self.current_memory_usage = 0
        
        # Statistics
        self.cache_hits = 0
        self.cache_misses = 0
        self.evictions = 0
    
    def estimate_cache_size(self, sequence_length, batch_size=1):
        """Estimate memory required for KV cache"""
        
        # Each layer stores K and V caches
        # Shape: [batch_size, num_heads, sequence_length, head_dim]
        bytes_per_element = 2  # FP16
        
        k_cache_size = (batch_size * self.num_heads * sequence_length * 
                       self.head_dim * bytes_per_element)
        v_cache_size = k_cache_size
        
        total_size = (k_cache_size + v_cache_size) * self.num_layers
        
        return total_size
    
    def allocate_cache(self, session_id: str, max_sequence_length: int, 
                      batch_size: int = 1):
        """Allocate KV cache for a session"""
        
        cache_size = self.estimate_cache_size(max_sequence_length, batch_size)
        
        # Make space if needed
        while self.current_memory_usage + cache_size > self.max_cache_memory:
            if len(self.caches) == 0:
                raise RuntimeError("Cannot allocate cache: insufficient memory")
            
            evict_id = self.select_eviction_candidate()
            self.evict_cache(evict_id)
        
        with self.cache_lock:
            # Allocate cache tensors
            k_caches = []
            v_caches = []
            
            for _ in range(self.num_layers):
                k_cache = torch.zeros(
                    batch_size, self.num_heads, max_sequence_length, self.head_dim,
                    dtype=torch.float16,
                    device='cuda'
                )
                v_cache = torch.zeros(
                    batch_size, self.num_heads, max_sequence_length, self.head_dim,
                    dtype=torch.float16,
                    device='cuda'
                )
                
                k_caches.append(k_cache)
                v_caches.append(v_cache)
            
            entry = CacheEntry(
                key_cache=k_caches,
                value_cache=v_caches,
                sequence_length=0,
                last_accessed=time.time(),
                access_count=0
            )
            
            self.caches[session_id] = entry
            self.current_memory_usage += cache_size
            
            print(f"Allocated cache for {session_id}: {cache_size/1024/1024:.1f}MB")
    
    def get_cache(self, session_id: str):
        """Retrieve cache for session"""
        
        with self.cache_lock:
            if session_id not in self.caches:
                self.cache_misses += 1
                return None
            
            self.cache_hits += 1
            entry = self.caches[session_id]
            entry.last_accessed = time.time()
            entry.access_count += 1
            
            return entry
    
    def update_cache(self, session_id: str, new_keys, new_values, position: int):
        """Update cache with new key-value pairs"""
        
        entry = self.get_cache(session_id)
        if entry is None:
            raise ValueError(f"No cache found for session {session_id}")
        
        # Update caches for each layer
        for layer_idx, (k, v) in enumerate(zip(new_keys, new_values)):
            entry.key_cache[layer_idx][:, :, position:position+k.shape[2], :] = k
            entry.value_cache[layer_idx][:, :, position:position+v.shape[2], :] = v
        
        entry.sequence_length = max(entry.sequence_length, position + new_keys[0].shape[2])
    
    def select_eviction_candidate(self):
        """Select cache to evict using LRU"""
        
        if not self.caches:
            return None
        
        # Find least recently used
        oldest_id = None
        oldest_time = float('inf')
        
        for session_id, entry in self.caches.items():
            if entry.last_accessed < oldest_time:
                oldest_time = entry.last_accessed
                oldest_id = session_id
        
        return oldest_id
    
    def evict_cache(self, session_id: str):
        """Evict cache for session"""
        
        if session_id not in self.caches:
            return
        
        with self.cache_lock:
            entry = self.caches[session_id]
            
            # Free GPU memory
            for k_cache, v_cache in zip(entry.key_cache, entry.value_cache):
                del k_cache
                del v_cache
            
            cache_size = self.estimate_cache_size(entry.sequence_length)
            self.current_memory_usage -= cache_size
            
            del self.caches[session_id]
            self.evictions += 1
            
            print(f"Evicted cache for {session_id}")
    
    def compress_cache(self, session_id: str, compression_ratio: float = 0.5):
        """Compress cache by removing less important positions"""
        
        entry = self.get_cache(session_id)
        if entry is None:
            return
        
        # Simple compression: keep most recent tokens
        keep_length = int(entry.sequence_length * compression_ratio)
        start_position = entry.sequence_length - keep_length
        
        with self.cache_lock:
            for layer_idx in range(self.num_layers):
                # Shift cache to beginning
                entry.key_cache[layer_idx][:, :, :keep_length, :] = \
                    entry.key_cache[layer_idx][:, :, start_position:, :]
                entry.value_cache[layer_idx][:, :, :keep_length, :] = \
                    entry.value_cache[layer_idx][:, :, start_position:, :]
                
                # Zero out rest
                entry.key_cache[layer_idx][:, :, keep_length:, :] = 0
                entry.value_cache[layer_idx][:, :, keep_length:, :] = 0
            
            entry.sequence_length = keep_length
            
            print(f"Compressed cache for {session_id} to {keep_length} tokens")
    
    def get_statistics(self):
        """Get cache manager statistics"""
        
        total_requests = self.cache_hits + self.cache_misses
        hit_rate = self.cache_hits / total_requests if total_requests > 0 else 0
        
        return {
            'active_caches': len(self.caches),
            'memory_usage_mb': self.current_memory_usage / 1024 / 1024,
            'max_memory_mb': self.max_cache_memory / 1024 / 1024,
            'cache_hits': self.cache_hits,
            'cache_misses': self.cache_misses,
            'hit_rate': hit_rate,
            'evictions': self.evictions
        }

# Example usage
if __name__ == '__main__':
    manager = KVCacheManager(
        max_cache_memory_mb=512,
        num_layers=12,
        hidden_size=768,
        num_heads=12
    )
    
    # Allocate caches for multiple sessions
    for i in range(5):
        session_id = f"session_{i}"
        manager.allocate_cache(session_id, max_sequence_length=512)
    
    # Simulate cache usage
    for iteration in range(100):
        session_id = f"session_{iteration % 5}"
        entry = manager.get_cache(session_id)
        
        if entry:
            # Simulate adding new tokens
            new_position = entry.sequence_length
            if new_position < 512:
                # Dummy key/value tensors
                new_keys = [torch.randn(1, 12, 1, 64, dtype=torch.float16, device='cuda') 
                           for _ in range(12)]
                new_values = [torch.randn(1, 12, 1, 64, dtype=torch.float16, device='cuda') 
                             for _ in range(12)]
                
                manager.update_cache(session_id, new_keys, new_values, new_position)
        
        if iteration % 20 == 0:
            stats = manager.get_statistics()
            print(f"\nIteration {iteration}:")
            print(f"  Active caches: {stats['active_caches']}")
            print(f"  Memory usage: {stats['memory_usage_mb']:.1f}MB")
            print(f"  Hit rate: {stats['hit_rate']:.2%}")

SLA Enforcement and Performance Guarantees

Production edge deployments require service level agreement enforcement ensuring consistent performance under varying load conditions.

SLA Enforcement Manager (Python):

#!/usr/bin/env python3
"""
SLA enforcement for edge inference services
"""

import time
import threading
from collections import deque
from dataclasses import dataclass
from enum import Enum
import numpy as np

class SLAViolationType(Enum):
    LATENCY = "latency"
    THROUGHPUT = "throughput"
    AVAILABILITY = "availability"
    ERROR_RATE = "error_rate"

@dataclass
class SLAPolicy:
    max_latency_p95: float = 0.05  # 50ms
    max_latency_p99: float = 0.1   # 100ms
    min_throughput: float = 20.0   # req/s
    max_error_rate: float = 0.01   # 1%
    min_availability: float = 0.999  # 99.9%

class SLAEnforcementManager:
    """
    Monitors and enforces SLA policies for inference service
    """
    
    def __init__(self, policy: SLAPolicy, window_size=60):
        self.policy = policy
        self.window_size = window_size
        
        # Metrics collection
        self.latencies = deque(maxlen=1000)
        self.errors = deque(maxlen=1000)
        self.requests = deque(maxlen=1000)
        self.timestamps = deque(maxlen=1000)
        
        # SLA status
        self.violations = []
        self.lock = threading.Lock()
        
        # Monitoring thread
        self.running = True
        self.monitor_thread = threading.Thread(target=self.monitor_loop)
        self.monitor_thread.daemon = True
        self.monitor_thread.start()
    
    def record_request(self, latency: float, success: bool):
        """Record request metrics"""
        
        with self.lock:
            self.latencies.append(latency)
            self.errors.append(0 if success else 1)
            self.requests.append(1)
            self.timestamps.append(time.time())
    
    def get_window_metrics(self):
        """Calculate metrics for current time window"""
        
        current_time = time.time()
        window_start = current_time - self.window_size
        
        # Filter metrics to window
        window_latencies = []
        window_errors = []
        window_requests = 0
        
        for timestamp, latency, error in zip(self.timestamps, self.latencies, self.errors):
            if timestamp >= window_start:
                window_latencies.append(latency)
                window_errors.append(error)
                window_requests += 1
        
        if not window_latencies:
            return None
        
        metrics = {
            'latency_p95': np.percentile(window_latencies, 95),
            'latency_p99': np.percentile(window_latencies, 99),
            'latency_mean': np.mean(window_latencies),
            'throughput': window_requests / self.window_size,
            'error_rate': np.mean(window_errors),
            'availability': 1.0 - np.mean(window_errors),
            'total_requests': window_requests
        }
        
        return metrics
    
    def check_sla_compliance(self):
        """Check if current metrics comply with SLA"""
        
        metrics = self.get_window_metrics()
        
        if metrics is None:
            return True, []
        
        violations = []
        
        # Check latency SLAs
        if metrics['latency_p95'] > self.policy.max_latency_p95:
            violations.append({
                'type': SLAViolationType.LATENCY,
                'metric': 'p95',
                'value': metrics['latency_p95'],
                'threshold': self.policy.max_latency_p95,
                'severity': 'medium'
            })
        
        if metrics['latency_p99'] > self.policy.max_latency_p99:
            violations.append({
                'type': SLAViolationType.LATENCY,
                'metric': 'p99',
                'value': metrics['latency_p99'],
                'threshold': self.policy.max_latency_p99,
                'severity': 'high'
            })
        
        # Check throughput SLA
        if metrics['throughput'] < self.policy.min_throughput:
            violations.append({
                'type': SLAViolationType.THROUGHPUT,
                'metric': 'throughput',
                'value': metrics['throughput'],
                'threshold': self.policy.min_throughput,
                'severity': 'medium'
            })
        
        # Check error rate SLA
        if metrics['error_rate'] > self.policy.max_error_rate:
            violations.append({
                'type': SLAViolationType.ERROR_RATE,
                'metric': 'error_rate',
                'value': metrics['error_rate'],
                'threshold': self.policy.max_error_rate,
                'severity': 'high'
            })
        
        # Check availability SLA
        if metrics['availability'] < self.policy.min_availability:
            violations.append({
                'type': SLAViolationType.AVAILABILITY,
                'metric': 'availability',
                'value': metrics['availability'],
                'threshold': self.policy.min_availability,
                'severity': 'critical'
            })
        
        return len(violations) == 0, violations
    
    def apply_mitigation(self, violations):
        """Apply mitigation strategies for SLA violations"""
        
        for violation in violations:
            if violation['type'] == SLAViolationType.LATENCY:
                # Reduce batch size, increase worker threads
                print(f"Mitigating latency violation: "
                      f"{violation['metric']}={violation['value']*1000:.1f}ms "
                      f"(threshold={violation['threshold']*1000:.1f}ms)")
                # Implementation would adjust scheduler parameters
            
            elif violation['type'] == SLAViolationType.THROUGHPUT:
                # Increase parallelism, enable batching
                print(f"Mitigating throughput violation: "
                      f"{violation['value']:.1f} req/s "
                      f"(threshold={violation['threshold']:.1f} req/s)")
            
            elif violation['type'] == SLAViolationType.ERROR_RATE:
                # Enable retries, circuit breaker
                print(f"Mitigating error rate violation: "
                      f"{violation['value']*100:.2f}% "
                      f"(threshold={violation['threshold']*100:.2f}%)")
            
            elif violation['type'] == SLAViolationType.AVAILABILITY:
                # Failover, redundancy
                print(f"CRITICAL: Availability violation: "
                      f"{violation['value']*100:.3f}% "
                      f"(threshold={violation['threshold']*100:.3f}%)")
        
        # Record violations
        with self.lock:
            self.violations.extend(violations)
    
    def monitor_loop(self):
        """Background monitoring thread"""
        
        while self.running:
            time.sleep(5)  # Check every 5 seconds
            
            compliant, violations = self.check_sla_compliance()
            
            if not compliant:
                self.apply_mitigation(violations)
            
            # Print status
            metrics = self.get_window_metrics()
            if metrics:
                print(f"\nSLA Status: {'COMPLIANT' if compliant else 'VIOLATED'}")
                print(f"  P95 Latency: {metrics['latency_p95']*1000:.1f}ms")
                print(f"  Throughput: {metrics['throughput']:.1f} req/s")
                print(f"  Error Rate: {metrics['error_rate']*100:.2f}%")
    
    def get_sla_report(self):
        """Generate SLA compliance report"""
        
        metrics = self.get_window_metrics()
        compliant, current_violations = self.check_sla_compliance()
        
        return {
            'compliant': compliant,
            'current_metrics': metrics,
            'current_violations': current_violations,
            'total_violations': len(self.violations),
            'violation_history': self.violations[-10:]  # Last 10
        }
    
    def shutdown(self):
        """Shutdown monitor"""
        self.running = False
        self.monitor_thread.join()

# Example usage
if __name__ == '__main__':
    policy = SLAPolicy(
        max_latency_p95=0.05,
        max_latency_p99=0.1,
        min_throughput=20.0,
        max_error_rate=0.01
    )
    
    manager = SLAEnforcementManager(policy, window_size=30)
    
    # Simulate requests with varying performance
    for i in range(200):
        # Simulate degrading performance
        if i < 50:
            latency = np.random.normal(0.02, 0.005)  # Good
            success = np.random.random() > 0.001
        elif i < 100:
            latency = np.random.normal(0.04, 0.01)  # Degrading
            success = np.random.random() > 0.005
        elif i < 150:
            latency = np.random.normal(0.08, 0.02)  # Violating
            success = np.random.random() > 0.02
        else:
            latency = np.random.normal(0.03, 0.008)  # Recovering
            success = np.random.random() > 0.002
        
        manager.record_request(latency, success)
        time.sleep(0.1)
    
    # Final report
    report = manager.get_sla_report()
    print(f"\n{'='*60}")
    print(f"Final SLA Report:")
    print(f"Compliant: {report['compliant']}")
    print(f"Total Violations: {report['total_violations']}")
    
    manager.shutdown()

Key Takeaways

Advanced optimization patterns enable production edge AI deployments to maximize resource utilization while maintaining quality of service. Memory-aware scheduling coordinates multiple concurrent models through intelligent caching and eviction policies preventing GPU memory exhaustion. Implementation typically achieves 80-90% cache hit rates reducing model load overhead by 50-70%.

GPU resource pooling with priority-based scheduling prevents contention-induced latency spikes. CUDA stream management enables parallel execution of independent requests achieving 2-3x throughput improvement over sequential processing. Adaptive batching dynamically adjusts batch sizes balancing throughput gains against latency constraints, typically improving throughput by 40-60% while maintaining sub-50ms P95 latency.

KV cache management for transformer models reduces memory footprint through intelligent compression and eviction strategies. Proper cache management enables 50-100 concurrent transformer sessions on 8GB edge devices. SLA enforcement mechanisms monitor performance metrics in real-time applying automatic mitigation strategies maintaining service quality under varying load conditions.

Combined optimization strategies demonstrate 50-70% end-to-end latency reduction compared to naive implementations while improving throughput by 2-4x. Production deployments require careful tuning of optimization parameters based on specific workload characteristics, hardware capabilities, and service level requirements. Comprehensive monitoring and adaptive strategies ensure consistent performance as deployment conditions evolve.

Part 6 concludes the series with production operations covering Prometheus/Jaeger monitoring integration, data drift detection mechanisms, model versioning strategies, canary deployment patterns, OTA update procedures, comprehensive health checking, feedback loop implementation, and orchestration patterns for managing 100+ distributed edge devices.

References

Written by:

534 Posts

View All Posts
Follow Me :
How to whitelist website on AdBlocker?

How to whitelist website on AdBlocker?

  1. 1 Click on the AdBlock Plus icon on the top right corner of your browser
  2. 2 Click on "Enabled on this site" from the AdBlock Plus option
  3. 3 Refresh the page and start browsing the site