Source code for crossbar.bridge.rest.common

#####################################################################################
#
#  Copyright (c) typedef int GmbH
#  SPDX-License-Identifier: EUPL-1.2
#
#####################################################################################

import base64
import binascii
import datetime
import hashlib
import hmac
import json
import os
from ipaddress import ip_address, ip_network
from typing import Any, Dict

from autobahn.twisted.wamp import ApplicationSession
from autobahn.wamp.exception import ApplicationError
from autobahn.websocket.utf8validator import Utf8Validator
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import hmac as hazmat_hmac
from twisted.web import server
from twisted.web.resource import Resource
from txaio import make_logger

from crossbar._compat import native_string
from crossbar._log_categories import log_categories
from crossbar._util import dump_json

[docs] _validator = Utf8Validator()
[docs] _ALLOWED_CONTENT_TYPES = set([b"application/json"])
[docs] class _InvalidUnicode(BaseException): """ Invalid Unicode was found. """
# used for constant-time compares
[docs] _nonce = os.urandom(32)
[docs] def _hmac_sha256(key, data): """ :returns: the HMAC-SHA256 of 'data' using 'key' """ h = hazmat_hmac.HMAC(key, hashes.SHA256(), default_backend()) h.update(data) return h.finalize()
[docs] def _constant_compare(a, b): """ Compare the two byte-strings 'a' and 'b' using a constant-time algorithm. The byte-strings should be the same length. """ return _hmac_sha256(_nonce, a.encode("ascii")) == _hmac_sha256(_nonce, b.encode("ascii"))
[docs] def _confirm_github_signature(request, secret_token, raw_body): """ confirm that signature headers from GitHub are present and valid """ # stuff just "won't work" unless we gives bytes objects to the # underlying crypto primitives if not isinstance(secret_token, bytes): secret_token = secret_token.encode("ascii") assert isinstance(raw_body, bytes) # must have the header to continue if not request.requestHeaders.getRawHeaders("X-Hub-Signature"): return False purported_signature = str(request.requestHeaders.getRawHeaders("X-Hub-Signature")[0]).lower() # NOTE: never use SHA1 for new code ... but GitHub signatures are # SHA1, so we have to here :( h = hazmat_hmac.HMAC(secret_token, hashes.SHA1(), default_backend()) # nosec h.update(raw_body) our_signature = "sha1={}".format(binascii.b2a_hex(h.finalize()).decode("ascii")) return _constant_compare(our_signature, purported_signature)
[docs] class _CommonResource(Resource): """ Shared components between PublisherResource and CallerResource. """
[docs] isLeaf = True
[docs] decode_as_json = True
def __init__(self, options: Dict[str, Any], session: ApplicationSession): """ :param options: Options for path service from configuration. :param session: WAMP session to be used for forwarding events / calls. """ Resource.__init__(self)
[docs] self._options = options
[docs] self._session = session
[docs] self._debug = False
[docs] self.log = make_logger()
if "debug" in options and options["debug"]: self._debug = True
[docs] self._key = None
if "key" in options: self._key = options["key"].encode("utf8")
[docs] self._secret = None
if "secret" in options: self._secret = options["secret"].encode("utf8")
[docs] self._post_body_limit = int(options.get("post_body_limit", 0))
[docs] self._timestamp_delta_limit = int(options.get("timestamp_delta_limit", 300))
[docs] self._require_ip = None
if "require_ip" in options: self._require_ip = [ip_network(net) for net in options["require_ip"]]
[docs] self._require_tls = options.get("require_tls", None)
[docs] def _deny_request(self, request, code, **kwargs): """ Called when client request is denied. """ if "log_category" not in kwargs.keys(): kwargs["log_category"] = "AR" + str(code) self.log.debug(code=code, **kwargs) error_str = log_categories[kwargs["log_category"]].format(**kwargs) body = dump_json({"error": error_str, "args": [], "kwargs": {}}, True).encode("utf8") request.setResponseCode(code) return body
[docs] def _fail_request(self, request, **kwargs): """ Called when client request fails. """ res = {} err = kwargs["failure"] if isinstance(err.value, ApplicationError): res["error"] = err.value.error if err.value.args: res["args"] = err.value.args else: res["args"] = [] if err.value.kwargs: res["kwargs"] = err.value.kwargs else: res["kwargs"] = {} # This is a user-level error, not a CB error, so return 200 code = 200 else: # This is a "CB" error, so return 500 and a generic error res["error"] = "wamp.error.runtime_error" res["args"] = ["Sorry, Crossbar.io has encountered a problem."] res["kwargs"] = {} # CB-level error, return 500 code = 500 self.log.failure(None, failure=err, log_category="AR500") body = json.dumps(res).encode("utf8") if "log_category" not in kwargs.keys(): kwargs["log_category"] = "AR" + str(code) self.log.debug(code=code, **kwargs) request.setResponseCode(code) request.write(body) request.finish()
[docs] def _complete_request(self, request, code, body, **kwargs): """ Called when client request is complete. """ if "log_category" not in kwargs.keys(): kwargs["log_category"] = "AR" + str(code) self.log.debug(code=code, **kwargs) request.setResponseCode(code) request.write(body)
[docs] def _set_common_headers(self, request): """ Set common HTTP response headers. """ origin = request.getHeader(b"origin") if origin is None or origin == b"null": origin = b"*" request.setHeader(b"access-control-allow-origin", origin) request.setHeader(b"access-control-allow-credentials", b"true") request.setHeader(b"cache-control", b"no-store,no-cache,must-revalidate,max-age=0") request.setHeader(b"content-type", b"application/json; charset=UTF-8") headers = request.getHeader(b"access-control-request-headers") if headers is not None: request.setHeader(b"access-control-allow-headers", headers)
[docs] def render(self, request): """ Handle the request. All requests start here. """ self.log.debug(log_category="AR100", method=request.method, path=request.path) self._set_common_headers(request) try: if request.method not in (b"POST", b"PUT", b"OPTIONS"): return self._deny_request(request, 405, method=request.method, allowed="POST, PUT") else: if request.method == b"OPTIONS": # http://greenbytes.de/tech/webdav/rfc2616.html#rfc.section.14.7 request.setHeader(b"allow", b"POST,PUT,OPTIONS") # https://www.w3.org/TR/cors/#access-control-allow-methods-response-header request.setHeader(b"access-control-allow-methods", b"POST,PUT,OPTIONS") request.setResponseCode(200) return b"" else: return self._render_request(request) except Exception as e: self.log.failure(log_category="CB501", exc=e) return self._deny_request(request, 500, log_category="CB500")
[docs] def _render_request(self, request): """ Receives an HTTP/POST|PUT request, and then calls the Publisher/Caller processor. """ # read HTTP/POST|PUT body body = request.content.read() args = {native_string(x): y[0] for x, y in request.args.items()} headers = request.requestHeaders # check content type + charset encoding # content_type_header = headers.getRawHeaders(b"content-type", []) if len(content_type_header) > 0: content_type_elements = [x.strip().lower() for x in content_type_header[0].split(b";")] else: content_type_elements = [] if self.decode_as_json: # if the client sent a content type, it MUST be one of _ALLOWED_CONTENT_TYPES # (but we allow missing content type .. will catch later during JSON # parsing anyway) if len(content_type_elements) > 0 and content_type_elements[0] not in _ALLOWED_CONTENT_TYPES: return self._deny_request( request, 400, accepted=list(_ALLOWED_CONTENT_TYPES), given=content_type_elements[0], log_category="AR452", ) encoding_parts = {} if len(content_type_elements) > 1: try: for item in content_type_elements: if b"=" not in item: # Don't bother looking at things "like application/json" continue # Parsing things like: # charset=utf-8 _ = native_string(item).split("=") assert len(_) == 2 # We don't want duplicates key = _[0].strip().lower() assert key not in encoding_parts encoding_parts[key] = _[1].strip().lower() except: return self._deny_request(request, 400, log_category="AR450") charset_encoding = encoding_parts.get("charset", "utf-8") if charset_encoding not in ["utf-8", "utf8"]: return self._deny_request(request, 400, log_category="AR450") # enforce "post_body_limit" # body_length = len(body) content_length_header = headers.getRawHeaders(b"content-length", []) if len(content_length_header) == 1: content_length = int(content_length_header[0]) elif len(content_length_header) > 1: return self._deny_request(request, 400, log_category="AR463") else: content_length = body_length if body_length != content_length: # Prevent the body length from being different to the given # Content-Length. This is so that clients can't lie and bypass # length restrictions by giving an incorrect header with a large # body. return self._deny_request(request, 400, bodylen=body_length, conlen=content_length, log_category="AR465") if self._post_body_limit and content_length > self._post_body_limit: return self._deny_request(request, 413, length=content_length, accepted=self._post_body_limit) # # if we were given a GitHub token, check for a valid signature # header # github_secret = self._options.get("github_secret", "") if github_secret: if not _confirm_github_signature(request, github_secret, body): return self._deny_request( request, 400, log_cagegory="AR467", ) # # parse/check HTTP/POST|PUT query parameters # # key # if "key" in args: key_str = args["key"] else: if self._secret: return self._deny_request(request, 400, reason="'key' field missing", log_category="AR461") # timestamp # if "timestamp" in args: timestamp_str = args["timestamp"] try: ts = datetime.datetime.strptime(native_string(timestamp_str), "%Y-%m-%dT%H:%M:%S.%fZ") delta = abs((ts - datetime.datetime.utcnow()).total_seconds()) if self._timestamp_delta_limit and delta > self._timestamp_delta_limit: return self._deny_request(request, 400, log_category="AR464") except ValueError: return self._deny_request( request, 400, reason="invalid timestamp '{0}' (must be UTC/ISO-8601, e.g. '2011-10-14T16:59:51.123Z')".format( native_string(timestamp_str) ), log_category="AR462", ) else: if self._secret: return self._deny_request( request, 400, reason="signed request required, but mandatory 'timestamp' field missing", log_category="AR461", ) # seq # if "seq" in args: seq_str = args["seq"] try: # FIXME: check sequence seq = int(seq_str) # noqa except: return self._deny_request( request, 400, reason="invalid sequence number '{0}' (must be an integer)".format(native_string(seq_str)), log_category="AR462", ) else: if self._secret: return self._deny_request(request, 400, reason="'seq' field missing", log_category="AR461") # nonce # if "nonce" in args: nonce_str = args["nonce"] try: # FIXME: check nonce nonce = int(nonce_str) # noqa except: return self._deny_request( request, 400, reason="invalid nonce '{0}' (must be an integer)".format(native_string(nonce_str)), log_category="AR462", ) else: if self._secret: return self._deny_request(request, 400, reason="'nonce' field missing", log_category="AR461") # signature # if "signature" in args: signature_str = args["signature"] else: if self._secret: return self._deny_request(request, 400, reason="'signature' field missing", log_category="AR461") # do more checks if signed requests are required # if self._secret: if key_str != self._key: return self._deny_request( request, 401, reason="unknown key '{0}' in signed request".format(native_string(key_str)), log_category="AR460", ) # Compute signature: HMAC[SHA256]_{secret} (key | timestamp | seq | nonce | body) => signature hm = hmac.new(self._secret, None, hashlib.sha256) hm.update(key_str) hm.update(timestamp_str) hm.update(seq_str) hm.update(nonce_str) hm.update(body) signature_recomputed = base64.urlsafe_b64encode(hm.digest()) if signature_str != signature_recomputed: return self._deny_request(request, 401, log_category="AR459") else: self.log.debug("REST request signature valid.", log_category="AR203") # user_agent = headers.get("user-agent", "unknown") client_ip = request.getClientIP() is_secure = request.isSecure() # enforce client IP address # if self._require_ip: ip = ip_address(client_ip) allowed = False for net in self._require_ip: if ip in net: allowed = True break if not allowed: return self._deny_request(request, 400, log_category="AR466") # enforce TLS # if self._require_tls: if not is_secure: return self._deny_request(request, 400, reason="request denied because not using TLS") # FIXME: authorize request authorized = True if not authorized: return self._deny_request(request, 401, reason="not authorized") _validator.reset() validation_result = _validator.validate(body) # validate() returns a 4-tuple, of which item 0 is whether it # is valid if not validation_result[0]: return self._deny_request(request, 400, log_category="AR451") event = body.decode("utf8") if self.decode_as_json: try: event = json.loads(event) except Exception as e: return self._deny_request(request, 400, exc=e, log_category="AR453") if not isinstance(event, dict): return self._deny_request(request, 400, log_category="AR454") d = self._process(request, event) if isinstance(d, bytes): # If it's bytes, return it directly return d else: # If it's a Deferred, let it run. d.addCallback(lambda _: request.finish()) return server.NOT_DONE_YET
[docs] def _process(self, request, event): raise NotImplementedError()