0
0
Fork 0

Streaming: replace npmlog with pino & pino-http (#27828)

This commit is contained in:
Emelia Smith 2024-01-18 19:40:25 +01:00 committed by GitHub
parent f866413e72
commit 1335083bed
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 593 additions and 252 deletions

View file

@ -10,12 +10,11 @@ const dotenv = require('dotenv');
const express = require('express');
const Redis = require('ioredis');
const { JSDOM } = require('jsdom');
const log = require('npmlog');
const pg = require('pg');
const dbUrlToConfig = require('pg-connection-string').parse;
const uuid = require('uuid');
const WebSocket = require('ws');
const { logger, httpLogger, initializeLogLevel, attachWebsocketHttpLogger, createWebsocketLogger } = require('./logging');
const { setupMetrics } = require('./metrics');
const { isTruthy } = require("./utils");
@ -28,15 +27,30 @@ dotenv.config({
path: path.resolve(__dirname, path.join('..', dotenvFile))
});
log.level = process.env.LOG_LEVEL || 'verbose';
initializeLogLevel(process.env, environment);
/**
* Declares the result type for accountFromToken / accountFromRequest.
*
* Note: This is here because jsdoc doesn't like importing types that
* are nested in functions
* @typedef ResolvedAccount
* @property {string} accessTokenId
* @property {string[]} scopes
* @property {string} accountId
* @property {string[]} chosenLanguages
* @property {string} deviceId
*/
/**
* @param {Object.<string, any>} config
*/
const createRedisClient = async (config) => {
const { redisParams, redisUrl } = config;
// @ts-ignore
const client = new Redis(redisUrl, redisParams);
client.on('error', (err) => log.error('Redis Client Error!', err));
// @ts-ignore
client.on('error', (err) => logger.error({ err }, 'Redis Client Error!'));
return client;
};
@ -61,12 +75,12 @@ const parseJSON = (json, req) => {
*/
if (req) {
if (req.accountId) {
log.warn(req.requestId, `Error parsing message from user ${req.accountId}: ${err}`);
req.log.error({ err }, `Error parsing message from user ${req.accountId}`);
} else {
log.silly(req.requestId, `Error parsing message from ${req.remoteAddress}: ${err}`);
req.log.error({ err }, `Error parsing message from ${req.remoteAddress}`);
}
} else {
log.warn(`Error parsing message from redis: ${err}`);
logger.error({ err }, `Error parsing message from redis`);
}
return null;
}
@ -105,6 +119,7 @@ const pgConfigFromEnv = (env) => {
baseConfig.password = env.DB_PASS;
}
} else {
// @ts-ignore
baseConfig = pgConfigs[environment];
if (env.DB_SSLMODE) {
@ -149,6 +164,7 @@ const redisConfigFromEnv = (env) => {
// redisParams.path takes precedence over host and port.
if (env.REDIS_URL && env.REDIS_URL.startsWith('unix://')) {
// @ts-ignore
redisParams.path = env.REDIS_URL.slice(7);
}
@ -195,6 +211,7 @@ const startServer = async () => {
app.set('trust proxy', process.env.TRUSTED_PROXY_IP ? process.env.TRUSTED_PROXY_IP.split(/(?:\s*,\s*|\s+)/) : 'loopback,uniquelocal');
app.use(httpLogger);
app.use(cors());
// Handle eventsource & other http requests:
@ -202,32 +219,37 @@ const startServer = async () => {
// Handle upgrade requests:
server.on('upgrade', async function handleUpgrade(request, socket, head) {
// Setup the HTTP logger, since websocket upgrades don't get the usual http
// logger. This decorates the `request` object.
attachWebsocketHttpLogger(request);
request.log.info("HTTP Upgrade Requested");
/** @param {Error} err */
const onSocketError = (err) => {
log.error(`Error with websocket upgrade: ${err}`);
request.log.error({ error: err }, err.message);
};
socket.on('error', onSocketError);
// Authenticate:
try {
await accountFromRequest(request);
} catch (err) {
log.error(`Error authenticating request: ${err}`);
/** @type {ResolvedAccount} */
let resolvedAccount;
try {
resolvedAccount = await accountFromRequest(request);
} catch (err) {
// Unfortunately for using the on('upgrade') setup, we need to manually
// write a HTTP Response to the Socket to close the connection upgrade
// attempt, so the following code is to handle all of that.
const statusCode = err.status ?? 401;
/** @type {Record<string, string | number>} */
/** @type {Record<string, string | number | import('pino-http').ReqId>} */
const headers = {
'Connection': 'close',
'Content-Type': 'text/plain',
'Content-Length': 0,
'X-Request-Id': request.id,
// TODO: Send the error message via header so it can be debugged in
// developer tools
'X-Error-Message': err.status ? err.toString() : 'An unexpected error occurred'
};
// Ensure the socket is closed once we've finished writing to it:
@ -238,15 +260,28 @@ const startServer = async () => {
// Write the HTTP response manually:
socket.end(`HTTP/1.1 ${statusCode} ${http.STATUS_CODES[statusCode]}\r\n${Object.keys(headers).map((key) => `${key}: ${headers[key]}`).join('\r\n')}\r\n\r\n`);
// Finally, log the error:
request.log.error({
err,
res: {
statusCode,
headers
}
}, err.toString());
return;
}
// Remove the error handler, wss.handleUpgrade has its own:
socket.removeListener('error', onSocketError);
wss.handleUpgrade(request, socket, head, function done(ws) {
// Remove the error handler:
socket.removeListener('error', onSocketError);
request.log.info("Authenticated request & upgraded to WebSocket connection");
const wsLogger = createWebsocketLogger(request, resolvedAccount);
// Start the connection:
wss.emit('connection', ws, request);
wss.emit('connection', ws, request, wsLogger);
});
});
@ -273,9 +308,9 @@ const startServer = async () => {
// When checking metrics in the browser, the favicon is requested this
// prevents the request from falling through to the API Router, which would
// error for this endpoint:
app.get('/favicon.ico', (req, res) => res.status(404).end());
app.get('/favicon.ico', (_req, res) => res.status(404).end());
app.get('/api/v1/streaming/health', (req, res) => {
app.get('/api/v1/streaming/health', (_req, res) => {
res.writeHead(200, { 'Content-Type': 'text/plain' });
res.end('OK');
});
@ -285,7 +320,7 @@ const startServer = async () => {
res.set('Content-Type', metrics.register.contentType);
res.end(await metrics.register.metrics());
} catch (ex) {
log.error(ex);
req.log.error(ex);
res.status(500).end();
}
});
@ -319,7 +354,7 @@ const startServer = async () => {
const callbacks = subs[channel];
log.silly(`New message on channel ${redisPrefix}${channel}`);
logger.debug(`New message on channel ${redisPrefix}${channel}`);
if (!callbacks) {
return;
@ -343,17 +378,16 @@ const startServer = async () => {
* @param {SubscriptionListener} callback
*/
const subscribe = (channel, callback) => {
log.silly(`Adding listener for ${channel}`);
logger.debug(`Adding listener for ${channel}`);
subs[channel] = subs[channel] || [];
if (subs[channel].length === 0) {
log.verbose(`Subscribe ${channel}`);
logger.debug(`Subscribe ${channel}`);
redisSubscribeClient.subscribe(channel, (err, count) => {
if (err) {
log.error(`Error subscribing to ${channel}`);
}
else {
logger.error(`Error subscribing to ${channel}`);
} else if (typeof count === 'number') {
redisSubscriptions.set(count);
}
});
@ -367,7 +401,7 @@ const startServer = async () => {
* @param {SubscriptionListener} callback
*/
const unsubscribe = (channel, callback) => {
log.silly(`Removing listener for ${channel}`);
logger.debug(`Removing listener for ${channel}`);
if (!subs[channel]) {
return;
@ -376,12 +410,11 @@ const startServer = async () => {
subs[channel] = subs[channel].filter(item => item !== callback);
if (subs[channel].length === 0) {
log.verbose(`Unsubscribe ${channel}`);
logger.debug(`Unsubscribe ${channel}`);
redisSubscribeClient.unsubscribe(channel, (err, count) => {
if (err) {
log.error(`Error unsubscribing to ${channel}`);
}
else {
logger.error(`Error unsubscribing to ${channel}`);
} else if (typeof count === 'number') {
redisSubscriptions.set(count);
}
});
@ -390,45 +423,13 @@ const startServer = async () => {
};
/**
* @param {any} req
* @param {any} res
* @param {function(Error=): void} next
*/
const setRequestId = (req, res, next) => {
req.requestId = uuid.v4();
res.header('X-Request-Id', req.requestId);
next();
};
/**
* @param {any} req
* @param {any} res
* @param {function(Error=): void} next
*/
const setRemoteAddress = (req, res, next) => {
req.remoteAddress = req.connection.remoteAddress;
next();
};
/**
* @param {any} req
* @param {http.IncomingMessage & ResolvedAccount} req
* @param {string[]} necessaryScopes
* @returns {boolean}
*/
const isInScope = (req, necessaryScopes) =>
req.scopes.some(scope => necessaryScopes.includes(scope));
/**
* @typedef ResolvedAccount
* @property {string} accessTokenId
* @property {string[]} scopes
* @property {string} accountId
* @property {string[]} chosenLanguages
* @property {string} deviceId
*/
/**
* @param {string} token
* @param {any} req
@ -441,6 +442,7 @@ const startServer = async () => {
return;
}
// @ts-ignore
client.query('SELECT oauth_access_tokens.id, oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes, devices.device_id FROM oauth_access_tokens INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id LEFT OUTER JOIN devices ON oauth_access_tokens.id = devices.access_token_id WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL LIMIT 1', [token], (err, result) => {
done();
@ -451,6 +453,7 @@ const startServer = async () => {
if (result.rows.length === 0) {
err = new Error('Invalid access token');
// @ts-ignore
err.status = 401;
reject(err);
@ -485,6 +488,7 @@ const startServer = async () => {
if (!authorization && !accessToken) {
const err = new Error('Missing access token');
// @ts-ignore
err.status = 401;
reject(err);
@ -529,15 +533,16 @@ const startServer = async () => {
};
/**
* @param {any} req
* @param {http.IncomingMessage & ResolvedAccount} req
* @param {import('pino').Logger} logger
* @param {string|undefined} channelName
* @returns {Promise.<void>}
*/
const checkScopes = (req, channelName) => new Promise((resolve, reject) => {
log.silly(req.requestId, `Checking OAuth scopes for ${channelName}`);
const checkScopes = (req, logger, channelName) => new Promise((resolve, reject) => {
logger.debug(`Checking OAuth scopes for ${channelName}`);
// When accessing public channels, no scopes are needed
if (PUBLIC_CHANNELS.includes(channelName)) {
if (channelName && PUBLIC_CHANNELS.includes(channelName)) {
resolve();
return;
}
@ -564,6 +569,7 @@ const startServer = async () => {
}
const err = new Error('Access token does not cover required scopes');
// @ts-ignore
err.status = 401;
reject(err);
@ -577,38 +583,40 @@ const startServer = async () => {
/**
* @param {any} req
* @param {SystemMessageHandlers} eventHandlers
* @returns {function(object): void}
* @returns {SubscriptionListener}
*/
const createSystemMessageListener = (req, eventHandlers) => {
return message => {
if (!message?.event) {
return;
}
const { event } = message;
log.silly(req.requestId, `System message for ${req.accountId}: ${event}`);
req.log.debug(`System message for ${req.accountId}: ${event}`);
if (event === 'kill') {
log.verbose(req.requestId, `Closing connection for ${req.accountId} due to expired access token`);
req.log.debug(`Closing connection for ${req.accountId} due to expired access token`);
eventHandlers.onKill();
} else if (event === 'filters_changed') {
log.verbose(req.requestId, `Invalidating filters cache for ${req.accountId}`);
req.log.debug(`Invalidating filters cache for ${req.accountId}`);
req.cachedFilters = null;
}
};
};
/**
* @param {any} req
* @param {any} res
* @param {http.IncomingMessage & ResolvedAccount} req
* @param {http.OutgoingMessage} res
*/
const subscribeHttpToSystemChannel = (req, res) => {
const accessTokenChannelId = `timeline:access_token:${req.accessTokenId}`;
const systemChannelId = `timeline:system:${req.accountId}`;
const listener = createSystemMessageListener(req, {
onKill() {
res.end();
},
});
res.on('close', () => {
@ -641,13 +649,14 @@ const startServer = async () => {
// the connection, as there's nothing to stream back
if (!channelName) {
const err = new Error('Unknown channel requested');
// @ts-ignore
err.status = 400;
next(err);
return;
}
accountFromRequest(req).then(() => checkScopes(req, channelName)).then(() => {
accountFromRequest(req).then(() => checkScopes(req, req.log, channelName)).then(() => {
subscribeHttpToSystemChannel(req, res);
}).then(() => {
next();
@ -663,22 +672,28 @@ const startServer = async () => {
* @param {function(Error=): void} next
*/
const errorMiddleware = (err, req, res, next) => {
log.error(req.requestId, err.toString());
req.log.error({ err }, err.toString());
if (res.headersSent) {
next(err);
return;
}
res.writeHead(err.status || 500, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: err.status ? err.toString() : 'An unexpected error occurred' }));
const hasStatusCode = Object.hasOwnProperty.call(err, 'status');
// @ts-ignore
const statusCode = hasStatusCode ? err.status : 500;
const errorMessage = hasStatusCode ? err.toString() : 'An unexpected error occurred';
res.writeHead(statusCode, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: errorMessage }));
};
/**
* @param {array} arr
* @param {any[]} arr
* @param {number=} shift
* @returns {string}
*/
// @ts-ignore
const placeholders = (arr, shift = 0) => arr.map((_, i) => `$${i + 1 + shift}`).join(', ');
/**
@ -695,6 +710,7 @@ const startServer = async () => {
return;
}
// @ts-ignore
client.query('SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1', [listId], (err, result) => {
done();
@ -709,34 +725,43 @@ const startServer = async () => {
});
/**
* @param {string[]} ids
* @param {any} req
* @param {string[]} channelIds
* @param {http.IncomingMessage & ResolvedAccount} req
* @param {import('pino').Logger} log
* @param {function(string, string): void} output
* @param {undefined | function(string[], SubscriptionListener): void} attachCloseHandler
* @param {'websocket' | 'eventsource'} destinationType
* @param {boolean=} needsFiltering
* @returns {SubscriptionListener}
*/
const streamFrom = (ids, req, output, attachCloseHandler, destinationType, needsFiltering = false) => {
const accountId = req.accountId || req.remoteAddress;
log.verbose(req.requestId, `Starting stream from ${ids.join(', ')} for ${accountId}`);
const streamFrom = (channelIds, req, log, output, attachCloseHandler, destinationType, needsFiltering = false) => {
log.info({ channelIds }, `Starting stream`);
/**
* @param {string} event
* @param {object|string} payload
*/
const transmit = (event, payload) => {
// TODO: Replace "string"-based delete payloads with object payloads:
const encodedPayload = typeof payload === 'object' ? JSON.stringify(payload) : payload;
messagesSent.labels({ type: destinationType }).inc(1);
log.silly(req.requestId, `Transmitting for ${accountId}: ${event} ${encodedPayload}`);
log.debug({ event, payload }, `Transmitting ${event} to ${req.accountId}`);
output(event, encodedPayload);
};
// The listener used to process each message off the redis subscription,
// message here is an object with an `event` and `payload` property. Some
// events also include a queued_at value, but this is being removed shortly.
/** @type {SubscriptionListener} */
const listener = message => {
if (!message?.event || !message?.payload) {
return;
}
const { event, payload } = message;
// Streaming only needs to apply filtering to some channels and only to
@ -759,7 +784,7 @@ const startServer = async () => {
// Filter based on language:
if (Array.isArray(req.chosenLanguages) && payload.language !== null && req.chosenLanguages.indexOf(payload.language) === -1) {
log.silly(req.requestId, `Message ${payload.id} filtered by language (${payload.language})`);
log.debug(`Message ${payload.id} filtered by language (${payload.language})`);
return;
}
@ -770,6 +795,7 @@ const startServer = async () => {
}
// Filter based on domain blocks, blocks, mutes, or custom filters:
// @ts-ignore
const targetAccountIds = [payload.account.id].concat(payload.mentions.map(item => item.id));
const accountDomain = payload.account.acct.split('@')[1];
@ -781,6 +807,7 @@ const startServer = async () => {
}
const queries = [
// @ts-ignore
client.query(`SELECT 1
FROM blocks
WHERE (account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)}))
@ -793,10 +820,13 @@ const startServer = async () => {
];
if (accountDomain) {
// @ts-ignore
queries.push(client.query('SELECT 1 FROM account_domain_blocks WHERE account_id = $1 AND domain = $2', [req.accountId, accountDomain]));
}
// @ts-ignore
if (!payload.filtered && !req.cachedFilters) {
// @ts-ignore
queries.push(client.query('SELECT filter.id AS id, filter.phrase AS title, filter.context AS context, filter.expires_at AS expires_at, filter.action AS filter_action, keyword.keyword AS keyword, keyword.whole_word AS whole_word FROM custom_filter_keywords keyword JOIN custom_filters filter ON keyword.custom_filter_id = filter.id WHERE filter.account_id = $1 AND (filter.expires_at IS NULL OR filter.expires_at > NOW())', [req.accountId]));
}
@ -819,9 +849,11 @@ const startServer = async () => {
// Handling for constructing the custom filters and caching them on the request
// TODO: Move this logic out of the message handling lifecycle
// @ts-ignore
if (!req.cachedFilters) {
const filterRows = values[accountDomain ? 2 : 1].rows;
// @ts-ignore
req.cachedFilters = filterRows.reduce((cache, filter) => {
if (cache[filter.id]) {
cache[filter.id].keywords.push([filter.keyword, filter.whole_word]);
@ -851,7 +883,9 @@ const startServer = async () => {
// needs to be done in a separate loop as the database returns one
// filterRow per keyword, so we need all the keywords before
// constructing the regular expression
// @ts-ignore
Object.keys(req.cachedFilters).forEach((key) => {
// @ts-ignore
req.cachedFilters[key].regexp = new RegExp(req.cachedFilters[key].keywords.map(([keyword, whole_word]) => {
let expr = keyword.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
@ -872,13 +906,16 @@ const startServer = async () => {
// Apply cachedFilters against the payload, constructing a
// `filter_results` array of FilterResult entities
// @ts-ignore
if (req.cachedFilters) {
const status = payload;
// TODO: Calculate searchableContent in Ruby on Rails:
// @ts-ignore
const searchableContent = ([status.spoiler_text || '', status.content].concat((status.poll && status.poll.options) ? status.poll.options.map(option => option.title) : [])).concat(status.media_attachments.map(att => att.description)).join('\n\n').replace(/<br\s*\/?>/g, '\n').replace(/<\/p><p>/g, '\n\n');
const searchableTextContent = JSDOM.fragment(searchableContent).textContent;
const now = new Date();
// @ts-ignore
const filter_results = Object.values(req.cachedFilters).reduce((results, cachedFilter) => {
// Check the filter hasn't expired before applying:
if (cachedFilter.expires_at !== null && cachedFilter.expires_at < now) {
@ -926,12 +963,12 @@ const startServer = async () => {
});
};
ids.forEach(id => {
channelIds.forEach(id => {
subscribe(`${redisPrefix}${id}`, listener);
});
if (typeof attachCloseHandler === 'function') {
attachCloseHandler(ids.map(id => `${redisPrefix}${id}`), listener);
attachCloseHandler(channelIds.map(id => `${redisPrefix}${id}`), listener);
}
return listener;
@ -943,8 +980,6 @@ const startServer = async () => {
* @returns {function(string, string): void}
*/
const streamToHttp = (req, res) => {
const accountId = req.accountId || req.remoteAddress;
const channelName = channelNameFromPath(req);
connectedClients.labels({ type: 'eventsource' }).inc();
@ -963,7 +998,8 @@ const startServer = async () => {
const heartbeat = setInterval(() => res.write(':thump\n'), 15000);
req.on('close', () => {
log.verbose(req.requestId, `Ending stream for ${accountId}`);
req.log.info({ accountId: req.accountId }, `Ending stream`);
// We decrement these counters here instead of in streamHttpEnd as in that
// method we don't have knowledge of the channel names
connectedClients.labels({ type: 'eventsource' }).dec();
@ -1007,15 +1043,15 @@ const startServer = async () => {
*/
const streamToWs = (req, ws, streamName) => (event, payload) => {
if (ws.readyState !== ws.OPEN) {
log.error(req.requestId, 'Tried writing to closed socket');
req.log.error('Tried writing to closed socket');
return;
}
const message = JSON.stringify({ stream: streamName, event, payload });
ws.send(message, (/** @type {Error} */ err) => {
ws.send(message, (/** @type {Error|undefined} */ err) => {
if (err) {
log.error(req.requestId, `Failed to send to websocket: ${err}`);
req.log.error({err}, `Failed to send to websocket`);
}
});
};
@ -1032,20 +1068,19 @@ const startServer = async () => {
app.use(api);
api.use(setRequestId);
api.use(setRemoteAddress);
api.use(authenticationMiddleware);
api.use(errorMiddleware);
api.get('/api/v1/streaming/*', (req, res) => {
// @ts-ignore
channelNameToIds(req, channelNameFromPath(req), req.query).then(({ channelIds, options }) => {
const onSend = streamToHttp(req, res);
const onEnd = streamHttpEnd(req, subscriptionHeartbeat(channelIds));
streamFrom(channelIds, req, onSend, onEnd, 'eventsource', options.needsFiltering);
// @ts-ignore
streamFrom(channelIds, req, req.log, onSend, onEnd, 'eventsource', options.needsFiltering);
}).catch(err => {
log.verbose(req.requestId, 'Subscription error:', err.toString());
res.log.info({ err }, 'Subscription error:', err.toString());
httpNotFound(res);
});
});
@ -1197,6 +1232,7 @@ const startServer = async () => {
break;
case 'list':
// @ts-ignore
authorizeListAccess(params.list, req).then(() => {
resolve({
channelIds: [`timeline:list:${params.list}`],
@ -1218,9 +1254,9 @@ const startServer = async () => {
* @returns {string[]}
*/
const streamNameFromChannelName = (channelName, params) => {
if (channelName === 'list') {
if (channelName === 'list' && params.list) {
return [channelName, params.list];
} else if (['hashtag', 'hashtag:local'].includes(channelName)) {
} else if (['hashtag', 'hashtag:local'].includes(channelName) && params.tag) {
return [channelName, params.tag];
} else {
return [channelName];
@ -1229,8 +1265,9 @@ const startServer = async () => {
/**
* @typedef WebSocketSession
* @property {WebSocket} websocket
* @property {http.IncomingMessage} request
* @property {WebSocket & { isAlive: boolean}} websocket
* @property {http.IncomingMessage & ResolvedAccount} request
* @property {import('pino').Logger} logger
* @property {Object.<string, { channelName: string, listener: SubscriptionListener, stopHeartbeat: function(): void }>} subscriptions
*/
@ -1240,8 +1277,8 @@ const startServer = async () => {
* @param {StreamParams} params
* @returns {void}
*/
const subscribeWebsocketToChannel = ({ socket, request, subscriptions }, channelName, params) => {
checkScopes(request, channelName).then(() => channelNameToIds(request, channelName, params)).then(({
const subscribeWebsocketToChannel = ({ websocket, request, logger, subscriptions }, channelName, params) => {
checkScopes(request, logger, channelName).then(() => channelNameToIds(request, channelName, params)).then(({
channelIds,
options,
}) => {
@ -1249,9 +1286,9 @@ const startServer = async () => {
return;
}
const onSend = streamToWs(request, socket, streamNameFromChannelName(channelName, params));
const onSend = streamToWs(request, websocket, streamNameFromChannelName(channelName, params));
const stopHeartbeat = subscriptionHeartbeat(channelIds);
const listener = streamFrom(channelIds, request, onSend, undefined, 'websocket', options.needsFiltering);
const listener = streamFrom(channelIds, request, logger, onSend, undefined, 'websocket', options.needsFiltering);
connectedChannels.labels({ type: 'websocket', channel: channelName }).inc();
@ -1261,14 +1298,17 @@ const startServer = async () => {
stopHeartbeat,
};
}).catch(err => {
log.verbose(request.requestId, 'Subscription error:', err.toString());
socket.send(JSON.stringify({ error: err.toString() }));
logger.error({ err }, 'Subscription error');
websocket.send(JSON.stringify({ error: err.toString() }));
});
};
const removeSubscription = (subscriptions, channelIds, request) => {
log.verbose(request.requestId, `Ending stream from ${channelIds.join(', ')} for ${request.accountId}`);
/**
* @param {WebSocketSession} session
* @param {string[]} channelIds
*/
const removeSubscription = ({ request, logger, subscriptions }, channelIds) => {
logger.info({ channelIds, accountId: request.accountId }, `Ending stream`);
const subscription = subscriptions[channelIds.join(';')];
@ -1292,16 +1332,17 @@ const startServer = async () => {
* @param {StreamParams} params
* @returns {void}
*/
const unsubscribeWebsocketFromChannel = ({ socket, request, subscriptions }, channelName, params) => {
const unsubscribeWebsocketFromChannel = (session, channelName, params) => {
const { websocket, request, logger } = session;
channelNameToIds(request, channelName, params).then(({ channelIds }) => {
removeSubscription(subscriptions, channelIds, request);
removeSubscription(session, channelIds);
}).catch(err => {
log.verbose(request.requestId, 'Unsubscribe error:', err);
logger.error({err}, 'Unsubscribe error');
// If we have a socket that is alive and open still, send the error back to the client:
// FIXME: In other parts of the code ws === socket
if (socket.isAlive && socket.readyState === socket.OPEN) {
socket.send(JSON.stringify({ error: "Error unsubscribing from channel" }));
if (websocket.isAlive && websocket.readyState === websocket.OPEN) {
websocket.send(JSON.stringify({ error: "Error unsubscribing from channel" }));
}
});
};
@ -1309,16 +1350,14 @@ const startServer = async () => {
/**
* @param {WebSocketSession} session
*/
const subscribeWebsocketToSystemChannel = ({ socket, request, subscriptions }) => {
const subscribeWebsocketToSystemChannel = ({ websocket, request, subscriptions }) => {
const accessTokenChannelId = `timeline:access_token:${request.accessTokenId}`;
const systemChannelId = `timeline:system:${request.accountId}`;
const listener = createSystemMessageListener(request, {
onKill() {
socket.close();
websocket.close();
},
});
subscribe(`${redisPrefix}${accessTokenChannelId}`, listener);
@ -1355,18 +1394,15 @@ const startServer = async () => {
/**
* @param {WebSocket & { isAlive: boolean }} ws
* @param {http.IncomingMessage} req
* @param {http.IncomingMessage & ResolvedAccount} req
* @param {import('pino').Logger} log
*/
function onConnection(ws, req) {
function onConnection(ws, req, log) {
// Note: url.parse could throw, which would terminate the connection, so we
// increment the connected clients metric straight away when we establish
// the connection, without waiting:
connectedClients.labels({ type: 'websocket' }).inc();
// Setup request properties:
req.requestId = uuid.v4();
req.remoteAddress = ws._socket.remoteAddress;
// Setup connection keep-alive state:
ws.isAlive = true;
ws.on('pong', () => {
@ -1377,8 +1413,9 @@ const startServer = async () => {
* @type {WebSocketSession}
*/
const session = {
socket: ws,
websocket: ws,
request: req,
logger: log,
subscriptions: {},
};
@ -1386,27 +1423,30 @@ const startServer = async () => {
const subscriptions = Object.keys(session.subscriptions);
subscriptions.forEach(channelIds => {
removeSubscription(session.subscriptions, channelIds.split(';'), req);
removeSubscription(session, channelIds.split(';'));
});
// Decrement the metrics for connected clients:
connectedClients.labels({ type: 'websocket' }).dec();
// ensure garbage collection:
session.socket = null;
session.request = null;
session.subscriptions = {};
// We need to delete the session object as to ensure it correctly gets
// garbage collected, without doing this we could accidentally hold on to
// references to the websocket, the request, and the logger, causing
// memory leaks.
//
// @ts-ignore
delete session;
});
// Note: immediately after the `error` event is emitted, the `close` event
// is emitted. As such, all we need to do is log the error here.
ws.on('error', (err) => {
log.error('websocket', err.toString());
ws.on('error', (/** @type {Error} */ err) => {
log.error(err);
});
ws.on('message', (data, isBinary) => {
if (isBinary) {
log.warn('websocket', 'Received binary data, closing connection');
log.warn('Received binary data, closing connection');
ws.close(1003, 'The mastodon streaming server does not support binary messages');
return;
}
@ -1441,18 +1481,20 @@ const startServer = async () => {
setInterval(() => {
wss.clients.forEach(ws => {
// @ts-ignore
if (ws.isAlive === false) {
ws.terminate();
return;
}
// @ts-ignore
ws.isAlive = false;
ws.ping('', false);
});
}, 30000);
attachServerWithConfig(server, address => {
log.warn(`Streaming API now listening on ${address}`);
logger.info(`Streaming API now listening on ${address}`);
});
const onExit = () => {
@ -1460,8 +1502,10 @@ const startServer = async () => {
process.exit(0);
};
/** @param {Error} err */
const onError = (err) => {
log.error(err);
logger.error(err);
server.close();
process.exit(0);
};
@ -1485,7 +1529,7 @@ const attachServerWithConfig = (server, onSuccess) => {
}
});
} else {
server.listen(+process.env.PORT || 4000, process.env.BIND || '127.0.0.1', () => {
server.listen(+(process.env.PORT || 4000), process.env.BIND || '127.0.0.1', () => {
if (onSuccess) {
onSuccess(`${server.address().address}:${server.address().port}`);
}