diff --git a/accounts/decorators.py b/accounts/decorators.py index 3f46c84..af66f08 100644 --- a/accounts/decorators.py +++ b/accounts/decorators.py @@ -12,6 +12,9 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from django.shortcuts import redirect +from django.urls import reverse + from .utils import ConfirmAccess @@ -22,3 +25,19 @@ def confirm_access(function): ) return wrap + + +def redirect_if_authenticated(fn): + """ + Redirect authenticated users visiting sign in/sign up views + """ + + def wrap(request, *args, **kwargs): + if request.user.is_authenticated: + data = request.GET if request.method == "GET" else request.POST + if "next" in data: + return redirect(data["next"]) + return redirect(reverse("accounts.home")) + return fn(request, *args, **kwargs) + + return wrap diff --git a/accounts/views.py b/accounts/views.py index 813a78f..db39aaf 100644 --- a/accounts/views.py +++ b/accounts/views.py @@ -26,8 +26,10 @@ from django.urls import reverse from .models import AccountConfirmChallenge from .utils import send_verification_email, ConfirmAccess +from .decorators import redirect_if_authenticated +@redirect_if_authenticated @csrf_protect def login_view(request): def default_login_ctx(): @@ -35,12 +37,6 @@ def login_view(request): "title": "Login", } - if request.user.is_authenticated: - data = request.GET if request.method == "GET" else request.POST - if "next" in data: - return redirect(data["next"]) - return redirect(reverse("accounts.home")) - if request.method == "GET": ctx = default_login_ctx() if "next" in request.GET: @@ -84,6 +80,7 @@ def protected_view(request): return redirect(reverse("dash.home")) +@redirect_if_authenticated def default_login_url(request): if "next" in request.GET: ctx = {"next": request.GET["next"]} @@ -97,6 +94,8 @@ def logout_view(request): return redirect(reverse("accounts.login")) +@redirect_if_authenticated +@csrf_protect def register_view(request): def default_register_ctx(username=None, email=None): return { @@ -174,6 +173,7 @@ def register_view(request): return redirect(challenge.pending_url()) +@redirect_if_authenticated def verification_pending_view(request, public_ref): challenge = get_object_or_404(AccountConfirmChallenge, public_ref=public_ref) @@ -185,6 +185,7 @@ def verification_pending_view(request, public_ref): return render(request, "accounts/auth/verification-pending.html", context=ctx) +@redirect_if_authenticated def resend_verification_email_view(request, public_ref): challenge = get_object_or_404(AccountConfirmChallenge, public_ref=public_ref) send_verification_email(request, challenge=challenge)