Real-Time WebSocket Architecture Series: Part 4 – Authentication & Security

Real-Time WebSocket Architecture Series: Part 4 – Authentication & Security

This entry is part 4 of 8 in the series Real-Time WebSocket Architecture Series

Welcome to Part 4! In Part 3, we implemented rooms and namespaces. Now we tackle the most critical aspect: Security. This comprehensive guide covers JWT authentication, OWASP Top 10 vulnerabilities, and production-ready security practices.

WebSocket Security Challenges

WebSockets introduce unique security challenges:

  • Persistent connections: Increased attack surface
  • No built-in authentication: Must implement custom auth
  • Cross-origin attacks: CSRF vulnerabilities
  • Injection attacks: XSS, SQL injection through messages
  • DoS vulnerabilities: Unlimited connections can overwhelm servers

Common WebSocket Vulnerabilities (OWASP)

  1. Cross-Site WebSocket Hijacking (CSWSH)
  2. Insufficient authentication/authorization
  3. Injection attacks (SQL, XSS, Command)
  4. Denial of Service (DoS)
  5. Man-in-the-Middle attacks
  6. Broken access control
  7. Sensitive data exposure
  8. Insufficient input validation

Security Architecture

sequenceDiagram
    participant C as Client
    participant Auth as Auth API
    participant WS as WebSocket Server
    
    C->>Auth: POST /login
    Auth->>Auth: Verify & Generate JWT
    Auth-->>C: Access + Refresh Token
    
    C->>WS: Connect (JWT in auth header)
    WS->>WS: Validate JWT
    WS->>WS: Check Origin
    WS->>WS: Rate Limit Check
    
    alt All Checks Pass
        WS-->>C: Connection Accepted
    else Failed
        WS-->>C: Connection Rejected
    end

Implementation

Install Dependencies

npm install express socket.io jsonwebtoken bcryptjs dotenv express-rate-limit validator helmet cors

Environment Setup

# .env
PORT=3000
JWT_SECRET=your-secret-key-256-bits
JWT_EXPIRES_IN=1h
JWT_REFRESH_SECRET=your-refresh-secret
JWT_REFRESH_EXPIRES_IN=7d
ALLOWED_ORIGINS=http://localhost:3000

Auth Service

// services/AuthService.js
const jwt = require('jsonwebtoken');
const bcrypt = require('bcryptjs');

class AuthService {
  static generateAccessToken(payload) {
    return jwt.sign(payload, process.env.JWT_SECRET, {
      expiresIn: process.env.JWT_EXPIRES_IN || '1h'
    });
  }
  
  static generateRefreshToken(payload) {
    return jwt.sign(
      { userId: payload.userId },
      process.env.JWT_REFRESH_SECRET,
      { expiresIn: '7d' }
    );
  }
  
  static verifyAccessToken(token) {
    try {
      return jwt.verify(token, process.env.JWT_SECRET);
    } catch (error) {
      if (error.name === 'TokenExpiredError') {
        throw new Error('TOKEN_EXPIRED');
      }
      throw new Error('INVALID_TOKEN');
    }
  }
  
  static async hashPassword(password) {
    const salt = await bcrypt.genSalt(12);
    return bcrypt.hash(password, salt);
  }
  
  static async comparePassword(password, hash) {
    return bcrypt.compare(password, hash);
  }
}

module.exports = AuthService;

Validation Service

// services/ValidationService.js
const validator = require('validator');

class ValidationService {
  static validateEmail(email) {
    if (!email || !validator.isEmail(email)) {
      return { valid: false, message: 'Invalid email' };
    }
    return { valid: true };
  }
  
  static validatePassword(password) {
    if (!password || password.length < 8) {
      return { valid: false, message: 'Password min 8 chars' };
    }
    if (!/[A-Z]/.test(password) || !/[a-z]/.test(password)) {
      return { valid: false, message: 'Need upper and lowercase' };
    }
    if (!/[0-9]/.test(password)) {
      return { valid: false, message: 'Need number' };
    }
    return { valid: true };
  }
  
  static sanitizeMessage(message) {
    if (typeof message !== 'string') return '';
    return validator.escape(message.trim());
  }
}

module.exports = ValidationService;

Auth Routes

// routes/auth.js
const express = require('express');
const rateLimit = require('express-rate-limit');
const AuthService = require('../services/AuthService');
const ValidationService = require('../services/ValidationService');

