Initial commit
This commit is contained in:
159
ProtuSAML/__init__.py
Normal file
159
ProtuSAML/__init__.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Callable, Dict, Set, Tuple
|
||||
|
||||
import attr
|
||||
import saml2
|
||||
import saml2.response
|
||||
from saml2.client import Saml2Client
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.config import ConfigError
|
||||
from synapse.handlers.sso import MappingException
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.types import (
|
||||
UserID,
|
||||
map_username_to_mxid_localpart,
|
||||
mxid_localpart_allowed_characters,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DOT_REPLACE_PATTERN = re.compile(
|
||||
("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
|
||||
)
|
||||
|
||||
|
||||
def dot_replace_for_mxid(username: str) -> str:
|
||||
"""Replace any characters which are not allowed in Matrix IDs with a dot."""
|
||||
username = username.lower()
|
||||
username = DOT_REPLACE_PATTERN.sub(".", username)
|
||||
|
||||
# regular mxids aren't allowed to start with an underscore either
|
||||
username = re.sub("^_", "", username)
|
||||
return username
|
||||
|
||||
|
||||
MXID_MAPPER_MAP = {
|
||||
"hexencode": map_username_to_mxid_localpart,
|
||||
"dotreplace": dot_replace_for_mxid,
|
||||
} # type: Dict[str, Callable[[str], str]]
|
||||
|
||||
|
||||
@attr.s
|
||||
class SamlConfig:
|
||||
mxid_source_attribute = attr.ib()
|
||||
mxid_mapper = attr.ib()
|
||||
|
||||
|
||||
class SAMLMappingProviderProtuEdition:
|
||||
__version__ = "0.0.1"
|
||||
|
||||
def __init__(self, parsed_config: SamlConfig, module_api: ModuleApi):
|
||||
"""The default SAML user mapping provider
|
||||
Args:
|
||||
parsed_config: Module configuration
|
||||
module_api: module api proxy
|
||||
"""
|
||||
self._mxid_source_attribute = parsed_config.mxid_source_attribute
|
||||
self._mxid_mapper = parsed_config.mxid_mapper
|
||||
|
||||
self._grandfathered_mxid_source_attribute = (
|
||||
module_api._hs.config.saml2_grandfathered_mxid_source_attribute
|
||||
)
|
||||
|
||||
def get_remote_user_id(
|
||||
self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str
|
||||
) -> str:
|
||||
"""Extracts the remote user id from the SAML response"""
|
||||
try:
|
||||
return saml_response.ava["uid"][0]
|
||||
except KeyError:
|
||||
logger.warning("SAML2 response lacks a 'uid' attestation")
|
||||
raise MappingException("'uid' not in SAML2 response")
|
||||
|
||||
def saml_response_to_user_attributes(
|
||||
self,
|
||||
saml_response: saml2.response.AuthnResponse,
|
||||
failures: int,
|
||||
client_redirect_url: str,
|
||||
) -> dict:
|
||||
"""Maps some text from a SAML response to attributes of a new user
|
||||
Args:
|
||||
saml_response: A SAML auth response object
|
||||
failures: How many times a call to this function with this
|
||||
saml_response has resulted in a failure
|
||||
client_redirect_url: where the client wants to redirect to
|
||||
Returns:
|
||||
dict: A dict containing new user attributes. Possible keys:
|
||||
* mxid_localpart (str): Required. The localpart of the user's mxid
|
||||
* displayname (str): The displayname of the user
|
||||
* emails (list[str]): Any emails for the user
|
||||
"""
|
||||
try:
|
||||
mxid_source = saml_response.ava[self._mxid_source_attribute][0]
|
||||
except KeyError:
|
||||
logger.warning(
|
||||
"SAML2 response lacks a '%s' attestation",
|
||||
self._mxid_source_attribute,
|
||||
)
|
||||
raise SynapseError(
|
||||
400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
|
||||
)
|
||||
|
||||
# Use the configured mapper for this mxid_source
|
||||
localpart = self._mxid_mapper(mxid_source)
|
||||
|
||||
# Append suffix integer if last call to this function failed to produce
|
||||
# a usable mxid.
|
||||
localpart += str(failures) if failures else ""
|
||||
|
||||
# Retrieve the display name from the saml response
|
||||
# If displayname is None, the mxid_localpart will be used instead
|
||||
displayname = saml_response.ava.get("displayName", [None])[0]
|
||||
|
||||
# Retrieve any emails present in the saml response
|
||||
emails = saml_response.ava.get("email", [])
|
||||
|
||||
return {
|
||||
"mxid_localpart": None,
|
||||
"displayname": displayname,
|
||||
"emails": emails,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def parse_config(config: dict) -> SamlConfig:
|
||||
"""Parse the dict provided by the homeserver's config
|
||||
Args:
|
||||
config: A dictionary containing configuration options for this provider
|
||||
Returns:
|
||||
SamlConfig: A custom config object for this module
|
||||
"""
|
||||
# Parse config options and use defaults where necessary
|
||||
mxid_source_attribute = config.get("mxid_source_attribute", "uid")
|
||||
mapping_type = config.get("mxid_mapping", "hexencode")
|
||||
|
||||
# Retrieve the associating mapping function
|
||||
try:
|
||||
mxid_mapper = MXID_MAPPER_MAP[mapping_type]
|
||||
except KeyError:
|
||||
raise ConfigError(
|
||||
"saml2_config.user_mapping_provider.config: '%s' is not a valid "
|
||||
"mxid_mapping value" % (mapping_type,)
|
||||
)
|
||||
|
||||
return SamlConfig(mxid_source_attribute, mxid_mapper)
|
||||
|
||||
@staticmethod
|
||||
def get_saml_attributes(config: SamlConfig) -> Tuple[Set[str], Set[str]]:
|
||||
"""Returns the required attributes of a SAML
|
||||
Args:
|
||||
config: A SamlConfig object containing configuration params for this provider
|
||||
Returns:
|
||||
The first set equates to the saml auth response
|
||||
attributes that are required for the module to function, whereas the
|
||||
second set consists of those attributes which can be used if
|
||||
available, but are not necessary
|
||||
"""
|
||||
return {"uid", config.mxid_source_attribute}, {"displayName", "email"}
|
||||
Reference in New Issue
Block a user