diff --git a/flask_cas/routing.py b/flask_cas/routing.py index dc2ee5d6..cb2b532c 100644 --- a/flask_cas/routing.py +++ b/flask_cas/routing.py @@ -8,7 +8,7 @@ from urllib.error import URLError from urllib.request import urlopen import flask -from flask import current_app +from flask import current_app, request from xmltodict import parse from .cas_urls import create_cas_login_url @@ -86,23 +86,27 @@ def logout(): flask.session.pop(cas_attributes_session_key, None) flask.session.pop(cas_token_session_key, None) # added by EV flask.session.pop("CAS_EDT_ID", None) # added by EV - cas_after_logout = current_app.config.get("CAS_AFTER_LOGOUT") - if cas_after_logout: - # If config starts with http, use it as dest URL. - # Else, build Flask URL - dest_url = ( - cas_after_logout - if cas_after_logout.startswith("http") - else flask.url_for(cas_after_logout, _external=True) - ) - redirect_url = create_cas_logout_url( - current_app.config["CAS_SERVER"], - current_app.config["CAS_LOGOUT_ROUTE"], - dest_url, - ) + cas_logout_route = current_app.config.get("CAS_LOGOUT_ROUTE") + cas_server = current_app.config.get("CAS_SERVER") + if cas_server: + if cas_after_logout and cas_logout_route: + # If config starts with http, use it as dest URL. + # Else, build Flask URL + dest_url = ( + cas_after_logout + if cas_after_logout.startswith("http") + else flask.url_for(cas_after_logout, _external=True) + ) + redirect_url = create_cas_logout_url( + cas_server, + cas_logout_route, + dest_url, + ) + else: + redirect_url = create_cas_logout_url(cas_server, None) else: - redirect_url = create_cas_logout_url(current_app.config["CAS_SERVER"], None) + redirect_url = request.root_url current_app.logger.debug(f"cas.logout: redirecting to {redirect_url}") return flask.redirect(redirect_url) @@ -134,10 +138,10 @@ def validate(ticket): ticket, ) - current_app.logger.debug("Making GET request to {0}".format(cas_validate_url)) + current_app.logger.debug(f"Making GET request to {cas_validate_url}") xml_from_dict = {} - isValid = False + is_valid = False if current_app.config.get("CAS_SSL_VERIFY"): ssl_context = ssl.SSLContext() @@ -161,7 +165,7 @@ def validate(ticket): .decode("utf8", "ignore") ) xml_from_dict = parse(xmldump) - isValid = ( + is_valid = ( True if "cas:authenticationSuccess" in xml_from_dict["cas:serviceResponse"] else False @@ -176,7 +180,7 @@ def validate(ticket): "erreur connexion au serveur CAS: vérifiez le certificat SSL" ) - if isValid: + if is_valid: current_app.logger.debug("valid") xml_from_dict = xml_from_dict["cas:serviceResponse"][ "cas:authenticationSuccess" @@ -207,4 +211,4 @@ def validate(ticket): else: current_app.logger.debug("invalid") - return isValid + return is_valid