[RFC] Return 401 for an authentication error on WebSockets (#3411)
* Return 401 for an authentication error on WebSocket * Use upgradeReq instead of a custom object
This commit is contained in:
		
							parent
							
								
									6fb5ac2410
								
							
						
					
					
						commit
						23e09cc6b7
					
				
					 1 changed files with 48 additions and 39 deletions
				
			
		|  | @ -95,7 +95,6 @@ const startWorker = (workerId) => { | ||||||
|   const app    = express(); |   const app    = express(); | ||||||
|   const pgPool = new pg.Pool(Object.assign(pgConfigs[env], dbUrlToConfig(process.env.DATABASE_URL))); |   const pgPool = new pg.Pool(Object.assign(pgConfigs[env], dbUrlToConfig(process.env.DATABASE_URL))); | ||||||
|   const server = http.createServer(app); |   const server = http.createServer(app); | ||||||
|   const wss    = new WebSocket.Server({ server }); |  | ||||||
|   const redisNamespace = process.env.REDIS_NAMESPACE || null; |   const redisNamespace = process.env.REDIS_NAMESPACE || null; | ||||||
| 
 | 
 | ||||||
|   const redisParams = { |   const redisParams = { | ||||||
|  | @ -186,14 +185,10 @@ const startWorker = (workerId) => { | ||||||
|     }); |     }); | ||||||
|   }; |   }; | ||||||
| 
 | 
 | ||||||
|   const authenticationMiddleware = (req, res, next) => { |   const accountFromRequest = (req, next) => { | ||||||
|     if (req.method === 'OPTIONS') { |     const authorization = req.headers.authorization; | ||||||
|       next(); |     const location = url.parse(req.url, true); | ||||||
|       return; |     const accessToken = location.query.access_token; | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     const authorization = req.get('Authorization'); |  | ||||||
|     const accessToken = req.query.access_token; |  | ||||||
| 
 | 
 | ||||||
|     if (!authorization && !accessToken) { |     if (!authorization && !accessToken) { | ||||||
|       const err = new Error('Missing access token'); |       const err = new Error('Missing access token'); | ||||||
|  | @ -208,6 +203,26 @@ const startWorker = (workerId) => { | ||||||
|     accountFromToken(token, req, next); |     accountFromToken(token, req, next); | ||||||
|   }; |   }; | ||||||
| 
 | 
 | ||||||
|  |   const wsVerifyClient = (info, cb) => { | ||||||
|  |     accountFromRequest(info.req, err => { | ||||||
|  |       if (!err) { | ||||||
|  |         cb(true, undefined, undefined); | ||||||
|  |       } else { | ||||||
|  |         log.error(info.req.requestId, err.toString()); | ||||||
|  |         cb(false, 401, 'Unauthorized'); | ||||||
|  |       } | ||||||
|  |     }); | ||||||
|  |   }; | ||||||
|  | 
 | ||||||
|  |   const authenticationMiddleware = (req, res, next) => { | ||||||
|  |     if (req.method === 'OPTIONS') { | ||||||
|  |       next(); | ||||||
|  |       return; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     accountFromRequest(req, next); | ||||||
|  |   }; | ||||||
|  | 
 | ||||||
|   const errorMiddleware = (err, req, res, next) => { |   const errorMiddleware = (err, req, res, next) => { | ||||||
|     log.error(req.requestId, err.toString()); |     log.error(req.requestId, err.toString()); | ||||||
|     res.writeHead(err.statusCode || 500, { 'Content-Type': 'application/json' }); |     res.writeHead(err.statusCode || 500, { 'Content-Type': 'application/json' }); | ||||||
|  | @ -352,10 +367,12 @@ const startWorker = (workerId) => { | ||||||
|     streamFrom(`timeline:hashtag:${req.query.tag}:local`, req, streamToHttp(req, res), streamHttpEnd(req), true); |     streamFrom(`timeline:hashtag:${req.query.tag}:local`, req, streamToHttp(req, res), streamHttpEnd(req), true); | ||||||
|   }); |   }); | ||||||
| 
 | 
 | ||||||
|  |   const wss    = new WebSocket.Server({ server, verifyClient: wsVerifyClient }); | ||||||
|  | 
 | ||||||
|   wss.on('connection', ws => { |   wss.on('connection', ws => { | ||||||
|     const location = url.parse(ws.upgradeReq.url, true); |     const req      = ws.upgradeReq; | ||||||
|     const token    = location.query.access_token; |     const location = url.parse(req.url, true); | ||||||
|     const req      = { requestId: uuid.v4() }; |     req.requestId  = uuid.v4(); | ||||||
| 
 | 
 | ||||||
|     ws.isAlive = true; |     ws.isAlive = true; | ||||||
| 
 | 
 | ||||||
|  | @ -363,33 +380,25 @@ const startWorker = (workerId) => { | ||||||
|       ws.isAlive = true; |       ws.isAlive = true; | ||||||
|     }); |     }); | ||||||
| 
 | 
 | ||||||
|     accountFromToken(token, req, err => { |     switch(location.query.stream) { | ||||||
|       if (err) { |     case 'user': | ||||||
|         log.error(req.requestId, err); |       streamFrom(`timeline:${req.accountId}`, req, streamToWs(req, ws), streamWsEnd(req, ws)); | ||||||
|         ws.close(); |       break; | ||||||
|         return; |     case 'public': | ||||||
|       } |       streamFrom('timeline:public', req, streamToWs(req, ws), streamWsEnd(req, ws), true); | ||||||
| 
 |       break; | ||||||
|       switch(location.query.stream) { |     case 'public:local': | ||||||
|       case 'user': |       streamFrom('timeline:public:local', req, streamToWs(req, ws), streamWsEnd(req, ws), true); | ||||||
|         streamFrom(`timeline:${req.accountId}`, req, streamToWs(req, ws), streamWsEnd(req, ws)); |       break; | ||||||
|         break; |     case 'hashtag': | ||||||
|       case 'public': |       streamFrom(`timeline:hashtag:${location.query.tag}`, req, streamToWs(req, ws), streamWsEnd(req, ws), true); | ||||||
|         streamFrom('timeline:public', req, streamToWs(req, ws), streamWsEnd(req, ws), true); |       break; | ||||||
|         break; |     case 'hashtag:local': | ||||||
|       case 'public:local': |       streamFrom(`timeline:hashtag:${location.query.tag}:local`, req, streamToWs(req, ws), streamWsEnd(req, ws), true); | ||||||
|         streamFrom('timeline:public:local', req, streamToWs(req, ws), streamWsEnd(req, ws), true); |       break; | ||||||
|         break; |     default: | ||||||
|       case 'hashtag': |       ws.close(); | ||||||
|         streamFrom(`timeline:hashtag:${location.query.tag}`, req, streamToWs(req, ws), streamWsEnd(req, ws), true); |     } | ||||||
|         break; |  | ||||||
|       case 'hashtag:local': |  | ||||||
|         streamFrom(`timeline:hashtag:${location.query.tag}:local`, req, streamToWs(req, ws), streamWsEnd(req, ws), true); |  | ||||||
|         break; |  | ||||||
|       default: |  | ||||||
|         ws.close(); |  | ||||||
|       } |  | ||||||
|     }); |  | ||||||
|   }); |   }); | ||||||
| 
 | 
 | ||||||
|   const wsInterval = setInterval(() => { |   const wsInterval = setInterval(() => { | ||||||
|  |  | ||||||
		Loading…
	
		Reference in a new issue