From 96e71d7c3290c732644e6f0dfa182acd8d14e57c Mon Sep 17 00:00:00 2001 From: realaravinth Date: Mon, 20 Jun 2022 17:00:25 +0530 Subject: [PATCH] fix: auto-redirect authenticated user when visiting login page --- accounts/tests.py | 26 ++++++++++++++++++++++++++ accounts/views.py | 6 ++++++ 2 files changed, 32 insertions(+) diff --git a/accounts/tests.py b/accounts/tests.py index 79564b2..4b49f1c 100644 --- a/accounts/tests.py +++ b/accounts/tests.py @@ -131,6 +131,32 @@ class LoginTest(TestCase): self.assertEqual(resp.status_code, 302) self.assertEqual(resp.headers["location"], reverse("accounts.login")) + def test_login_view_redirects_if_user_is_loggedin(self): + """ + Automatically redirect authenticated users that are visiting login view + """ + c = Client() + login_util(t=self, c=c, redirect_to="accounts.home") + + resp = c.get(reverse("accounts.login")) + self.assertEqual(resp.status_code, 302) + self.assertEqual(resp.headers["location"], reverse("accounts.home")) + + resp = c.post(reverse("accounts.login")) + self.assertEqual(resp.status_code, 302) + self.assertEqual(resp.headers["location"], reverse("accounts.home")) + + resp = c.get( + f"{reverse('accounts.login')}?next={reverse('dash.instances.list')}" + ) + self.assertEqual(resp.status_code, 302) + self.assertEqual(resp.headers["location"], reverse("dash.instances.list")) + + ctx = {"next": reverse("dash.instances.list")} + resp = c.post(reverse("accounts.login"), ctx) + self.assertEqual(resp.status_code, 302) + self.assertEqual(resp.headers["location"], reverse("dash.instances.list")) + class RegistrationTest(TestCase): def setUp(self): diff --git a/accounts/views.py b/accounts/views.py index 67d9d8f..813a78f 100644 --- a/accounts/views.py +++ b/accounts/views.py @@ -35,6 +35,12 @@ def login_view(request): "title": "Login", } + if request.user.is_authenticated: + data = request.GET if request.method == "GET" else request.POST + if "next" in data: + return redirect(data["next"]) + return redirect(reverse("accounts.home")) + if request.method == "GET": ctx = default_login_ctx() if "next" in request.GET: