Edge devices face stringent resource constraints requiring sophisticated optimization strategies beyond basic inference implementation. This post explores advanced optimization patterns for maximizing edge AI performance including memory-aware scheduling coordinating multiple concurrent models, GPU resource pooling minimizing contention, KV cache management for transformer-based models, adaptive batching dynamically grouping requests, thermal-aware resource allocation preventing throttling, and SLA enforcement ensuring quality of service. Production implementations demonstrate 50-70% latency reduction through intelligent resource coordination while maintaining accuracy and reliability.
Part 4 covered building multi-language inference servers with basic request handling. This post advances to production-scale optimization: understanding memory hierarchies and allocation strategies, implementing multi-model coordination without interference, designing adaptive batching systems, managing transformer KV caches efficiently, enforcing service level agreements under resource constraints, and measuring optimization effectiveness through comprehensive benchmarking.
Memory-Aware Scheduling for Multi-Model Deployments
Edge devices frequently serve multiple models simultaneously for different detection tasks, classification requirements, or customer applications. Naive concurrent execution causes GPU memory exhaustion and severe performance degradation.
flowchart TD
A[Incoming Requests] --> B[Request Classifier]
B --> C{Model Type}
C -->|Detection| D[Detection Queue]
C -->|Classification| E[Classification Queue]
C -->|Segmentation| F[Segmentation Queue]
D --> G[Memory Allocator]
E --> G
F --> G
G --> H{Available Memory?}
H -->|Yes| I[Load Model]
H -->|No| J[Eviction Policy]
J --> K[Unload Idle Model]
K --> I
I --> L[Execute Inference]
L --> M[Release Memory]
M --> N[Return Result]
O[Memory Monitor] --> G
P[Thermal Monitor] --> GMulti-Model Memory Scheduler (Python):
#!/usr/bin/env python3
"""
Memory-aware multi-model scheduler for edge inference
"""
import torch
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import threading
import time
from collections import OrderedDict
from dataclasses import dataclass
from typing import Dict, Optional
import psutil
@dataclass
class ModelMetadata:
name: str
engine_path: str
memory_required: int # bytes
priority: int
last_used: float
load_count: int
inference_count: int
class MemoryAwareScheduler:
"""
Multi-model scheduler with intelligent memory management
"""
def __init__(self, max_gpu_memory_mb=4096, cache_size=3):
self.max_gpu_memory = max_gpu_memory_mb * 1024 * 1024
self.cache_size = cache_size
self.loaded_models = OrderedDict()
self.model_metadata: Dict[str, ModelMetadata] = {}
self.lock = threading.Lock()
# Memory tracking
self.current_memory_usage = 0
# Performance metrics
self.cache_hits = 0
self.cache_misses = 0
self.evictions = 0
def register_model(self, name: str, engine_path: str,
priority: int = 1, estimated_memory_mb: int = 100):
"""Register a model for scheduling"""
metadata = ModelMetadata(
name=name,
engine_path=engine_path,
memory_required=estimated_memory_mb * 1024 * 1024,
priority=priority,
last_used=0,
load_count=0,
inference_count=0
)
self.model_metadata[name] = metadata
print(f"Registered model: {name} ({estimated_memory_mb}MB, priority={priority})")
def get_available_memory(self):
"""Get available GPU memory"""
torch.cuda.synchronize()
free_memory, total_memory = torch.cuda.mem_get_info()
return free_memory
def estimate_model_memory(self, engine_path: str):
"""Estimate memory requirements for a model"""
import os
# Engine file size is rough approximation
engine_size = os.path.getsize(engine_path)
# Add overhead for activation buffers (2-3x engine size)
estimated_memory = engine_size * 2.5
return int(estimated_memory)
def select_eviction_candidate(self) -> Optional[str]:
"""Select model to evict using weighted LRU"""
if not self.loaded_models:
return None
# Score models by: recency, priority, usage frequency
scores = {}
current_time = time.time()
for name, model_data in self.loaded_models.items():
metadata = self.model_metadata[name]
# Time since last use (seconds)
idle_time = current_time - metadata.last_used
# Weighted score (lower is better for eviction)
# High priority and recent use = low score (keep)
# Low priority and old use = high score (evict)
score = idle_time / (metadata.priority * (metadata.inference_count + 1))
scores[name] = score
# Return model with highest eviction score
return max(scores.items(), key=lambda x: x[1])[0]
def evict_model(self, name: str):
"""Evict a model from GPU memory"""
if name not in self.loaded_models:
return
with self.lock:
model_data = self.loaded_models[name]
metadata = self.model_metadata[name]
# Cleanup TensorRT resources
if 'context' in model_data:
del model_data['context']
if 'engine' in model_data:
del model_data['engine']
if 'runtime' in model_data:
del model_data['runtime']
# Free device buffers
if 'buffers' in model_data:
for buffer in model_data['buffers']:
if 'device' in buffer:
buffer['device'].free()
# Update memory tracking
self.current_memory_usage -= metadata.memory_required
# Remove from cache
del self.loaded_models[name]
self.evictions += 1
print(f"Evicted model: {name} (freed ~{metadata.memory_required/1024/1024:.1f}MB)")
def load_model(self, name: str):
"""Load model into GPU memory"""
metadata = self.model_metadata[name]
# Check if already loaded
if name in self.loaded_models:
self.cache_hits += 1
metadata.last_used = time.time()
metadata.inference_count += 1
return self.loaded_models[name]
self.cache_misses += 1
# Make space if needed
while self.current_memory_usage + metadata.memory_required > self.max_gpu_memory:
if len(self.loaded_models) == 0:
raise RuntimeError(f"Model {name} requires more memory than available")
evict_candidate = self.select_eviction_candidate()
self.evict_model(evict_candidate)
# Enforce cache size limit
while len(self.loaded_models) >= self.cache_size:
evict_candidate = self.select_eviction_candidate()
self.evict_model(evict_candidate)
with self.lock:
print(f"Loading model: {name}")
# Load TensorRT engine
logger = trt.Logger(trt.Logger.WARNING)
runtime = trt.Runtime(logger)
with open(metadata.engine_path, 'rb') as f:
engine = runtime.deserialize_cuda_engine(f.read())
context = engine.create_execution_context()
# Allocate buffers
buffers = []
bindings = []
for binding in engine:
size = trt.volume(engine.get_binding_shape(binding))
dtype = trt.nptype(engine.get_binding_dtype(binding))
host_mem = cuda.pagelocked_empty(size, dtype)
device_mem = cuda.mem_alloc(host_mem.nbytes)
bindings.append(int(device_mem))
buffer_info = {
'host': host_mem,
'device': device_mem,
'size': size,
'dtype': dtype
}
if engine.binding_is_input(binding):
buffer_info['type'] = 'input'
else:
buffer_info['type'] = 'output'
buffers.append(buffer_info)
# Create CUDA stream
stream = cuda.Stream()
model_data = {
'runtime': runtime,
'engine': engine,
'context': context,
'buffers': buffers,
'bindings': bindings,
'stream': stream
}
# Update cache and metadata
self.loaded_models[name] = model_data
self.current_memory_usage += metadata.memory_required
metadata.last_used = time.time()
metadata.load_count += 1
metadata.inference_count += 1
print(f"Model loaded: {name} (memory: {self.current_memory_usage/1024/1024:.1f}MB / "
f"{self.max_gpu_memory/1024/1024:.1f}MB)")
return model_data
def infer(self, model_name: str, input_data):
"""Execute inference with automatic model loading"""
model_data = self.load_model(model_name)
# Prepare input
input_buffer = model_data['buffers'][0]
input_data_flat = input_data.ravel()
# Copy to device
cuda.memcpy_htod_async(
input_buffer['device'],
input_data_flat,
model_data['stream']
)
# Execute
model_data['context'].execute_async_v2(
bindings=model_data['bindings'],
stream_handle=model_data['stream'].handle
)
# Copy output
output_buffer = model_data['buffers'][1]
cuda.memcpy_dtoh_async(
output_buffer['host'],
output_buffer['device'],
model_data['stream']
)
model_data['stream'].synchronize()
# Update metadata
self.model_metadata[model_name].last_used = time.time()
return output_buffer['host']
def get_statistics(self):
"""Get scheduler performance statistics"""
total_requests = self.cache_hits + self.cache_misses
hit_rate = self.cache_hits / total_requests if total_requests > 0 else 0
return {
'loaded_models': len(self.loaded_models),
'memory_usage_mb': self.current_memory_usage / 1024 / 1024,
'max_memory_mb': self.max_gpu_memory / 1024 / 1024,
'cache_hits': self.cache_hits,
'cache_misses': self.cache_misses,
'hit_rate': hit_rate,
'evictions': self.evictions,
'model_stats': {
name: {
'loads': meta.load_count,
'inferences': meta.inference_count,
'last_used': time.time() - meta.last_used
}
for name, meta in self.model_metadata.items()
}
}
# Example usage
if __name__ == '__main__':
scheduler = MemoryAwareScheduler(max_gpu_memory_mb=4096, cache_size=3)
# Register models
scheduler.register_model('yolov8n', 'yolov8n_int8.engine', priority=3, estimated_memory_mb=80)
scheduler.register_model('yolov8s', 'yolov8s_int8.engine', priority=2, estimated_memory_mb=140)
scheduler.register_model('resnet50', 'resnet50_int8.engine', priority=1, estimated_memory_mb=120)
# Simulate inference requests
for i in range(100):
model_name = ['yolov8n', 'yolov8s', 'resnet50'][i % 3]
dummy_input = torch.randn(1, 3, 640, 640).cuda().cpu().numpy()
result = scheduler.infer(model_name, dummy_input)
if i % 20 == 0:
stats = scheduler.get_statistics()
print(f"\nIteration {i}: Hit rate = {stats['hit_rate']:.2%}, "
f"Evictions = {stats['evictions']}")
GPU Resource Pooling and Contention Management
Multiple concurrent inference requests compete for limited GPU compute resources. Resource pooling with intelligent scheduling prevents contention-induced latency spikes.
GPU Resource Pool Manager (Python):
#!/usr/bin/env python3
"""
GPU resource pool with contention management
"""
import torch
import threading
import queue
import time
from dataclasses import dataclass
from typing import List, Optional
from enum import Enum
class Priority(Enum):
LOW = 0
NORMAL = 1
HIGH = 2
CRITICAL = 3
@dataclass
class InferenceTask:
task_id: str
model_name: str
input_data: any
priority: Priority
submit_time: float
deadline: Optional[float] = None
callback: Optional[callable] = None
class GPUResourcePool:
"""
GPU resource pool with priority scheduling and contention management
"""
def __init__(self, num_streams=4, enable_cuda_graphs=True):
self.num_streams = num_streams
self.enable_cuda_graphs = enable_cuda_graphs
# Create CUDA streams for parallel execution
self.streams = [torch.cuda.Stream() for _ in range(num_streams)]
self.stream_busy = [False] * num_streams
self.stream_lock = threading.Lock()
# Priority queues for different request types
self.queues = {
Priority.CRITICAL: queue.PriorityQueue(),
Priority.HIGH: queue.PriorityQueue(),
Priority.NORMAL: queue.PriorityQueue(),
Priority.LOW: queue.PriorityQueue()
}
# Performance tracking
self.tasks_completed = 0
self.tasks_dropped = 0
self.total_queue_time = 0
self.total_execution_time = 0
# Worker threads
self.workers = []
self.running = True
self.start_workers()
def start_workers(self):
"""Start worker threads for each stream"""
for i in range(self.num_streams):
worker = threading.Thread(target=self.worker_loop, args=(i,))
worker.daemon = True
worker.start()
self.workers.append(worker)
def submit_task(self, task: InferenceTask):
"""Submit inference task to appropriate priority queue"""
# Check deadline if specified
if task.deadline and time.time() > task.deadline:
self.tasks_dropped += 1
print(f"Task {task.task_id} dropped (deadline exceeded)")
return
# Add to priority queue
# Use negative timestamp for FIFO within same priority
priority_value = (task.priority.value, -task.submit_time)
self.queues[task.priority].put((priority_value, task))
def get_next_task(self) -> Optional[InferenceTask]:
"""Get next task respecting priority order"""
# Check priorities in order: CRITICAL -> HIGH -> NORMAL -> LOW
for priority in [Priority.CRITICAL, Priority.HIGH, Priority.NORMAL, Priority.LOW]:
try:
_, task = self.queues[priority].get_nowait()
return task
except queue.Empty:
continue
return None
def acquire_stream(self) -> Optional[int]:
"""Acquire an available CUDA stream"""
with self.stream_lock:
for i, busy in enumerate(self.stream_busy):
if not busy:
self.stream_busy[i] = True
return i
return None
def release_stream(self, stream_id: int):
"""Release CUDA stream"""
with self.stream_lock:
self.stream_busy[stream_id] = False
def worker_loop(self, worker_id: int):
"""Worker thread processing tasks"""
while self.running:
# Get stream
stream_id = None
while stream_id is None and self.running:
stream_id = self.acquire_stream()
if stream_id is None:
time.sleep(0.001)
if not self.running:
break
# Get next task
task = self.get_next_task()
if task is None:
self.release_stream(stream_id)
time.sleep(0.001)
continue
# Check deadline again
if task.deadline and time.time() > task.deadline:
self.tasks_dropped += 1
self.release_stream(stream_id)
print(f"Task {task.task_id} dropped at execution (deadline exceeded)")
continue
# Execute task
queue_time = time.time() - task.submit_time
self.total_queue_time += queue_time
try:
start_time = time.time()
# Execute on assigned stream
with torch.cuda.stream(self.streams[stream_id]):
result = self.execute_inference(task)
execution_time = time.time() - start_time
self.total_execution_time += execution_time
self.tasks_completed += 1
# Callback if provided
if task.callback:
task.callback(result, queue_time, execution_time)
except Exception as e:
print(f"Task {task.task_id} failed: {e}")
finally:
self.release_stream(stream_id)
def execute_inference(self, task: InferenceTask):
"""Execute inference (placeholder - integrate with scheduler)"""
# Simulate inference
time.sleep(0.02) # 20ms
return {'task_id': task.task_id, 'detections': []}
def get_statistics(self):
"""Get resource pool statistics"""
total_tasks = self.tasks_completed + self.tasks_dropped
return {
'tasks_completed': self.tasks_completed,
'tasks_dropped': self.tasks_dropped,
'drop_rate': self.tasks_dropped / total_tasks if total_tasks > 0 else 0,
'avg_queue_time': self.total_queue_time / self.tasks_completed if self.tasks_completed > 0 else 0,
'avg_execution_time': self.total_execution_time / self.tasks_completed if self.tasks_completed > 0 else 0,
'queue_depths': {
priority.name: self.queues[priority].qsize()
for priority in Priority
}
}
def shutdown(self):
"""Shutdown resource pool"""
self.running = False
for worker in self.workers:
worker.join()
# Example usage
if __name__ == '__main__':
pool = GPUResourcePool(num_streams=4)
# Submit tasks with different priorities
for i in range(100):
priority = Priority.HIGH if i % 10 == 0 else Priority.NORMAL
deadline = time.time() + 0.5 # 500ms deadline
task = InferenceTask(
task_id=f"task_{i}",
model_name='yolov8n',
input_data=None,
priority=priority,
submit_time=time.time(),
deadline=deadline
)
pool.submit_task(task)
time.sleep(0.01) # 10ms between submissions
# Wait for completion
time.sleep(5)
stats = pool.get_statistics()
print(f"\nPool Statistics:")
print(f"Completed: {stats['tasks_completed']}")
print(f"Dropped: {stats['tasks_dropped']}")
print(f"Drop rate: {stats['drop_rate']:.2%}")
print(f"Avg queue time: {stats['avg_queue_time']*1000:.2f}ms")
print(f"Avg execution time: {stats['avg_execution_time']*1000:.2f}ms")
pool.shutdown()
Adaptive Batching for Throughput Optimization
Batching multiple inference requests reduces per-request overhead and increases GPU utilization. Adaptive batching balances throughput gains against latency constraints.
Adaptive Batch Scheduler (Python):
#!/usr/bin/env python3
"""
Adaptive batching for optimized throughput
"""
import torch
import threading
import time
import numpy as np
from collections import deque
from dataclasses import dataclass
@dataclass
class BatchConfig:
min_batch_size: int = 1
max_batch_size: int = 8
max_wait_time: float = 0.01 # 10ms
target_latency: float = 0.05 # 50ms
class AdaptiveBatchScheduler:
"""
Adaptive batching with dynamic batch size adjustment
"""
def __init__(self, config: BatchConfig):
self.config = config
self.pending_requests = deque()
self.lock = threading.Lock()
self.condition = threading.Condition(self.lock)
# Performance tracking for adaptation
self.latency_history = deque(maxlen=100)
self.batch_size_history = deque(maxlen=100)
self.throughput_history = deque(maxlen=100)
# Current adaptive parameters
self.current_batch_size = config.min_batch_size
self.current_wait_time = config.max_wait_time
# Batch processing thread
self.running = True
self.processor_thread = threading.Thread(target=self.batch_processor)
self.processor_thread.daemon = True
self.processor_thread.start()
def submit_request(self, input_data, callback):
"""Submit inference request"""
request = {
'input': input_data,
'callback': callback,
'submit_time': time.time()
}
with self.condition:
self.pending_requests.append(request)
self.condition.notify()
def get_batch(self):
"""Collect requests into batch"""
with self.condition:
# Wait for requests or timeout
deadline = time.time() + self.current_wait_time
while len(self.pending_requests) < self.current_batch_size:
remaining = deadline - time.time()
if remaining <= 0:
break
self.condition.wait(timeout=remaining)
# Collect batch
batch_size = min(len(self.pending_requests), self.current_batch_size)
if batch_size == 0:
return None
batch = []
for _ in range(batch_size):
batch.append(self.pending_requests.popleft())
return batch
def batch_processor(self):
"""Background batch processing thread"""
while self.running:
batch = self.get_batch()
if batch is None:
time.sleep(0.001)
continue
# Process batch
batch_start = time.time()
results = self.execute_batch(batch)
batch_duration = time.time() - batch_start
# Calculate metrics
avg_latency = np.mean([time.time() - req['submit_time'] for req in batch])
throughput = len(batch) / batch_duration
# Record history
self.latency_history.append(avg_latency)
self.batch_size_history.append(len(batch))
self.throughput_history.append(throughput)
# Adapt batch parameters
self.adapt_batch_parameters(avg_latency, throughput)
# Invoke callbacks
for req, result in zip(batch, results):
if req['callback']:
latency = time.time() - req['submit_time']
req['callback'](result, latency)
def execute_batch(self, batch):
"""Execute batched inference"""
# Stack inputs
inputs = np.stack([req['input'] for req in batch])
inputs_tensor = torch.from_numpy(inputs).cuda()
# Simulated batched inference
time.sleep(0.015 * (1 + len(batch) * 0.1)) # Simulated batch processing
# Return per-request results
results = [{'detections': []} for _ in batch]
return results
def adapt_batch_parameters(self, current_latency, current_throughput):
"""Adapt batch size and wait time based on performance"""
if len(self.latency_history) < 10:
return
avg_latency = np.mean(list(self.latency_history)[-10:])
avg_throughput = np.mean(list(self.throughput_history)[-10:])
# If latency too high, reduce batch size
if avg_latency > self.config.target_latency * 1.2:
self.current_batch_size = max(
self.config.min_batch_size,
self.current_batch_size - 1
)
self.current_wait_time = max(
self.config.max_wait_time * 0.5,
self.current_wait_time * 0.9
)
# If latency acceptable and queue building, increase batch size
elif avg_latency < self.config.target_latency * 0.8 and len(self.pending_requests) > 5:
self.current_batch_size = min(
self.config.max_batch_size,
self.current_batch_size + 1
)
self.current_wait_time = min(
self.config.max_wait_time,
self.current_wait_time * 1.1
)
def get_statistics(self):
"""Get batching statistics"""
if not self.latency_history:
return {}
return {
'current_batch_size': self.current_batch_size,
'current_wait_time': self.current_wait_time * 1000, # ms
'avg_latency': np.mean(list(self.latency_history)) * 1000, # ms
'p95_latency': np.percentile(list(self.latency_history), 95) * 1000, # ms
'avg_throughput': np.mean(list(self.throughput_history)), # req/s
'avg_batch_size': np.mean(list(self.batch_size_history))
}
def shutdown(self):
"""Shutdown scheduler"""
self.running = False
with self.condition:
self.condition.notify()
self.processor_thread.join()
# Example usage
if __name__ == '__main__':
config = BatchConfig(
min_batch_size=1,
max_batch_size=8,
max_wait_time=0.01,
target_latency=0.05
)
scheduler = AdaptiveBatchScheduler(config)
completed = []
def callback(result, latency):
completed.append(latency)
# Simulate varying load
for i in range(200):
input_data = np.random.randn(1, 3, 640, 640).astype(np.float32)
scheduler.submit_request(input_data, callback)
# Variable request rate
if i < 50:
time.sleep(0.02) # Low load
elif i < 100:
time.sleep(0.005) # Medium load
elif i < 150:
time.sleep(0.002) # High load
else:
time.sleep(0.01) # Medium load
if i % 50 == 0 and i > 0:
stats = scheduler.get_statistics()
print(f"\nIteration {i}:")
print(f" Batch size: {stats['current_batch_size']}")
print(f" Avg latency: {stats['avg_latency']:.2f}ms")
print(f" P95 latency: {stats['p95_latency']:.2f}ms")
print(f" Throughput: {stats['avg_throughput']:.1f} req/s")
time.sleep(2)
scheduler.shutdown()
print(f"\nCompleted {len(completed)} requests")
print(f"Average latency: {np.mean(completed)*1000:.2f}ms")
KV Cache Management for Transformer Models
Transformer-based models like vision transformers or multimodal models require KV cache management for efficient sequence processing on edge devices.
KV Cache Manager (Python):
#!/usr/bin/env python3
"""
KV cache management for transformer models on edge
"""
import torch
import threading
from collections import OrderedDict
from dataclasses import dataclass
import time
@dataclass
class CacheEntry:
key_cache: torch.Tensor
value_cache: torch.Tensor
sequence_length: int
last_accessed: float
access_count: int
class KVCacheManager:
"""
Manages KV caches for transformer models with memory-efficient strategies
"""
def __init__(self, max_cache_memory_mb=512, num_layers=12,
hidden_size=768, num_heads=12):
self.max_cache_memory = max_cache_memory_mb * 1024 * 1024
self.num_layers = num_layers
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
# Cache storage
self.caches = OrderedDict()
self.cache_lock = threading.Lock()
# Memory tracking
self.current_memory_usage = 0
# Statistics
self.cache_hits = 0
self.cache_misses = 0
self.evictions = 0
def estimate_cache_size(self, sequence_length, batch_size=1):
"""Estimate memory required for KV cache"""
# Each layer stores K and V caches
# Shape: [batch_size, num_heads, sequence_length, head_dim]
bytes_per_element = 2 # FP16
k_cache_size = (batch_size * self.num_heads * sequence_length *
self.head_dim * bytes_per_element)
v_cache_size = k_cache_size
total_size = (k_cache_size + v_cache_size) * self.num_layers
return total_size
def allocate_cache(self, session_id: str, max_sequence_length: int,
batch_size: int = 1):
"""Allocate KV cache for a session"""
cache_size = self.estimate_cache_size(max_sequence_length, batch_size)
# Make space if needed
while self.current_memory_usage + cache_size > self.max_cache_memory:
if len(self.caches) == 0:
raise RuntimeError("Cannot allocate cache: insufficient memory")
evict_id = self.select_eviction_candidate()
self.evict_cache(evict_id)
with self.cache_lock:
# Allocate cache tensors
k_caches = []
v_caches = []
for _ in range(self.num_layers):
k_cache = torch.zeros(
batch_size, self.num_heads, max_sequence_length, self.head_dim,
dtype=torch.float16,
device='cuda'
)
v_cache = torch.zeros(
batch_size, self.num_heads, max_sequence_length, self.head_dim,
dtype=torch.float16,
device='cuda'
)
k_caches.append(k_cache)
v_caches.append(v_cache)
entry = CacheEntry(
key_cache=k_caches,
value_cache=v_caches,
sequence_length=0,
last_accessed=time.time(),
access_count=0
)
self.caches[session_id] = entry
self.current_memory_usage += cache_size
print(f"Allocated cache for {session_id}: {cache_size/1024/1024:.1f}MB")
def get_cache(self, session_id: str):
"""Retrieve cache for session"""
with self.cache_lock:
if session_id not in self.caches:
self.cache_misses += 1
return None
self.cache_hits += 1
entry = self.caches[session_id]
entry.last_accessed = time.time()
entry.access_count += 1
return entry
def update_cache(self, session_id: str, new_keys, new_values, position: int):
"""Update cache with new key-value pairs"""
entry = self.get_cache(session_id)
if entry is None:
raise ValueError(f"No cache found for session {session_id}")
# Update caches for each layer
for layer_idx, (k, v) in enumerate(zip(new_keys, new_values)):
entry.key_cache[layer_idx][:, :, position:position+k.shape[2], :] = k
entry.value_cache[layer_idx][:, :, position:position+v.shape[2], :] = v
entry.sequence_length = max(entry.sequence_length, position + new_keys[0].shape[2])
def select_eviction_candidate(self):
"""Select cache to evict using LRU"""
if not self.caches:
return None
# Find least recently used
oldest_id = None
oldest_time = float('inf')
for session_id, entry in self.caches.items():
if entry.last_accessed < oldest_time:
oldest_time = entry.last_accessed
oldest_id = session_id
return oldest_id
def evict_cache(self, session_id: str):
"""Evict cache for session"""
if session_id not in self.caches:
return
with self.cache_lock:
entry = self.caches[session_id]
# Free GPU memory
for k_cache, v_cache in zip(entry.key_cache, entry.value_cache):
del k_cache
del v_cache
cache_size = self.estimate_cache_size(entry.sequence_length)
self.current_memory_usage -= cache_size
del self.caches[session_id]
self.evictions += 1
print(f"Evicted cache for {session_id}")
def compress_cache(self, session_id: str, compression_ratio: float = 0.5):
"""Compress cache by removing less important positions"""
entry = self.get_cache(session_id)
if entry is None:
return
# Simple compression: keep most recent tokens
keep_length = int(entry.sequence_length * compression_ratio)
start_position = entry.sequence_length - keep_length
with self.cache_lock:
for layer_idx in range(self.num_layers):
# Shift cache to beginning
entry.key_cache[layer_idx][:, :, :keep_length, :] = \
entry.key_cache[layer_idx][:, :, start_position:, :]
entry.value_cache[layer_idx][:, :, :keep_length, :] = \
entry.value_cache[layer_idx][:, :, start_position:, :]
# Zero out rest
entry.key_cache[layer_idx][:, :, keep_length:, :] = 0
entry.value_cache[layer_idx][:, :, keep_length:, :] = 0
entry.sequence_length = keep_length
print(f"Compressed cache for {session_id} to {keep_length} tokens")
def get_statistics(self):
"""Get cache manager statistics"""
total_requests = self.cache_hits + self.cache_misses
hit_rate = self.cache_hits / total_requests if total_requests > 0 else 0
return {
'active_caches': len(self.caches),
'memory_usage_mb': self.current_memory_usage / 1024 / 1024,
'max_memory_mb': self.max_cache_memory / 1024 / 1024,
'cache_hits': self.cache_hits,
'cache_misses': self.cache_misses,
'hit_rate': hit_rate,
'evictions': self.evictions
}
# Example usage
if __name__ == '__main__':
manager = KVCacheManager(
max_cache_memory_mb=512,
num_layers=12,
hidden_size=768,
num_heads=12
)
# Allocate caches for multiple sessions
for i in range(5):
session_id = f"session_{i}"
manager.allocate_cache(session_id, max_sequence_length=512)
# Simulate cache usage
for iteration in range(100):
session_id = f"session_{iteration % 5}"
entry = manager.get_cache(session_id)
if entry:
# Simulate adding new tokens
new_position = entry.sequence_length
if new_position < 512:
# Dummy key/value tensors
new_keys = [torch.randn(1, 12, 1, 64, dtype=torch.float16, device='cuda')
for _ in range(12)]
new_values = [torch.randn(1, 12, 1, 64, dtype=torch.float16, device='cuda')
for _ in range(12)]
manager.update_cache(session_id, new_keys, new_values, new_position)
if iteration % 20 == 0:
stats = manager.get_statistics()
print(f"\nIteration {iteration}:")
print(f" Active caches: {stats['active_caches']}")
print(f" Memory usage: {stats['memory_usage_mb']:.1f}MB")
print(f" Hit rate: {stats['hit_rate']:.2%}")
SLA Enforcement and Performance Guarantees
Production edge deployments require service level agreement enforcement ensuring consistent performance under varying load conditions.
SLA Enforcement Manager (Python):
#!/usr/bin/env python3
"""
SLA enforcement for edge inference services
"""
import time
import threading
from collections import deque
from dataclasses import dataclass
from enum import Enum
import numpy as np
class SLAViolationType(Enum):
LATENCY = "latency"
THROUGHPUT = "throughput"
AVAILABILITY = "availability"
ERROR_RATE = "error_rate"
@dataclass
class SLAPolicy:
max_latency_p95: float = 0.05 # 50ms
max_latency_p99: float = 0.1 # 100ms
min_throughput: float = 20.0 # req/s
max_error_rate: float = 0.01 # 1%
min_availability: float = 0.999 # 99.9%
class SLAEnforcementManager:
"""
Monitors and enforces SLA policies for inference service
"""
def __init__(self, policy: SLAPolicy, window_size=60):
self.policy = policy
self.window_size = window_size
# Metrics collection
self.latencies = deque(maxlen=1000)
self.errors = deque(maxlen=1000)
self.requests = deque(maxlen=1000)
self.timestamps = deque(maxlen=1000)
# SLA status
self.violations = []
self.lock = threading.Lock()
# Monitoring thread
self.running = True
self.monitor_thread = threading.Thread(target=self.monitor_loop)
self.monitor_thread.daemon = True
self.monitor_thread.start()
def record_request(self, latency: float, success: bool):
"""Record request metrics"""
with self.lock:
self.latencies.append(latency)
self.errors.append(0 if success else 1)
self.requests.append(1)
self.timestamps.append(time.time())
def get_window_metrics(self):
"""Calculate metrics for current time window"""
current_time = time.time()
window_start = current_time - self.window_size
# Filter metrics to window
window_latencies = []
window_errors = []
window_requests = 0
for timestamp, latency, error in zip(self.timestamps, self.latencies, self.errors):
if timestamp >= window_start:
window_latencies.append(latency)
window_errors.append(error)
window_requests += 1
if not window_latencies:
return None
metrics = {
'latency_p95': np.percentile(window_latencies, 95),
'latency_p99': np.percentile(window_latencies, 99),
'latency_mean': np.mean(window_latencies),
'throughput': window_requests / self.window_size,
'error_rate': np.mean(window_errors),
'availability': 1.0 - np.mean(window_errors),
'total_requests': window_requests
}
return metrics
def check_sla_compliance(self):
"""Check if current metrics comply with SLA"""
metrics = self.get_window_metrics()
if metrics is None:
return True, []
violations = []
# Check latency SLAs
if metrics['latency_p95'] > self.policy.max_latency_p95:
violations.append({
'type': SLAViolationType.LATENCY,
'metric': 'p95',
'value': metrics['latency_p95'],
'threshold': self.policy.max_latency_p95,
'severity': 'medium'
})
if metrics['latency_p99'] > self.policy.max_latency_p99:
violations.append({
'type': SLAViolationType.LATENCY,
'metric': 'p99',
'value': metrics['latency_p99'],
'threshold': self.policy.max_latency_p99,
'severity': 'high'
})
# Check throughput SLA
if metrics['throughput'] < self.policy.min_throughput:
violations.append({
'type': SLAViolationType.THROUGHPUT,
'metric': 'throughput',
'value': metrics['throughput'],
'threshold': self.policy.min_throughput,
'severity': 'medium'
})
# Check error rate SLA
if metrics['error_rate'] > self.policy.max_error_rate:
violations.append({
'type': SLAViolationType.ERROR_RATE,
'metric': 'error_rate',
'value': metrics['error_rate'],
'threshold': self.policy.max_error_rate,
'severity': 'high'
})
# Check availability SLA
if metrics['availability'] < self.policy.min_availability:
violations.append({
'type': SLAViolationType.AVAILABILITY,
'metric': 'availability',
'value': metrics['availability'],
'threshold': self.policy.min_availability,
'severity': 'critical'
})
return len(violations) == 0, violations
def apply_mitigation(self, violations):
"""Apply mitigation strategies for SLA violations"""
for violation in violations:
if violation['type'] == SLAViolationType.LATENCY:
# Reduce batch size, increase worker threads
print(f"Mitigating latency violation: "
f"{violation['metric']}={violation['value']*1000:.1f}ms "
f"(threshold={violation['threshold']*1000:.1f}ms)")
# Implementation would adjust scheduler parameters
elif violation['type'] == SLAViolationType.THROUGHPUT:
# Increase parallelism, enable batching
print(f"Mitigating throughput violation: "
f"{violation['value']:.1f} req/s "
f"(threshold={violation['threshold']:.1f} req/s)")
elif violation['type'] == SLAViolationType.ERROR_RATE:
# Enable retries, circuit breaker
print(f"Mitigating error rate violation: "
f"{violation['value']*100:.2f}% "
f"(threshold={violation['threshold']*100:.2f}%)")
elif violation['type'] == SLAViolationType.AVAILABILITY:
# Failover, redundancy
print(f"CRITICAL: Availability violation: "
f"{violation['value']*100:.3f}% "
f"(threshold={violation['threshold']*100:.3f}%)")
# Record violations
with self.lock:
self.violations.extend(violations)
def monitor_loop(self):
"""Background monitoring thread"""
while self.running:
time.sleep(5) # Check every 5 seconds
compliant, violations = self.check_sla_compliance()
if not compliant:
self.apply_mitigation(violations)
# Print status
metrics = self.get_window_metrics()
if metrics:
print(f"\nSLA Status: {'COMPLIANT' if compliant else 'VIOLATED'}")
print(f" P95 Latency: {metrics['latency_p95']*1000:.1f}ms")
print(f" Throughput: {metrics['throughput']:.1f} req/s")
print(f" Error Rate: {metrics['error_rate']*100:.2f}%")
def get_sla_report(self):
"""Generate SLA compliance report"""
metrics = self.get_window_metrics()
compliant, current_violations = self.check_sla_compliance()
return {
'compliant': compliant,
'current_metrics': metrics,
'current_violations': current_violations,
'total_violations': len(self.violations),
'violation_history': self.violations[-10:] # Last 10
}
def shutdown(self):
"""Shutdown monitor"""
self.running = False
self.monitor_thread.join()
# Example usage
if __name__ == '__main__':
policy = SLAPolicy(
max_latency_p95=0.05,
max_latency_p99=0.1,
min_throughput=20.0,
max_error_rate=0.01
)
manager = SLAEnforcementManager(policy, window_size=30)
# Simulate requests with varying performance
for i in range(200):
# Simulate degrading performance
if i < 50:
latency = np.random.normal(0.02, 0.005) # Good
success = np.random.random() > 0.001
elif i < 100:
latency = np.random.normal(0.04, 0.01) # Degrading
success = np.random.random() > 0.005
elif i < 150:
latency = np.random.normal(0.08, 0.02) # Violating
success = np.random.random() > 0.02
else:
latency = np.random.normal(0.03, 0.008) # Recovering
success = np.random.random() > 0.002
manager.record_request(latency, success)
time.sleep(0.1)
# Final report
report = manager.get_sla_report()
print(f"\n{'='*60}")
print(f"Final SLA Report:")
print(f"Compliant: {report['compliant']}")
print(f"Total Violations: {report['total_violations']}")
manager.shutdown()
Key Takeaways
Advanced optimization patterns enable production edge AI deployments to maximize resource utilization while maintaining quality of service. Memory-aware scheduling coordinates multiple concurrent models through intelligent caching and eviction policies preventing GPU memory exhaustion. Implementation typically achieves 80-90% cache hit rates reducing model load overhead by 50-70%.
GPU resource pooling with priority-based scheduling prevents contention-induced latency spikes. CUDA stream management enables parallel execution of independent requests achieving 2-3x throughput improvement over sequential processing. Adaptive batching dynamically adjusts batch sizes balancing throughput gains against latency constraints, typically improving throughput by 40-60% while maintaining sub-50ms P95 latency.
KV cache management for transformer models reduces memory footprint through intelligent compression and eviction strategies. Proper cache management enables 50-100 concurrent transformer sessions on 8GB edge devices. SLA enforcement mechanisms monitor performance metrics in real-time applying automatic mitigation strategies maintaining service quality under varying load conditions.
Combined optimization strategies demonstrate 50-70% end-to-end latency reduction compared to naive implementations while improving throughput by 2-4x. Production deployments require careful tuning of optimization parameters based on specific workload characteristics, hardware capabilities, and service level requirements. Comprehensive monitoring and adaptive strategies ensure consistent performance as deployment conditions evolve.
Part 6 concludes the series with production operations covering Prometheus/Jaeger monitoring integration, data drift detection mechanisms, model versioning strategies, canary deployment patterns, OTA update procedures, comprehensive health checking, feedback loop implementation, and orchestration patterns for managing 100+ distributed edge devices.
References
- PyTorch CUDA Semantics (https://pytorch.org/docs/stable/notes/cuda.html)
- NVIDIA CUDA Streams Documentation (https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#streams)
- Efficient Memory Management for LLM Serving (https://arxiv.org/abs/2308.13893)
- Orca: Distributed Serving System for Transformer Models (https://www.usenix.org/conference/osdi22/presentation/yu)
- TensorRT Best Practices Guide (https://docs.nvidia.com/deeplearning/tensorrt/best-practices/index.html)
- Adaptive Batching for Deep Learning Inference (https://arxiv.org/abs/2104.04473)
- PyTorch CUDA Optimization Patterns (https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/)
- NVIDIA CUDA Occupancy Optimization (https://developer.nvidia.com/blog/cuda-pro-tip-occupancy-api-simplifies-launch-configuration/)