const router = express.Router();
const users = new Map();
const refreshTokens = new Map();

const authLimiter = rateLimit({
  windowMs: 15 * 60 * 1000,
  max: 5,
  message: 'Too many attempts'
});

router.post('/register', authLimiter, async (req, res) => {
  try {
    const { username, email, password } = req.body;
    
    const emailValid = ValidationService.validateEmail(email);
    if (!emailValid.valid) {
      return res.status(400).json({ error: emailValid.message });
    }
    
    const passValid = ValidationService.validatePassword(password);
    if (!passValid.valid) {
      return res.status(400).json({ error: passValid.message });
    }
    
    if (Array.from(users.values()).find(u => u.email === email)) {
      return res.status(409).json({ error: 'Email exists' });
    }
    
    const hashedPassword = await AuthService.hashPassword(password);
    const userId = require('crypto').randomBytes(16).toString('hex');
    
    users.set(userId, {
      id: userId,
      username,
      email,
      password: hashedPassword
    });
    
    const accessToken = AuthService.generateAccessToken({
      userId, username, email
    });
    const refreshToken = AuthService.generateRefreshToken({ userId });
    
    refreshTokens.set(refreshToken, { userId });
    
    res.status(201).json({
      success: true,
      accessToken,
      refreshToken,
      user: { id: userId, username, email }
    });
  } catch (error) {
    res.status(500).json({ error: 'Registration failed' });
  }
});

router.post('/login', authLimiter, async (req, res) => {
  try {
    const { email, password } = req.body;
    
    const user = Array.from(users.values()).find(u => u.email === email);
    if (!user) {
      return res.status(401).json({ error: 'Invalid credentials' });
    }
    
    const isValid = await AuthService.comparePassword(password, user.password);
    if (!isValid) {
      return res.status(401).json({ error: 'Invalid credentials' });
    }
    
    const accessToken = AuthService.generateAccessToken({
      userId: user.id,
      username: user.username,
      email: user.email
    });
    const refreshToken = AuthService.generateRefreshToken({ userId: user.id });
    
    refreshTokens.set(refreshToken, { userId: user.id });
    
    res.json({
      success: true,
      accessToken,
      refreshToken,
      user: { id: user.id, username: user.username, email: user.email }
    });
  } catch (error) {
    res.status(500).json({ error: 'Login failed' });
  }
});

module.exports = router;

Secure WebSocket Server

// server.js
require('dotenv').config();
const express = require('express');
const http = require('http');
const { Server } = require('socket.io');
const helmet = require('helmet');
const cors = require('cors');
const AuthService = require('./services/AuthService');
const ValidationService = require('./services/ValidationService');
const authRoutes = require('./routes/auth');

const app = express();
const server = http.createServer(app);

app.use(helmet());
app.use(express.json({ limit: '10kb' }));
app.use(cors({ origin: process.env.ALLOWED_ORIGINS.split(',') }));
app.use('/api/auth', authRoutes);

const io = new Server(server, {
  cors: { origin: process.env.ALLOWED_ORIGINS.split(',') },
  maxHttpBufferSize: 1e6
});

const connections = new Map();
const messageRateLimits = new Map();

// Auth middleware
io.use((socket, next) => {
  try {
    const token = socket.handshake.auth.token;
    if (!token) return next(new Error('NO_TOKEN'));
    
    const decoded = AuthService.verifyAccessToken(token);
    socket.user = {
      id: decoded.userId,
      username: decoded.username,
      email: decoded.email
    };
    next();
  } catch (error) {
    next(new Error('AUTH_FAILED'));
  }
});

// Origin validation
io.use((socket, next) => {
  const origin = socket.handshake.headers.origin;
  const allowed = process.env.ALLOWED_ORIGINS.split(',');
  
  if (!origin || !allowed.includes(origin)) {
    return next(new Error('INVALID_ORIGIN'));
  }
  next();
});

// Rate limiting
io.use((socket, next) => {
  const ip = socket.handshake.address;
  const attempts = connectionAttempts.get(ip) || [];
  const recent = attempts.filter(t => Date.now() - t < 60000);
  
  if (recent.length > 10) {
    return next(new Error('TOO_MANY_CONNECTIONS'));
  }
  
  recent.push(Date.now());
  connectionAttempts.set(ip, recent);
  next();
});

