diff --git a/app/auth/forms.py b/app/auth/forms.py index 3d70054d8..bfb8676b5 100644 --- a/app/auth/forms.py +++ b/app/auth/forms.py @@ -4,9 +4,10 @@ TODO: à revoir complètement pour reprendre ZScoUsers et les pages d'authentification """ - +from urllib.parse import urlparse, urljoin +from flask import request, url_for, redirect from flask_wtf import FlaskForm -from wtforms import StringField, PasswordField, BooleanField, SubmitField +from wtforms import BooleanField, HiddenField, PasswordField, StringField, SubmitField from wtforms.validators import ValidationError, DataRequired, Email, EqualTo from app.auth.models import User, is_valid_password @@ -14,13 +15,45 @@ from app.auth.models import User, is_valid_password _ = lambda x: x # sans babel _l = _ +# See http://flask.pocoo.org/snippets/63/ +def is_safe_url(target): + ref_url = urlparse(request.host_url) + test_url = urlparse(urljoin(request.host_url, target)) + return test_url.scheme in ("http", "https") and ref_url.netloc == test_url.netloc -class LoginForm(FlaskForm): + +def get_redirect_target(): + for target in request.args.get("next"), request.referrer: + if not target: + continue + if is_safe_url(target): + return target + + +class RedirectForm(FlaskForm): + next = HiddenField() + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if not self.next.data: + self.next.data = get_redirect_target() or "" + + def redirect(self, endpoint="index", **values): + if is_safe_url(self.next.data): + return redirect(self.next.data) + target = get_redirect_target() + return redirect(target or url_for(endpoint, **values)) + + +class LoginForm(RedirectForm): user_name = StringField(_l("Nom d'utilisateur"), validators=[DataRequired()]) password = PasswordField(_l("Mot de passe"), validators=[DataRequired()]) remember_me = BooleanField(_l("mémoriser la connexion")) submit = SubmitField(_l("Suivant")) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + class UserCreationForm(FlaskForm): user_name = StringField(_l("Nom d'utilisateur"), validators=[DataRequired()]) diff --git a/app/auth/routes.py b/app/auth/routes.py index 1f0259ab5..16df31354 100644 --- a/app/auth/routes.py +++ b/app/auth/routes.py @@ -42,10 +42,10 @@ def login(): return redirect(url_for("auth.login")) login_user(user, remember=form.remember_me.data) current_app.logger.info("login: success (%s)", form.user_name.data) - next_page = request.args.get("next") - if not next_page or url_parse(next_page).netloc != "": - next_page = url_for("scodoc.index") - return redirect(next_page) + # next_page = request.args.get("next") + # if not next_page or url_parse(next_page).netloc != "": + # next_page = url_for("scodoc.index") + return form.redirect("scodoc.index") message = request.args.get("message", "") return render_template( "auth/login.html", title=_("Sign In"), form=form, message=message