diff --git a/.gitignore b/.gitignore index 34c73b2..17c6cc4 100644 --- a/.gitignore +++ b/.gitignore @@ -152,3 +152,4 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ keys +htmlcov/ diff --git a/Makefile b/Makefile index 7fab726..55f1726 100644 --- a/Makefile +++ b/Makefile @@ -14,6 +14,7 @@ default: ## Run app coverage: ## Generate test coverage report . ./venv/bin/activate && coverage run manage.py test . ./venv/bin/activate && coverage report -m + . ./venv/bin/activate && coverage html doc: ## Generates documentation $(call unimplemented) @@ -41,4 +42,4 @@ migrate: ## Run migrations $(call run_migrations) test: ## Run tests - @. ./venv/bin/activate && python manage.py test --parallel --with-coverage + @. ./venv/bin/activate && python manage.py test --parallel diff --git a/accounts/tests.py b/accounts/tests.py index 3a35a75..e2665b5 100644 --- a/accounts/tests.py +++ b/accounts/tests.py @@ -28,21 +28,36 @@ from .management.commands.rm_unverified_users import ( ) +def register_util(t: TestCase, username: str): + t.password = "password121231" + t.username = username + t.email = f"{t.username}@example.org" + t.user = get_user_model().objects.create( + username=t.username, + email=t.email, + ) + t.user.set_password(t.password) + t.user.save() + + +def login_util(t: TestCase, c: Client, redirect_to: str): + payload = { + "login": t.username, + "password": t.password, + } + resp = c.post(reverse("accounts.login"), payload) + t.assertEqual(resp.status_code, 302) + t.assertEqual(resp.headers["location"], reverse(redirect_to)) + + class LoginTest(TestCase): """ Tests create new app view """ def setUp(self): - self.password = "password121231" self.username = "create_new_app_tests" - self.email = f"{self.username}@example.org" - self.user = get_user_model().objects.create( - username=self.username, - email=self.email, - ) - self.user.set_password(self.password) - self.user.save() + register_util(t=self, username=self.username) def test_login_template_works(self): """ @@ -58,16 +73,10 @@ class LoginTest(TestCase): c = Client() # username login works - payload = { - "login": self.username, - "password": self.password, - } - resp = c.post(reverse("accounts.login"), payload) - self.assertEqual(resp.status_code, 302) - self.assertEqual(resp.headers["location"], reverse("accounts.home")) + login_util(t=self, c=c, redirect_to="accounts.home") # email login works - paylaod = { + payload = { "login": self.email, "password": self.password, } @@ -81,10 +90,24 @@ class LoginTest(TestCase): "password": self.user.email, } - resp = self.client.post(reverse("accounts.login"), paylaod) + resp = self.client.post(reverse("accounts.login"), payload) self.assertEqual(resp.status_code, 401) self.assertEqual(b"Login Failed" in resp.content, True) + # protected view works + resp = c.get(reverse("accounts.home")) + self.assertEqual(resp.status_code, 302) + self.assertEqual(resp.headers["location"], reverse("dash.home")) + + def test_default_login_uri_works(self): + """ + /accounts/login should redirect_to /login + """ + c = Client() + resp = c.get(reverse("accounts.default_login_url")) + self.assertEqual(resp.status_code, 302) + self.assertEqual(resp.headers["location"], reverse("accounts.login")) + class RegistrationTest(TestCase): def test_register_template_works(self): @@ -121,11 +144,24 @@ class RegistrationTest(TestCase): pending_url = challenge.pending_url() self.assertEqual(resp.headers["location"], pending_url) + resend_url = reverse("accounts.verify.resend", args=(challenge.public_ref,)) - resp = c.post(reverse("accounts.verify.resend", args=(challenge.public_ref,))) + # visit pending URL + resp = c.get(pending_url) + self.assertEqual(resp.status_code, 200) + self.assertEqual(str.encode(msg["email"]) in resp.content, True) + self.assertEqual(str.encode(resend_url) in resp.content, True) + + resp = c.post(resend_url) self.assertEqual(resp.status_code, 302) self.assertEqual(resp.headers["location"], pending_url) + resp = c.get(challenge.verification_link()) + self.assertEqual(resp.status_code, 200) + self.assertEqual( + str.encode(challenge.verification_link()) in resp.content, True + ) + resp = c.post(challenge.verification_link()) self.assertEqual(resp.status_code, 302) self.assertEqual(resp.headers["location"], reverse("accounts.login")) diff --git a/accounts/views.py b/accounts/views.py index 03489db..0851005 100644 --- a/accounts/views.py +++ b/accounts/views.py @@ -43,10 +43,12 @@ def login_view(request): user = None if "@" in login_cred: - user = authenticate( - email=login_cred, - password=request.POST["password"], - ) + user = get_user_model().objects.get(email=login_cred) + if user is not None: + user = authenticate( + username=user.username, + password=request.POST["password"], + ) else: user = authenticate( username=login_cred, @@ -75,8 +77,10 @@ def protected_view(request): def default_login_url(request): - ctx = {"next": request.GET["next"]} - return redirect(f"{reverse('accounts.login')}?{urlencode(ctx)}") + if "next" in request.GET: + ctx = {"next": request.GET["next"]} + return redirect(f"{reverse('accounts.login')}?{urlencode(ctx)}") + return redirect(reverse("accounts.login")) @login_required