io.on('connection', (socket) => {
  console.log(`User: ${socket.user.username}`);
  
  connections.set(socket.id, {
    userId: socket.user.id,
    username: socket.user.username
  });
  
  messageRateLimits.set(socket.id, []);
  
  socket.on('send-message', (data) => {
    // Rate limit messages
    const limits = messageRateLimits.get(socket.id);
    const recent = limits.filter(t => Date.now() - t < 60000);
    
    if (recent.length > 30) {
      return socket.emit('error', { message: 'Rate limit exceeded' });
    }
    
    recent.push(Date.now());
    messageRateLimits.set(socket.id, recent);
    
    // Validate and sanitize
    const sanitized = ValidationService.sanitizeMessage(data.message);
    if (!sanitized || sanitized.length > 500) {
      return socket.emit('error', { message: 'Invalid message' });
    }
    
    io.emit('receive-message', {
      username: socket.user.username,
      message: sanitized,
      timestamp: new Date().toISOString()
    });
  });
  
  socket.on('disconnect', () => {
    connections.delete(socket.id);
    messageRateLimits.delete(socket.id);
  });
});

const connectionAttempts = new Map();

server.listen(process.env.PORT || 3000);

Client Implementation

// client.js
let socket;
let token = localStorage.getItem('token');
let refreshToken = localStorage.getItem('refreshToken');

async function login(email, password) {
  const res = await fetch('/api/auth/login', {
    method: 'POST',
    headers: { 'Content-Type': 'application/json' },
    body: JSON.stringify({ email, password })
  });
  
  const data = await res.json();
  if (!res.ok) throw new Error(data.error);
  
  token = data.accessToken;
  refreshToken = data.refreshToken;
  localStorage.setItem('token', token);
  localStorage.setItem('refreshToken', refreshToken);
  
  connectSocket();
  return data.user;
}

function connectSocket() {
  socket = io({
    auth: { token }
  });
  
  socket.on('connect_error', async (err) => {
    if (err.message === 'TOKEN_EXPIRED') {
      await refreshAccessToken();
      connectSocket();
    } else {
      logout();
    }
  });
}

async function refreshAccessToken() {
  const res = await fetch('/api/auth/refresh', {
    method: 'POST',
    headers: { 'Content-Type': 'application/json' },
    body: JSON.stringify({ refreshToken })
  });
  
  const data = await res.json();
  token = data.accessToken;
  localStorage.setItem('token', token);
}

function logout() {
  localStorage.removeItem('token');
  localStorage.removeItem('refreshToken');
  if (socket) socket.disconnect();
  window.location.href = '/login.html';
}

Security Best Practices

  1. Use WSS (wss://) always in production
  2. Validate Origin header on every connection
  3. Implement rate limiting for connections and messages
  4. Sanitize all input to prevent XSS
  5. Use strong JWT secrets (256+ bits)
  6. Set short token expiration (1 hour)
  7. Implement refresh tokens for better UX
  8. Hash passwords with bcrypt (cost factor 12+)
  9. Validate all data before processing
  10. Limit message size (prevent DoS)
  11. Monitor and log security events
  12. Use HTTPS for all HTTP endpoints

CSRF Protection

Implement CSRF tokens for additional security:

// Generate CSRF token on login
const csrfToken = crypto.randomBytes(32).toString('hex');

// Send in WebSocket auth
socket = io({
  auth: { token, csrfToken }
});

// Verify on server
io.use((socket, next) => {
  const csrf = socket.handshake.auth.csrfToken;
  if (!verifyCSRF(csrf)) {
    return next(new Error('CSRF_FAILED'));
  }
  next();
});

What's Next

In Part 5: Scaling with Redis, we'll scale across multiple servers using Redis pub/sub and sticky sessions!


Part 4 of the 8-part Real-Time WebSocket Architecture Series.

Navigate<< Real-Time WebSocket Architecture Series: Part 3 – Essential Features (Rooms, Namespaces & Events)Real-Time WebSocket Architecture Series: Part 5 – Scaling with Redis >>

Written by:

373 Posts

View All Posts
Follow Me :
How to whitelist website on AdBlocker?

How to whitelist website on AdBlocker?

  1. 1 Click on the AdBlock Plus icon on the top right corner of your browser
  2. 2 Click on "Enabled on this site" from the AdBlock Plus option
  3. 3 Refresh the page and start browsing the site