Source code for websockets.auth

"""
:mod:`websockets.auth` provides HTTP Basic Authentication according to
:rfc:`7235` and :rfc:`7617`.

"""


import functools
import http
from typing import Any, Awaitable, Callable, Iterable, Optional, Tuple, Type, Union

from .exceptions import InvalidHeader
from .headers import build_www_authenticate_basic, parse_authorization_basic
from .http import Headers
from .server import HTTPResponse, WebSocketServerProtocol


__all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"]

Credentials = Tuple[str, str]


def is_credentials(value: Any) -> bool:
    try:
        username, password = value
    except (TypeError, ValueError):
        return False
    else:
        return isinstance(username, str) and isinstance(password, str)


[docs]class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol): """ WebSocket server protocol that enforces HTTP Basic Auth. """ def __init__( self, *args: Any, realm: str, check_credentials: Callable[[str, str], Awaitable[bool]], **kwargs: Any, ) -> None: self.realm = realm self.check_credentials = check_credentials super().__init__(*args, **kwargs)
[docs] async def process_request( self, path: str, request_headers: Headers ) -> Optional[HTTPResponse]: """ Check HTTP Basic Auth and return a HTTP 401 or 403 response if needed. If authentication succeeds, the username of the authenticated user is stored in the ``username`` attribute. """ try: authorization = request_headers["Authorization"] except KeyError: return ( http.HTTPStatus.UNAUTHORIZED, [("WWW-Authenticate", build_www_authenticate_basic(self.realm))], b"Missing credentials\n", ) try: username, password = parse_authorization_basic(authorization) except InvalidHeader: return ( http.HTTPStatus.UNAUTHORIZED, [("WWW-Authenticate", build_www_authenticate_basic(self.realm))], b"Unsupported credentials\n", ) if not await self.check_credentials(username, password): return ( http.HTTPStatus.UNAUTHORIZED, [("WWW-Authenticate", build_www_authenticate_basic(self.realm))], b"Invalid credentials\n", ) self.username = username return await super().process_request(path, request_headers)
[docs]def basic_auth_protocol_factory( realm: str, credentials: Optional[Union[Credentials, Iterable[Credentials]]] = None, check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None, create_protocol: Type[ BasicAuthWebSocketServerProtocol ] = BasicAuthWebSocketServerProtocol, ) -> Callable[[Any], BasicAuthWebSocketServerProtocol]: """ Protocol factory that enforces HTTP Basic Auth. ``basic_auth_protocol_factory`` is designed to integrate with :func:`~websockets.server.serve` like this:: websockets.serve( ..., create_protocol=websockets.basic_auth_protocol_factory( realm="my dev server", credentials=("hello", "iloveyou"), ) ) ``realm`` indicates the scope of protection. It should contain only ASCII characters because the encoding of non-ASCII characters is undefined. Refer to section 2.2 of :rfc:`7235` for details. ``credentials`` defines hard coded authorized credentials. It can be a ``(username, password)`` pair or a list of such pairs. ``check_credentials`` defines a coroutine that checks whether credentials are authorized. This coroutine receives ``username`` and ``password`` arguments and returns a :class:`bool`. One of ``credentials`` or ``check_credentials`` must be provided but not both. By default, ``basic_auth_protocol_factory`` creates a factory for building :class:`BasicAuthWebSocketServerProtocol` instances. You can override this with the ``create_protocol`` parameter. :param realm: scope of protection :param credentials: hard coded credentials :param check_credentials: coroutine that verifies credentials :raises TypeError: if the credentials argument has the wrong type """ if (credentials is None) == (check_credentials is None): raise TypeError("provide either credentials or check_credentials") if credentials is not None: if is_credentials(credentials): async def check_credentials(username: str, password: str) -> bool: return (username, password) == credentials elif isinstance(credentials, Iterable): credentials_list = list(credentials) if all(is_credentials(item) for item in credentials_list): credentials_dict = dict(credentials_list) async def check_credentials(username: str, password: str) -> bool: return credentials_dict.get(username) == password else: raise TypeError(f"invalid credentials argument: {credentials}") else: raise TypeError(f"invalid credentials argument: {credentials}") return functools.partial( create_protocol, realm=realm, check_credentials=check_credentials )