From ef057584fd2714d94666f9ffef4aa89147eda72c Mon Sep 17 00:00:00 2001 From: Eugen Rochko Date: Tue, 11 Aug 2020 18:24:59 +0200 Subject: [PATCH] Add support for managing multiple stream subscriptions in a single connection (#14524) --- app/javascript/mastodon/actions/streaming.js | 99 ++- app/javascript/mastodon/stream.js | 288 +++++-- streaming/index.js | 754 +++++++++++++------ 3 files changed, 817 insertions(+), 324 deletions(-) diff --git a/app/javascript/mastodon/actions/streaming.js b/app/javascript/mastodon/actions/streaming.js index d998fcac48..beb5c6a4a9 100644 --- a/app/javascript/mastodon/actions/streaming.js +++ b/app/javascript/mastodon/actions/streaming.js @@ -1,3 +1,5 @@ +// @ts-check + import { connectStream } from '../stream'; import { updateTimeline, @@ -19,24 +21,59 @@ import { getLocale } from '../locales'; const { messages } = getLocale(); -export function connectTimelineStream (timelineId, path, pollingRefresh = null, accept = null) { +/** + * @param {number} max + * @return {number} + */ +const randomUpTo = max => + Math.floor(Math.random() * Math.floor(max)); - return connectStream (path, pollingRefresh, (dispatch, getState) => { +/** + * @param {string} timelineId + * @param {string} channelName + * @param {Object.} params + * @param {Object} options + * @param {function(Function, Function): void} [options.fallback] + * @param {function(object): boolean} [options.accept] + * @return {function(): void} + */ +export const connectTimelineStream = (timelineId, channelName, params = {}, options = {}) => + connectStream(channelName, params, (dispatch, getState) => { const locale = getState().getIn(['meta', 'locale']); + let pollingId; + + /** + * @param {function(Function, Function): void} fallback + */ + const useFallback = fallback => { + fallback(dispatch, () => { + pollingId = setTimeout(() => useFallback(fallback), 20000 + randomUpTo(20000)); + }); + }; + return { onConnect() { dispatch(connectTimeline(timelineId)); + + if (pollingId) { + clearTimeout(pollingId); + pollingId = null; + } }, onDisconnect() { dispatch(disconnectTimeline(timelineId)); + + if (options.fallback) { + pollingId = setTimeout(() => useFallback(options.fallback), randomUpTo(40000)); + } }, onReceive (data) { switch(data.event) { case 'update': - dispatch(updateTimeline(timelineId, JSON.parse(data.payload), accept)); + dispatch(updateTimeline(timelineId, JSON.parse(data.payload), options.accept)); break; case 'delete': dispatch(deleteFromTimelines(data.payload)); @@ -63,17 +100,59 @@ export function connectTimelineStream (timelineId, path, pollingRefresh = null, }, }; }); -} +/** + * @param {Function} dispatch + * @param {function(): void} done + */ const refreshHomeTimelineAndNotification = (dispatch, done) => { dispatch(expandHomeTimeline({}, () => dispatch(expandNotifications({}, () => dispatch(fetchAnnouncements(done)))))); }; -export const connectUserStream = () => connectTimelineStream('home', 'user', refreshHomeTimelineAndNotification); -export const connectCommunityStream = ({ onlyMedia } = {}) => connectTimelineStream(`community${onlyMedia ? ':media' : ''}`, `public:local${onlyMedia ? ':media' : ''}`); -export const connectPublicStream = ({ onlyMedia, onlyRemote } = {}) => connectTimelineStream(`public${onlyRemote ? ':remote' : ''}${onlyMedia ? ':media' : ''}`, `public${onlyRemote ? ':remote' : ''}${onlyMedia ? ':media' : ''}`); -export const connectHashtagStream = (id, tag, local, accept) => connectTimelineStream(`hashtag:${id}${local ? ':local' : ''}`, `hashtag${local ? ':local' : ''}&tag=${tag}`, null, accept); -export const connectDirectStream = () => connectTimelineStream('direct', 'direct'); -export const connectListStream = id => connectTimelineStream(`list:${id}`, `list&list=${id}`); +/** + * @return {function(): void} + */ +export const connectUserStream = () => + connectTimelineStream('home', 'user', {}, { fallback: refreshHomeTimelineAndNotification }); + +/** + * @param {Object} options + * @param {boolean} [options.onlyMedia] + * @return {function(): void} + */ +export const connectCommunityStream = ({ onlyMedia } = {}) => + connectTimelineStream(`community${onlyMedia ? ':media' : ''}`, `public:local${onlyMedia ? ':media' : ''}`); + +/** + * @param {Object} options + * @param {boolean} [options.onlyMedia] + * @param {boolean} [options.onlyRemote] + * @return {function(): void} + */ +export const connectPublicStream = ({ onlyMedia, onlyRemote } = {}) => + connectTimelineStream(`public${onlyRemote ? ':remote' : ''}${onlyMedia ? ':media' : ''}`, `public${onlyRemote ? ':remote' : ''}${onlyMedia ? ':media' : ''}`); + +/** + * @param {string} columnId + * @param {string} tagName + * @param {boolean} onlyLocal + * @param {function(object): boolean} accept + * @return {function(): void} + */ +export const connectHashtagStream = (columnId, tagName, onlyLocal, accept) => + connectTimelineStream(`hashtag:${columnId}${onlyLocal ? ':local' : ''}`, `hashtag${onlyLocal ? ':local' : ''}`, { tag: tagName }, { accept }); + +/** + * @return {function(): void} + */ +export const connectDirectStream = () => + connectTimelineStream('direct', 'direct'); + +/** + * @param {string} listId + * @return {function(): void} + */ +export const connectListStream = listId => + connectTimelineStream(`list:${listId}`, 'list', { list: listId }); diff --git a/app/javascript/mastodon/stream.js b/app/javascript/mastodon/stream.js index 0cb2b228f3..640455b33d 100644 --- a/app/javascript/mastodon/stream.js +++ b/app/javascript/mastodon/stream.js @@ -1,87 +1,236 @@ +// @ts-check + import WebSocketClient from '@gamestdio/websocket'; -const randomIntUpTo = max => Math.floor(Math.random() * Math.floor(max)); +/** + * @type {WebSocketClient | undefined} + */ +let sharedConnection; -const knownEventTypes = [ - 'update', - 'delete', - 'notification', - 'conversation', - 'filters_changed', -]; +/** + * @typedef Subscription + * @property {string} channelName + * @property {Object.} params + * @property {function(): void} onConnect + * @property {function(StreamEvent): void} onReceive + * @property {function(): void} onDisconnect + */ -export function connectStream(path, pollingRefresh = null, callbacks = () => ({ onConnect() {}, onDisconnect() {}, onReceive() {} })) { - return (dispatch, getState) => { - const streamingAPIBaseURL = getState().getIn(['meta', 'streaming_api_base_url']); - const accessToken = getState().getIn(['meta', 'access_token']); - const { onConnect, onDisconnect, onReceive } = callbacks(dispatch, getState); + /** + * @typedef StreamEvent + * @property {string} event + * @property {object} payload + */ - let polling = null; +/** + * @type {Array.} + */ +const subscriptions = []; - const setupPolling = () => { - pollingRefresh(dispatch, () => { - polling = setTimeout(() => setupPolling(), 20000 + randomIntUpTo(20000)); - }); - }; +/** + * @type {Object.} + */ +const subscriptionCounters = {}; - const clearPolling = () => { - if (polling) { - clearTimeout(polling); - polling = null; +/** + * @param {Subscription} subscription + */ +const addSubscription = subscription => { + subscriptions.push(subscription); +}; + +/** + * @param {Subscription} subscription + */ +const removeSubscription = subscription => { + const index = subscriptions.indexOf(subscription); + + if (index !== -1) { + subscriptions.splice(index, 1); + } +}; + +/** + * @param {Subscription} subscription + */ +const subscribe = ({ channelName, params, onConnect }) => { + const key = channelNameWithInlineParams(channelName, params); + + subscriptionCounters[key] = subscriptionCounters[key] || 0; + + if (subscriptionCounters[key] === 0) { + sharedConnection.send(JSON.stringify({ type: 'subscribe', stream: channelName, ...params })); + } + + subscriptionCounters[key] += 1; + onConnect(); +}; + +/** + * @param {Subscription} subscription + */ +const unsubscribe = ({ channelName, params, onDisconnect }) => { + const key = channelNameWithInlineParams(channelName, params); + + subscriptionCounters[key] = subscriptionCounters[key] || 1; + + if (subscriptionCounters[key] === 1 && sharedConnection.readyState === WebSocketClient.OPEN) { + sharedConnection.send(JSON.stringify({ type: 'unsubscribe', stream: channelName, ...params })); + } + + subscriptionCounters[key] -= 1; + onDisconnect(); +}; + +const sharedCallbacks = { + connected () { + subscriptions.forEach(subscription => subscribe(subscription)); + }, + + received (data) { + const { stream } = data; + + subscriptions.filter(({ channelName, params }) => { + const streamChannelName = stream[0]; + + if (stream.length === 1) { + return channelName === streamChannelName; } - }; - const subscription = getStream(streamingAPIBaseURL, accessToken, path, { + const streamIdentifier = stream[1]; + + if (['hashtag', 'hashtag:local'].includes(channelName)) { + return channelName === streamChannelName && params.tag === streamIdentifier; + } else if (channelName === 'list') { + return channelName === streamChannelName && params.list === streamIdentifier; + } + + return false; + }).forEach(subscription => { + subscription.onReceive(data); + }); + }, + + disconnected () { + subscriptions.forEach(({ onDisconnect }) => onDisconnect()); + }, + + reconnected () { + subscriptions.forEach(subscription => subscribe(subscription)); + }, +}; + +/** + * @param {string} channelName + * @param {Object.} params + * @return {string} + */ +const channelNameWithInlineParams = (channelName, params) => { + if (Object.keys(params).length === 0) { + return channelName; + } + + return `${channelName}&${Object.keys(params).map(key => `${key}=${params[key]}`).join('&')}`; +}; + +/** + * @param {string} channelName + * @param {Object.} params + * @param {function(Function, Function): { onConnect: (function(): void), onReceive: (function(StreamEvent): void), onDisconnect: (function(): void) }} callbacks + * @return {function(): void} + */ +export const connectStream = (channelName, params, callbacks) => (dispatch, getState) => { + const streamingAPIBaseURL = getState().getIn(['meta', 'streaming_api_base_url']); + const accessToken = getState().getIn(['meta', 'access_token']); + const { onConnect, onReceive, onDisconnect } = callbacks(dispatch, getState); + + // If we cannot use a websockets connection, we must fall back + // to using individual connections for each channel + if (!streamingAPIBaseURL.startsWith('ws')) { + const connection = createConnection(streamingAPIBaseURL, accessToken, channelNameWithInlineParams(channelName, params), { connected () { - if (pollingRefresh) { - clearPolling(); - } - onConnect(); }, - disconnected () { - if (pollingRefresh) { - polling = setTimeout(() => setupPolling(), randomIntUpTo(40000)); - } - - onDisconnect(); - }, - received (data) { onReceive(data); }, - reconnected () { - if (pollingRefresh) { - clearPolling(); - pollingRefresh(dispatch); - } - - onConnect(); + disconnected () { + onDisconnect(); }, + reconnected () { + onConnect(); + }, }); - const disconnect = () => { - if (subscription) { - subscription.close(); - } - - clearPolling(); + return () => { + connection.close(); }; + } - return disconnect; + const subscription = { + channelName, + params, + onConnect, + onReceive, + onDisconnect, }; -} + addSubscription(subscription); -export default function getStream(streamingAPIBaseURL, accessToken, stream, { connected, received, disconnected, reconnected }) { - const params = stream.split('&'); - stream = params.shift(); + // If a connection is open, we can execute the subscription right now. Otherwise, + // because we have already registered it, it will be executed on connect + + if (!sharedConnection) { + sharedConnection = /** @type {WebSocketClient} */ (createConnection(streamingAPIBaseURL, accessToken, '', sharedCallbacks)); + } else if (sharedConnection.readyState === WebSocketClient.OPEN) { + subscribe(subscription); + } + + return () => { + removeSubscription(subscription); + unsubscribe(subscription); + }; +}; + +const KNOWN_EVENT_TYPES = [ + 'update', + 'delete', + 'notification', + 'conversation', + 'filters_changed', + 'encrypted_message', + 'announcement', + 'announcement.delete', + 'announcement.reaction', +]; + +/** + * @param {MessageEvent} e + * @param {function(StreamEvent): void} received + */ +const handleEventSourceMessage = (e, received) => { + received({ + event: e.type, + payload: e.data, + }); +}; + +/** + * @param {string} streamingAPIBaseURL + * @param {string} accessToken + * @param {string} channelName + * @param {{ connected: Function, received: function(StreamEvent): void, disconnected: Function, reconnected: Function }} callbacks + * @return {WebSocketClient | EventSource} + */ +const createConnection = (streamingAPIBaseURL, accessToken, channelName, { connected, received, disconnected, reconnected }) => { + const params = channelName.split('&'); + + channelName = params.shift(); if (streamingAPIBaseURL.startsWith('ws')) { - params.unshift(`stream=${stream}`); const ws = new WebSocketClient(`${streamingAPIBaseURL}/api/v1/streaming/?${params.join('&')}`, accessToken); ws.onopen = connected; @@ -92,11 +241,19 @@ export default function getStream(streamingAPIBaseURL, accessToken, stream, { co return ws; } - stream = stream.replace(/:/g, '/'); + channelName = channelName.replace(/:/g, '/'); + + if (channelName.endsWith(':media')) { + channelName = channelName.replace('/media', ''); + params.push('only_media=true'); + } + params.push(`access_token=${accessToken}`); - const es = new EventSource(`${streamingAPIBaseURL}/api/v1/streaming/${stream}?${params.join('&')}`); + + const es = new EventSource(`${streamingAPIBaseURL}/api/v1/streaming/${channelName}?${params.join('&')}`); let firstConnect = true; + es.onopen = () => { if (firstConnect) { firstConnect = false; @@ -105,15 +262,12 @@ export default function getStream(streamingAPIBaseURL, accessToken, stream, { co reconnected(); } }; - for (let type of knownEventTypes) { - es.addEventListener(type, (e) => { - received({ - event: e.type, - payload: e.data, - }); - }); - } - es.onerror = disconnected; + + KNOWN_EVENT_TYPES.forEach(type => { + es.addEventListener(type, e => handleEventSourceMessage(/** @type {MessageEvent} */ (e), received)); + }); + + es.onerror = /** @type {function(): void} */ (disconnected); return es; }; diff --git a/streaming/index.js b/streaming/index.js index 39e70c1ba7..7c0c6a465e 100644 --- a/streaming/index.js +++ b/streaming/index.js @@ -1,3 +1,5 @@ +// @ts-check + const os = require('os'); const throng = require('throng'); const dotenv = require('dotenv'); @@ -12,7 +14,7 @@ const uuid = require('uuid'); const fs = require('fs'); const env = process.env.NODE_ENV || 'development'; -const alwaysRequireAuth = process.env.WHITELIST_MODE === 'true' || process.env.AUTHORIZED_FETCH === 'true'; +const alwaysRequireAuth = process.env.LIMITED_FEDERATION_MODE === 'true' || process.env.WHITELIST_MODE === 'true' || process.env.AUTHORIZED_FETCH === 'true'; dotenv.config({ path: env === 'production' ? '.env.production' : '.env', @@ -20,6 +22,10 @@ dotenv.config({ log.level = process.env.LOG_LEVEL || 'verbose'; +/** + * @param {string} dbUrl + * @return {Object.} + */ const dbUrlToConfig = (dbUrl) => { if (!dbUrl) { return {}; @@ -53,6 +59,10 @@ const dbUrlToConfig = (dbUrl) => { return config; }; +/** + * @param {Object.} defaultConfig + * @param {string} redisUrl + */ const redisUrlToClient = (defaultConfig, redisUrl) => { const config = defaultConfig; @@ -108,6 +118,7 @@ const startWorker = (workerId) => { } const app = express(); + app.set('trusted proxy', process.env.TRUSTED_PROXY_IP || 'loopback,uniquelocal'); const pgPool = new pg.Pool(Object.assign(pgConfigs[env], dbUrlToConfig(process.env.DATABASE_URL))); @@ -130,6 +141,9 @@ const startWorker = (workerId) => { const redisSubscribeClient = redisUrlToClient(redisParams, process.env.REDIS_URL); const redisClient = redisUrlToClient(redisParams, process.env.REDIS_URL); + /** + * @type {Object.>} + */ const subs = {}; redisSubscribeClient.on('message', (channel, message) => { @@ -144,11 +158,11 @@ const startWorker = (workerId) => { callbacks.forEach(callback => callback(message)); }); + /** + * @param {string[]} channels + * @return {function(): void} + */ const subscriptionHeartbeat = channels => { - if (!Array.isArray(channels)) { - channels = [channels]; - } - const interval = 6 * 60; const tellSubscribed = () => { @@ -164,25 +178,65 @@ const startWorker = (workerId) => { }; }; + /** + * @param {string} channel + * @param {function(string): void} callback + */ const subscribe = (channel, callback) => { log.silly(`Adding listener for ${channel}`); subs[channel] = subs[channel] || []; + if (subs[channel].length === 0) { log.verbose(`Subscribe ${channel}`); redisSubscribeClient.subscribe(channel); } + subs[channel].push(callback); }; + /** + * @param {string} channel + * @param {function(string): void} callback + */ const unsubscribe = (channel, callback) => { log.silly(`Removing listener for ${channel}`); + + if (!subs[channel]) { + return; + } + subs[channel] = subs[channel].filter(item => item !== callback); + if (subs[channel].length === 0) { log.verbose(`Unsubscribe ${channel}`); redisSubscribeClient.unsubscribe(channel); } }; + const FALSE_VALUES = [ + false, + 0, + "0", + "f", + "F", + "false", + "FALSE", + "off", + "OFF" + ]; + + /** + * @param {any} value + * @return {boolean} + */ + const isTruthy = value => + value && !FALSE_VALUES.includes(value); + + /** + * @param {any} req + * @param {any} res + * @param {function(Error=): void} + */ const allowCrossDomain = (req, res, next) => { res.header('Access-Control-Allow-Origin', '*'); res.header('Access-Control-Allow-Headers', 'Authorization, Accept, Cache-Control'); @@ -191,6 +245,11 @@ const startWorker = (workerId) => { next(); }; + /** + * @param {any} req + * @param {any} res + * @param {function(Error=): void} + */ const setRequestId = (req, res, next) => { req.requestId = uuid.v4(); res.header('X-Request-Id', req.requestId); @@ -198,16 +257,26 @@ const startWorker = (workerId) => { next(); }; + /** + * @param {any} req + * @param {any} res + * @param {function(Error=): void} + */ const setRemoteAddress = (req, res, next) => { req.remoteAddress = req.connection.remoteAddress; next(); }; - const accountFromToken = (token, allowedScopes, req, next) => { + /** + * @param {string} token + * @param {any} req + * @return {Promise.} + */ + const accountFromToken = (token, req) => new Promise((resolve, reject) => { pgPool.connect((err, client, done) => { if (err) { - next(err); + reject(err); return; } @@ -215,62 +284,88 @@ const startWorker = (workerId) => { done(); if (err) { - next(err); + reject(err); return; } if (result.rows.length === 0) { err = new Error('Invalid access token'); - err.statusCode = 401; + err.status = 401; - next(err); - return; - } - - const scopes = result.rows[0].scopes.split(' '); - - if (allowedScopes.size > 0 && !scopes.some(scope => allowedScopes.includes(scope))) { - err = new Error('Access token does not cover required scopes'); - err.statusCode = 401; - - next(err); + reject(err); return; } + req.scopes = result.rows[0].scopes.split(' '); req.accountId = result.rows[0].account_id; req.chosenLanguages = result.rows[0].chosen_languages; - req.allowNotifications = scopes.some(scope => ['read', 'read:notifications'].includes(scope)); + req.allowNotifications = req.scopes.some(scope => ['read', 'read:notifications'].includes(scope)); req.deviceId = result.rows[0].device_id; - next(); + resolve(); }); }); - }; + }); - const accountFromRequest = (req, next, required = true, allowedScopes = ['read']) => { + /** + * @param {any} req + * @param {boolean=} required + * @return {Promise.} + */ + const accountFromRequest = (req, required = true) => new Promise((resolve, reject) => { const authorization = req.headers.authorization; - const location = url.parse(req.url, true); - const accessToken = location.query.access_token || req.headers['sec-websocket-protocol']; + const location = url.parse(req.url, true); + const accessToken = location.query.access_token || req.headers['sec-websocket-protocol']; if (!authorization && !accessToken) { if (required) { const err = new Error('Missing access token'); - err.statusCode = 401; + err.status = 401; - next(err); + reject(err); return; } else { - next(); + resolve(); return; } } const token = authorization ? authorization.replace(/^Bearer /, '') : accessToken; - accountFromToken(token, allowedScopes, req, next); + resolve(accountFromToken(token, req)); + }); + + /** + * @param {any} req + * @return {string} + */ + const channelNameFromPath = req => { + const { path, query } = req; + const onlyMedia = isTruthy(query.only_media); + + switch(path) { + case '/api/v1/streaming/user': + return 'user'; + case '/api/v1/streaming/user/notification': + return 'user:notification'; + case '/api/v1/streaming/public': + return onlyMedia ? 'public:media' : 'public'; + case '/api/v1/streaming/public/local': + return onlyMedia ? 'public:local:media' : 'public:local'; + case '/api/v1/streaming/public/remote': + return onlyMedia ? 'public:remote:media' : 'public:remote'; + case '/api/v1/streaming/hashtag': + return 'hashtag'; + case '/api/v1/streaming/hashtag/local': + return 'hashtag:local'; + case '/api/v1/streaming/direct': + return 'direct'; + case '/api/v1/streaming/list': + return 'list'; + } }; - const PUBLIC_STREAMS = [ + const PUBLIC_CHANNELS = [ 'public', 'public:media', 'public:local', @@ -281,95 +376,148 @@ const startWorker = (workerId) => { 'hashtag:local', ]; - const wsVerifyClient = (info, cb) => { - const location = url.parse(info.req.url, true); - const authRequired = alwaysRequireAuth || !PUBLIC_STREAMS.some(stream => stream === location.query.stream); - const allowedScopes = []; + /** + * @param {any} req + * @param {string} channelName + * @return {Promise.} + */ + const checkScopes = (req, channelName) => new Promise((resolve, reject) => { + log.silly(req.requestId, `Checking OAuth scopes for ${channelName}`); - if (authRequired) { - allowedScopes.push('read'); - if (location.query.stream === 'user:notification') { - allowedScopes.push('read:notifications'); - } else { - allowedScopes.push('read:statuses'); - } + // When accessing public channels, no scopes are needed + if (PUBLIC_CHANNELS.includes(channelName)) { + resolve(); + return; } - accountFromRequest(info.req, err => { - if (!err) { - cb(true, undefined, undefined); - } else { - log.error(info.req.requestId, err.toString()); - cb(false, 401, 'Unauthorized'); - } - }, authRequired, allowedScopes); + // The `read` scope has the highest priority, if the token has it + // then it can access all streams + const requiredScopes = ['read']; + + // When accessing specifically the notifications stream, + // we need a read:notifications, while in all other cases, + // we can allow access with read:statuses. Mind that the + // user stream will not contain notifications unless + // the token has either read or read:notifications scope + // as well, this is handled separately. + if (channelName === 'user:notification') { + requiredScopes.push('read:notifications'); + } else { + requiredScopes.push('read:statuses'); + } + + if (requiredScopes.some(requiredScope => req.scopes.includes(requiredScope))) { + resolve(); + return; + } + + const err = new Error('Access token does not cover required scopes'); + err.status = 401; + + reject(err); + }); + + /** + * @param {any} info + * @param {function(boolean, number, string): void} callback + */ + const wsVerifyClient = (info, callback) => { + // When verifying the websockets connection, we no longer pre-emptively + // check OAuth scopes and drop the connection if they're missing. We only + // drop the connection if access without token is not allowed by environment + // variables. OAuth scope checks are moved to the point of subscription + // to a specific stream. + + accountFromRequest(info.req, alwaysRequireAuth).then(() => { + callback(true, undefined, undefined); + }).catch(err => { + log.error(info.req.requestId, err.toString()); + callback(false, 401, 'Unauthorized'); + }); }; - const PUBLIC_ENDPOINTS = [ - '/api/v1/streaming/public', - '/api/v1/streaming/public/local', - '/api/v1/streaming/public/remote', - '/api/v1/streaming/hashtag', - '/api/v1/streaming/hashtag/local', - ]; - + /** + * @param {any} req + * @param {any} res + * @param {function(Error=): void} next + */ const authenticationMiddleware = (req, res, next) => { if (req.method === 'OPTIONS') { next(); return; } - const authRequired = alwaysRequireAuth || !PUBLIC_ENDPOINTS.some(endpoint => endpoint === req.path); - const allowedScopes = []; - - if (authRequired) { - allowedScopes.push('read'); - if (req.path === '/api/v1/streaming/user/notification') { - allowedScopes.push('read:notifications'); - } else { - allowedScopes.push('read:statuses'); - } - } - - accountFromRequest(req, next, authRequired, allowedScopes); - }; - - const errorMiddleware = (err, req, res, {}) => { - log.error(req.requestId, err.toString()); - res.writeHead(err.statusCode || 500, { 'Content-Type': 'application/json' }); - res.end(JSON.stringify({ error: err.statusCode ? err.toString() : 'An unexpected error occurred' })); - }; - - const placeholders = (arr, shift = 0) => arr.map((_, i) => `$${i + 1 + shift}`).join(', '); - - const authorizeListAccess = (id, req, next) => { - pgPool.connect((err, client, done) => { - if (err) { - next(false); - return; - } - - client.query('SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1', [id], (err, result) => { - done(); - - if (err || result.rows.length === 0 || result.rows[0].account_id !== req.accountId) { - next(false); - return; - } - - next(true); - }); + accountFromRequest(req, alwaysRequireAuth).then(() => checkScopes(req, channelNameFromPath(req))).then(() => { + next(); + }).catch(err => { + next(err); }); }; + /** + * @param {Error} err + * @param {any} req + * @param {any} res + * @param {function(Error=): void} next + */ + const errorMiddleware = (err, req, res, next) => { + log.error(req.requestId, err.toString()); + + if (res.headersSent) { + return next(err); + } + + res.writeHead(err.status || 500, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: err.status ? err.toString() : 'An unexpected error occurred' })); + }; + + /** + * @param {array} + * @param {number=} shift + * @return {string} + */ + const placeholders = (arr, shift = 0) => arr.map((_, i) => `$${i + 1 + shift}`).join(', '); + + /** + * @param {string} listId + * @param {any} req + * @return {Promise.} + */ + const authorizeListAccess = (listId, req) => new Promise((resolve, reject) => { + const { accountId } = req; + + pgPool.connect((err, client, done) => { + if (err) { + reject(); + return; + } + + client.query('SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1', [listId], (err, result) => { + done(); + + if (err || result.rows.length === 0 || result.rows[0].account_id !== accountId) { + reject(); + return; + } + + resolve(); + }); + }); + }); + + /** + * @param {string[]} ids + * @param {any} req + * @param {function(string, string): void} output + * @param {function(string[], function(string): void): void} attachCloseHandler + * @param {boolean=} needsFiltering + * @param {boolean=} notificationOnly + * @return {function(string): void} + */ const streamFrom = (ids, req, output, attachCloseHandler, needsFiltering = false, notificationOnly = false) => { const accountId = req.accountId || req.remoteAddress; const streamType = notificationOnly ? ' (notification)' : ''; - if (!Array.isArray(ids)) { - ids = [ids]; - } - log.verbose(req.requestId, `Starting stream from ${ids.join(', ')} for ${accountId}${streamType}`); const listener = message => { @@ -447,10 +595,18 @@ const startWorker = (workerId) => { subscribe(`${redisPrefix}${id}`, listener); }); - attachCloseHandler(ids.map(id => `${redisPrefix}${id}`), listener); + if (attachCloseHandler) { + attachCloseHandler(ids.map(id => `${redisPrefix}${id}`), listener); + } + + return listener; }; - // Setup stream output to HTTP + /** + * @param {any} req + * @param {any} res + * @return {function(string, string): void} + */ const streamToHttp = (req, res) => { const accountId = req.accountId || req.remoteAddress; @@ -473,12 +629,12 @@ const startWorker = (workerId) => { }; }; - // Setup stream end for HTTP - const streamHttpEnd = (req, closeHandler = false) => (ids, listener) => { - if (!Array.isArray(ids)) { - ids = [ids]; - } - + /** + * @param {any} req + * @param {function(): void} [closeHandler] + * @return {function(string[], function(string): void)} + */ + const streamHttpEnd = (req, closeHandler = undefined) => (ids, listener) => { req.on('close', () => { ids.forEach(id => { unsubscribe(id, listener); @@ -490,37 +646,24 @@ const startWorker = (workerId) => { }); }; - // Setup stream output to WebSockets - const streamToWs = (req, ws) => (event, payload) => { + /** + * @param {any} req + * @param {any} ws + * @param {string[]} streamName + * @return {function(string, string): void} + */ + const streamToWs = (req, ws, streamName) => (event, payload) => { if (ws.readyState !== ws.OPEN) { log.error(req.requestId, 'Tried writing to closed socket'); return; } - ws.send(JSON.stringify({ event, payload })); - }; - - // Setup stream end for WebSockets - const streamWsEnd = (req, ws, closeHandler = false) => (id, listener) => { - const accountId = req.accountId || req.remoteAddress; - - ws.on('close', () => { - log.verbose(req.requestId, `Ending stream for ${accountId}`); - unsubscribe(id, listener); - if (closeHandler) { - closeHandler(); - } - }); - - ws.on('error', () => { - log.verbose(req.requestId, `Ending stream for ${accountId}`); - unsubscribe(id, listener); - if (closeHandler) { - closeHandler(); - } - }); + ws.send(JSON.stringify({ stream: streamName, event, payload })); }; + /** + * @param {any} res + */ const httpNotFound = res => { res.writeHead(404, { 'Content-Type': 'application/json' }); res.end(JSON.stringify({ error: 'Not found' })); @@ -538,157 +681,267 @@ const startWorker = (workerId) => { app.use(authenticationMiddleware); app.use(errorMiddleware); - app.get('/api/v1/streaming/user', (req, res) => { - const channels = [`timeline:${req.accountId}`]; + app.get('/api/v1/streaming/*', (req, res) => { + channelNameToIds(req, channelNameFromPath(req), req.query).then(({ channelIds, options }) => { + const onSend = streamToHttp(req, res); + const onEnd = streamHttpEnd(req, subscriptionHeartbeat(channelIds)); - if (req.deviceId) { - channels.push(`timeline:${req.accountId}:${req.deviceId}`); - } - - streamFrom(channels, req, streamToHttp(req, res), streamHttpEnd(req, subscriptionHeartbeat(channels))); - }); - - app.get('/api/v1/streaming/user/notification', (req, res) => { - streamFrom(`timeline:${req.accountId}`, req, streamToHttp(req, res), streamHttpEnd(req), false, true); - }); - - app.get('/api/v1/streaming/public', (req, res) => { - const onlyMedia = req.query.only_media === '1' || req.query.only_media === 'true'; - const channel = onlyMedia ? 'timeline:public:media' : 'timeline:public'; - - streamFrom(channel, req, streamToHttp(req, res), streamHttpEnd(req), true); - }); - - app.get('/api/v1/streaming/public/local', (req, res) => { - const onlyMedia = req.query.only_media === '1' || req.query.only_media === 'true'; - const channel = onlyMedia ? 'timeline:public:local:media' : 'timeline:public:local'; - - streamFrom(channel, req, streamToHttp(req, res), streamHttpEnd(req), true); - }); - - app.get('/api/v1/streaming/public/remote', (req, res) => { - const onlyMedia = req.query.only_media === '1' || req.query.only_media === 'true'; - const channel = onlyMedia ? 'timeline:public:remote:media' : 'timeline:public:remote'; - - streamFrom(channel, req, streamToHttp(req, res), streamHttpEnd(req), true); - }); - - app.get('/api/v1/streaming/direct', (req, res) => { - const channel = `timeline:direct:${req.accountId}`; - streamFrom(channel, req, streamToHttp(req, res), streamHttpEnd(req, subscriptionHeartbeat(channel)), true); - }); - - app.get('/api/v1/streaming/hashtag', (req, res) => { - const { tag } = req.query; - - if (!tag || tag.length === 0) { + streamFrom(channelIds, req, onSend, onEnd, options.needsFiltering, options.notificationOnly); + }).catch(err => { + log.verbose(req.requestId, 'Subscription error:', err.toString()); httpNotFound(res); - return; - } - - streamFrom(`timeline:hashtag:${tag.toLowerCase()}`, req, streamToHttp(req, res), streamHttpEnd(req), true); - }); - - app.get('/api/v1/streaming/hashtag/local', (req, res) => { - const { tag } = req.query; - - if (!tag || tag.length === 0) { - httpNotFound(res); - return; - } - - streamFrom(`timeline:hashtag:${tag.toLowerCase()}:local`, req, streamToHttp(req, res), streamHttpEnd(req), true); - }); - - app.get('/api/v1/streaming/list', (req, res) => { - const listId = req.query.list; - - authorizeListAccess(listId, req, authorized => { - if (!authorized) { - httpNotFound(res); - return; - } - - const channel = `timeline:list:${listId}`; - streamFrom(channel, req, streamToHttp(req, res), streamHttpEnd(req, subscriptionHeartbeat(channel))); }); }); const wss = new WebSocketServer({ server, verifyClient: wsVerifyClient }); - wss.on('connection', (ws, req) => { - const location = url.parse(req.url, true); - req.requestId = uuid.v4(); - req.remoteAddress = ws._socket.remoteAddress; + /** + * @typedef StreamParams + * @property {string} [tag] + * @property {string} [list] + * @property {string} [only_media] + */ - let channel; - - switch(location.query.stream) { + /** + * @param {any} req + * @param {string} name + * @param {StreamParams} params + * @return {Promise.<{ channelIds: string[], options: { needsFiltering: boolean, notificationOnly: boolean } }>} + */ + const channelNameToIds = (req, name, params) => new Promise((resolve, reject) => { + switch(name) { case 'user': - channel = [`timeline:${req.accountId}`]; + resolve({ + channelIds: req.deviceId ? [`timeline:${req.accountId}`, `timeline:${req.accountId}:${req.deviceId}`] : [`timeline:${req.accountId}`], + options: { needsFiltering: false, notificationOnly: false }, + }); - if (req.deviceId) { - channel.push(`timeline:${req.accountId}:${req.deviceId}`); - } - - streamFrom(channel, req, streamToWs(req, ws), streamWsEnd(req, ws, subscriptionHeartbeat(channel))); break; case 'user:notification': - streamFrom(`timeline:${req.accountId}`, req, streamToWs(req, ws), streamWsEnd(req, ws), false, true); + resolve({ + channelIds: [`timeline:${req.accountId}`], + options: { needsFiltering: false, notificationOnly: true }, + }); + break; case 'public': - streamFrom('timeline:public', req, streamToWs(req, ws), streamWsEnd(req, ws), true); + resolve({ + channelIds: ['timeline:public'], + options: { needsFiltering: true, notificationOnly: false }, + }); + break; case 'public:local': - streamFrom('timeline:public:local', req, streamToWs(req, ws), streamWsEnd(req, ws), true); + resolve({ + channelIds: ['timeline:public:local'], + options: { needsFiltering: true, notificationOnly: false }, + }); + break; case 'public:remote': - streamFrom('timeline:public:remote', req, streamToWs(req, ws), streamWsEnd(req, ws), true); + resolve({ + channelIds: ['timeline:public:remote'], + options: { needsFiltering: true, notificationOnly: false }, + }); + break; case 'public:media': - streamFrom('timeline:public:media', req, streamToWs(req, ws), streamWsEnd(req, ws), true); + resolve({ + channelIds: ['timeline:public:media'], + options: { needsFiltering: true, notificationOnly: false }, + }); + break; case 'public:local:media': - streamFrom('timeline:public:local:media', req, streamToWs(req, ws), streamWsEnd(req, ws), true); + resolve({ + channelIds: ['timeline:public:local:media'], + options: { needsFiltering: true, notificationOnly: false }, + }); + break; case 'public:remote:media': - streamFrom('timeline:public:remote:media', req, streamToWs(req, ws), streamWsEnd(req, ws), true); + resolve({ + channelIds: ['timeline:public:remote:media'], + options: { needsFiltering: true, notificationOnly: false }, + }); + break; case 'direct': - channel = `timeline:direct:${req.accountId}`; - streamFrom(channel, req, streamToWs(req, ws), streamWsEnd(req, ws, subscriptionHeartbeat(channel)), true); + resolve({ + channelIds: [`timeline:direct:${req.accountId}`], + options: { needsFiltering: false, notificationOnly: false }, + }); + break; case 'hashtag': - if (!location.query.tag || location.query.tag.length === 0) { - ws.close(); - return; + if (!params.tag || params.tag.length === 0) { + reject('No tag for stream provided'); + } else { + resolve({ + channelIds: [`timeline:hashtag:${params.tag.toLowerCase()}`], + options: { needsFiltering: true, notificationOnly: false }, + }); } - streamFrom(`timeline:hashtag:${location.query.tag.toLowerCase()}`, req, streamToWs(req, ws), streamWsEnd(req, ws), true); break; case 'hashtag:local': - if (!location.query.tag || location.query.tag.length === 0) { - ws.close(); + if (!params.tag || params.tag.length === 0) { + reject('No tag for stream provided'); + } else { + resolve({ + channelIds: [`timeline:hashtag:${params.tag.toLowerCase()}:local`], + options: { needsFiltering: true, notificationOnly: false }, + }); + } + + break; + case 'list': + authorizeListAccess(params.list, req).then(() => { + resolve({ + channelIds: [`timeline:list:${params.list}`], + options: { needsFiltering: false, notificationOnly: false }, + }); + }).catch(() => { + reject('Not authorized to stream this list'); + }); + + break; + default: + reject('Unknown stream type'); + } + }); + + /** + * @param {string} channelName + * @param {StreamParams} params + * @return {string[]} + */ + const streamNameFromChannelName = (channelName, params) => { + if (channelName === 'list') { + return [channelName, params.list]; + } else if (['hashtag', 'hashtag:local'].includes(channelName)) { + return [channelName, params.tag]; + } else { + return [channelName]; + } + }; + + /** + * @typedef WebSocketSession + * @property {any} socket + * @property {any} request + * @property {Object.} subscriptions + */ + + /** + * @param {WebSocketSession} session + * @param {string} channelName + * @param {StreamParams} params + */ + const subscribeWebsocketToChannel = ({ socket, request, subscriptions }, channelName, params) => + checkScopes(request, channelName).then(() => channelNameToIds(request, channelName, params)).then(({ channelIds, options }) => { + if (subscriptions[channelIds.join(';')]) { return; } - streamFrom(`timeline:hashtag:${location.query.tag.toLowerCase()}:local`, req, streamToWs(req, ws), streamWsEnd(req, ws), true); - break; - case 'list': - const listId = location.query.list; + const onSend = streamToWs(request, socket, streamNameFromChannelName(channelName, params)); + const stopHeartbeat = subscriptionHeartbeat(channelIds); + const listener = streamFrom(channelIds, request, onSend, undefined, options.needsFiltering, options.notificationOnly); - authorizeListAccess(listId, req, authorized => { - if (!authorized) { - ws.close(); - return; - } + subscriptions[channelIds.join(';')] = { + listener, + stopHeartbeat, + }; + }).catch(err => { + log.verbose(request.requestId, 'Subscription error:', err.toString()); + socket.send(JSON.stringify({ error: err.toString() })); + }); - channel = `timeline:list:${listId}`; - streamFrom(channel, req, streamToWs(req, ws), streamWsEnd(req, ws, subscriptionHeartbeat(channel))); + /** + * @param {WebSocketSession} session + * @param {string} channelName + * @param {StreamParams} params + */ + const unsubscribeWebsocketFromChannel = ({ socket, request, subscriptions }, channelName, params) => + channelNameToIds(request, channelName, params).then(({ channelIds }) => { + log.verbose(request.requestId, `Ending stream from ${channelIds.join(', ')} for ${request.accountId}`); + + const { listener, stopHeartbeat } = subscriptions[channelIds.join(';')]; + + if (!listener) { + return; + } + + channelIds.forEach(channelId => { + unsubscribe(`${redisPrefix}${channelId}`, listener); }); - break; - default: - ws.close(); + + stopHeartbeat(); + + subscriptions[channelIds.join(';')] = undefined; + }).catch(err => { + log.verbose(request.requestId, 'Unsubscription error:', err); + socket.send(JSON.stringify({ error: err.toString() })); + }); + + /** + * @param {string|string[]} arrayOrString + * @return {string} + */ + const firstParam = arrayOrString => { + if (Array.isArray(arrayOrString)) { + return arrayOrString[0]; + } else { + return arrayOrString; + } + }; + + wss.on('connection', (ws, req) => { + const location = url.parse(req.url, true); + + req.requestId = uuid.v4(); + req.remoteAddress = ws._socket.remoteAddress; + + /** + * @type {WebSocketSession} + */ + const session = { + socket: ws, + request: req, + subscriptions: {}, + }; + + const onEnd = () => { + const keys = Object.keys(session.subscriptions); + + keys.forEach(channelIds => { + const { listener, stopHeartbeat } = session.subscriptions[channelIds]; + + channelIds.split(';').forEach(channelId => { + unsubscribe(`${redisPrefix}${channelId}`, listener); + }); + + stopHeartbeat(); + }); + }; + + ws.on('close', onEnd); + ws.on('error', onEnd); + + ws.on('message', data => { + const { type, stream, ...params } = JSON.parse(data); + + if (type === 'subscribe') { + subscribeWebsocketToChannel(session, firstParam(stream), params); + } else if (type === 'unsubscribe') { + unsubscribeWebsocketFromChannel(session, firstParam(stream), params) + } else { + // Unknown action type + } + }); + + if (location.query.stream) { + subscribeWebsocketToChannel(session, firstParam(location.query.stream), location.query); } }); @@ -716,6 +969,10 @@ const startWorker = (workerId) => { process.on('uncaughtException', onError); }; +/** + * @param {any} server + * @param {function(string): void} [onSuccess] + */ const attachServerWithConfig = (server, onSuccess) => { if (process.env.SOCKET || process.env.PORT && isNaN(+process.env.PORT)) { server.listen(process.env.SOCKET || process.env.PORT, () => { @@ -733,6 +990,9 @@ const attachServerWithConfig = (server, onSuccess) => { } }; +/** + * @param {function(Error=): void} onSuccess + */ const onPortAvailable = onSuccess => { const testServer = http.createServer();