"""
:mod:`websockets.auth` provides HTTP Basic Authentication according to
:rfc:`7235` and :rfc:`7617`.
"""
import functools
import http
from typing import Any, Awaitable, Callable, Iterable, Optional, Tuple, Type, Union
from .exceptions import InvalidHeader
from .headers import build_www_authenticate_basic, parse_authorization_basic
from .http import Headers
from .server import HTTPResponse, WebSocketServerProtocol
__all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"]
Credentials = Tuple[str, str]
def is_credentials(value: Any) -> bool:
try:
username, password = value
except (TypeError, ValueError):
return False
else:
return isinstance(username, str) and isinstance(password, str)
[docs]class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol):
"""
WebSocket server protocol that enforces HTTP Basic Auth.
"""
def __init__(
self,
*args: Any,
realm: str,
check_credentials: Callable[[str, str], Awaitable[bool]],
**kwargs: Any,
) -> None:
self.realm = realm
self.check_credentials = check_credentials
super().__init__(*args, **kwargs)
[docs] async def process_request(
self, path: str, request_headers: Headers
) -> Optional[HTTPResponse]:
"""
Check HTTP Basic Auth and return a HTTP 401 or 403 response if needed.
If authentication succeeds, the username of the authenticated user is
stored in the ``username`` attribute.
"""
try:
authorization = request_headers["Authorization"]
except KeyError:
return (
http.HTTPStatus.UNAUTHORIZED,
[("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
b"Missing credentials\n",
)
try:
username, password = parse_authorization_basic(authorization)
except InvalidHeader:
return (
http.HTTPStatus.UNAUTHORIZED,
[("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
b"Unsupported credentials\n",
)
if not await self.check_credentials(username, password):
return (
http.HTTPStatus.UNAUTHORIZED,
[("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
b"Invalid credentials\n",
)
self.username = username
return await super().process_request(path, request_headers)
[docs]def basic_auth_protocol_factory(
realm: str,
credentials: Optional[Union[Credentials, Iterable[Credentials]]] = None,
check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None,
create_protocol: Type[
BasicAuthWebSocketServerProtocol
] = BasicAuthWebSocketServerProtocol,
) -> Callable[[Any], BasicAuthWebSocketServerProtocol]:
"""
Protocol factory that enforces HTTP Basic Auth.
``basic_auth_protocol_factory`` is designed to integrate with
:func:`~websockets.server.serve` like this::
websockets.serve(
...,
create_protocol=websockets.basic_auth_protocol_factory(
realm="my dev server",
credentials=("hello", "iloveyou"),
)
)
``realm`` indicates the scope of protection. It should contain only ASCII
characters because the encoding of non-ASCII characters is undefined.
Refer to section 2.2 of :rfc:`7235` for details.
``credentials`` defines hard coded authorized credentials. It can be a
``(username, password)`` pair or a list of such pairs.
``check_credentials`` defines a coroutine that checks whether credentials
are authorized. This coroutine receives ``username`` and ``password``
arguments and returns a :class:`bool`.
One of ``credentials`` or ``check_credentials`` must be provided but not
both.
By default, ``basic_auth_protocol_factory`` creates a factory for building
:class:`BasicAuthWebSocketServerProtocol` instances. You can override this
with the ``create_protocol`` parameter.
:param realm: scope of protection
:param credentials: hard coded credentials
:param check_credentials: coroutine that verifies credentials
:raises TypeError: if the credentials argument has the wrong type
"""
if (credentials is None) == (check_credentials is None):
raise TypeError("provide either credentials or check_credentials")
if credentials is not None:
if is_credentials(credentials):
async def check_credentials(username: str, password: str) -> bool:
return (username, password) == credentials
elif isinstance(credentials, Iterable):
credentials_list = list(credentials)
if all(is_credentials(item) for item in credentials_list):
credentials_dict = dict(credentials_list)
async def check_credentials(username: str, password: str) -> bool:
return credentials_dict.get(username) == password
else:
raise TypeError(f"invalid credentials argument: {credentials}")
else:
raise TypeError(f"invalid credentials argument: {credentials}")
return functools.partial(
create_protocol, realm=realm, check_credentials=check_credentials
)