diff --git a/hc/accounts/tests/test_profile.py b/hc/accounts/tests/test_profile.py index f4d257c9..c0a42668 100644 --- a/hc/accounts/tests/test_profile.py +++ b/hc/accounts/tests/test_profile.py @@ -193,7 +193,7 @@ class ProfileTestCase(BaseTestCase): # to user's default team. self.bobs_profile.refresh_from_db() self.assertEqual(self.bobs_profile.current_team, self.bobs_profile) - self.assertEqual(self.bobs_profile.current_project, None) + self.assertEqual(self.bobs_profile.current_project, self.bobs_project) def test_it_sends_change_email_link(self): self.client.login(username="alice@example.org", password="password") diff --git a/hc/api/decorators.py b/hc/api/decorators.py index a6d0ffae..67e2a358 100644 --- a/hc/api/decorators.py +++ b/hc/api/decorators.py @@ -4,6 +4,7 @@ from functools import wraps from django.contrib.auth.models import User from django.db.models import Q from django.http import HttpResponse, JsonResponse +from hc.accounts.models import Project from hc.lib.jsonschema import ValidationError, validate @@ -23,9 +24,8 @@ def authorize(f): return error("missing api key", 401) try: - request.user = User.objects.get(profile__api_key=api_key) - request.project = request.user.project_set.first() - except User.DoesNotExist: + request.project = Project.objects.get(api_key=api_key) + except Project.DoesNotExist: return error("wrong api key", 401) return f(request, *args, **kwds) @@ -43,12 +43,11 @@ def authorize_read(f): if len(api_key) != 32: return error("missing api key", 401) - write_key_match = Q(profile__api_key=api_key) - read_key_match = Q(profile__api_key_readonly=api_key) + write_key_match = Q(api_key=api_key) + read_key_match = Q(api_key_readonly=api_key) try: - request.user = User.objects.get(write_key_match | read_key_match) - request.project = request.user.project_set.first() - except User.DoesNotExist: + request.project = Project.objects.get(write_key_match | read_key_match) + except Project.DoesNotExist: return error("wrong api key", 401) return f(request, *args, **kwds) diff --git a/hc/api/tests/test_badge.py b/hc/api/tests/test_badge.py index 6a7039ce..25a75910 100644 --- a/hc/api/tests/test_badge.py +++ b/hc/api/tests/test_badge.py @@ -12,7 +12,8 @@ class BadgeTestCase(BaseTestCase): def setUp(self): super(BadgeTestCase, self).setUp() - self.check = Check.objects.create(user=self.alice, tags="foo bar") + self.check = Check.objects.create(user=self.alice, project=self.project, + tags="foo bar") sig = base64_hmac(str(self.alice.username), "foo", settings.SECRET_KEY) sig = sig[:8] diff --git a/hc/api/tests/test_create_check.py b/hc/api/tests/test_create_check.py index 6bc36c8c..e3725e82 100644 --- a/hc/api/tests/test_create_check.py +++ b/hc/api/tests/test_create_check.py @@ -87,7 +87,7 @@ class CreateCheckTestCase(BaseTestCase): self.assertEqual(check.channel_set.get(), channel) def test_it_supports_unique(self): - existing = Check(user=self.alice, name="Foo") + existing = Check(user=self.alice, name="Foo", project=self.project) existing.save() r = self.post({ diff --git a/hc/api/tests/test_delete_check.py b/hc/api/tests/test_delete_check.py index d6680ac4..0d094d11 100644 --- a/hc/api/tests/test_delete_check.py +++ b/hc/api/tests/test_delete_check.py @@ -6,7 +6,7 @@ class DeleteCheckTestCase(BaseTestCase): def setUp(self): super(DeleteCheckTestCase, self).setUp() - self.check = Check(user=self.alice) + self.check = Check(user=self.alice, project=self.project) self.check.save() def test_it_works(self): diff --git a/hc/api/tests/test_list_channels.py b/hc/api/tests/test_list_channels.py index 7d755097..92dd9ff7 100644 --- a/hc/api/tests/test_list_channels.py +++ b/hc/api/tests/test_list_channels.py @@ -9,7 +9,7 @@ class ListChannelsTestCase(BaseTestCase): def setUp(self): super(ListChannelsTestCase, self).setUp() - self.c1 = Channel(user=self.alice) + self.c1 = Channel(user=self.alice, project=self.project) self.c1.kind = "email" self.c1.name = "Email to Alice" self.c1.save() @@ -36,7 +36,8 @@ class ListChannelsTestCase(BaseTestCase): self.assertIn("GET", r["Access-Control-Allow-Methods"]) def test_it_shows_only_users_channels(self): - Channel.objects.create(user=self.bob, kind="email", name="Bob") + Channel.objects.create(user=self.bob, kind="email", name="Bob", + project=self.bobs_project) r = self.get() data = r.json() @@ -53,8 +54,8 @@ class ListChannelsTestCase(BaseTestCase): self.assertContains(r, "Email to Alice") def test_readonly_key_works(self): - self.profile.api_key_readonly = "R" * 32 - self.profile.save() + self.project.api_key_readonly = "R" * 32 + self.project.save() r = self.client.get("/api/v1/channels/", HTTP_X_API_KEY="R" * 32) self.assertEqual(r.status_code, 200) diff --git a/hc/api/tests/test_list_checks.py b/hc/api/tests/test_list_checks.py index d2543d97..fcadf291 100644 --- a/hc/api/tests/test_list_checks.py +++ b/hc/api/tests/test_list_checks.py @@ -14,7 +14,7 @@ class ListChecksTestCase(BaseTestCase): self.now = now().replace(microsecond=0) - self.a1 = Check(user=self.alice, name="Alice 1") + self.a1 = Check(user=self.alice, name="Alice 1", project=self.project) self.a1.timeout = td(seconds=3600) self.a1.grace = td(seconds=900) self.a1.n_pings = 0 @@ -22,7 +22,7 @@ class ListChecksTestCase(BaseTestCase): self.a1.tags = "a1-tag a1-additional-tag" self.a1.save() - self.a2 = Check(user=self.alice, name="Alice 2") + self.a2 = Check(user=self.alice, name="Alice 2", project=self.project) self.a2.timeout = td(seconds=86400) self.a2.grace = td(seconds=3600) self.a2.last_ping = self.now @@ -79,7 +79,8 @@ class ListChecksTestCase(BaseTestCase): self.assertIn("GET", r["Access-Control-Allow-Methods"]) def test_it_shows_only_users_checks(self): - bobs_check = Check(user=self.bob, name="Bob 1") + bobs_check = Check(user=self.bob, name="Bob 1", + project=self.bobs_project) bobs_check.save() r = self.get() @@ -139,8 +140,8 @@ class ListChecksTestCase(BaseTestCase): self.assertEqual(len(doc["checks"]), 0) def test_readonly_key_works(self): - self.profile.api_key_readonly = "R" * 32 - self.profile.save() + self.project.api_key_readonly = "R" * 32 + self.project.save() r = self.client.get("/api/v1/checks/", HTTP_X_API_KEY="R" * 32) self.assertEqual(r.status_code, 200) diff --git a/hc/api/tests/test_pause.py b/hc/api/tests/test_pause.py index eb9e8d0d..6f97ccff 100644 --- a/hc/api/tests/test_pause.py +++ b/hc/api/tests/test_pause.py @@ -8,7 +8,7 @@ from hc.test import BaseTestCase class PauseTestCase(BaseTestCase): def test_it_works(self): - check = Check(user=self.alice, status="up") + check = Check(user=self.alice, status="up", project=self.project) check.save() url = "/api/v1/checks/%s/pause" % check.code @@ -22,7 +22,7 @@ class PauseTestCase(BaseTestCase): self.assertEqual(check.status, "paused") def test_it_handles_options(self): - check = Check(user=self.alice, status="up") + check = Check(user=self.alice, status="up", project=self.project) check.save() r = self.client.options("/api/v1/checks/%s/pause" % check.code) @@ -60,7 +60,7 @@ class PauseTestCase(BaseTestCase): self.assertEqual(r.status_code, 404) def test_it_clears_last_start_alert_after(self): - check = Check(user=self.alice, status="up") + check = Check(user=self.alice, status="up", project=self.project) check.last_start = now() check.alert_after = check.last_start + td(hours=1) check.save() diff --git a/hc/api/tests/test_update_check.py b/hc/api/tests/test_update_check.py index 24b41c2d..a3c28861 100644 --- a/hc/api/tests/test_update_check.py +++ b/hc/api/tests/test_update_check.py @@ -8,7 +8,7 @@ class UpdateCheckTestCase(BaseTestCase): def setUp(self): super(UpdateCheckTestCase, self).setUp() - self.check = Check(user=self.alice) + self.check = Check(user=self.alice, project=self.project) self.check.save() def post(self, code, data): diff --git a/hc/api/views.py b/hc/api/views.py index e9a384af..4be9f157 100644 --- a/hc/api/views.py +++ b/hc/api/views.py @@ -37,10 +37,10 @@ def ping(request, code, action="success"): return response -def _lookup(user, spec): +def _lookup(project, spec): unique_fields = spec.get("unique", []) if unique_fields: - existing_checks = Check.objects.filter(user=user) + existing_checks = Check.objects.filter(project=project) if "name" in unique_fields: existing_checks = existing_checks.filter(name=spec.get("name")) if "tags" in unique_fields: @@ -105,7 +105,7 @@ def _update(check, spec): @validate_json() @authorize_read def get_checks(request): - q = Check.objects.filter(user=request.user) + q = Check.objects.filter(project=request.project) q = q.prefetch_related("channel_set") tags = set(request.GET.getlist("tag")) @@ -126,13 +126,14 @@ def get_checks(request): @authorize def create_check(request): created = False - check = _lookup(request.user, request.json) + check = _lookup(request.project, request.json) if check is None: - num_checks = Check.objects.filter(user=request.user).count() - if num_checks >= request.user.profile.check_limit: + user = request.project.owner + num_checks = Check.objects.filter(project__owner=user).count() + if num_checks >= user.profile.check_limit: return HttpResponseForbidden() - check = Check(user=request.user, project=request.project) + check = Check(user=request.project.owner, project=request.project) created = True _update(check, request.json) @@ -152,7 +153,7 @@ def checks(request): @validate_json() @authorize_read def channels(request): - q = Channel.objects.filter(user=request.user) + q = Channel.objects.filter(project=request.project) channels = [ch.to_dict() for ch in q] return JsonResponse({"channels": channels}) @@ -163,7 +164,7 @@ def channels(request): @authorize def update(request, code): check = get_object_or_404(Check, code=code) - if check.user != request.user: + if check.project != request.project: return HttpResponseForbidden() if request.method == "POST": @@ -185,7 +186,7 @@ def update(request, code): @authorize def pause(request, code): check = get_object_or_404(Check, code=code) - if check.user != request.user: + if check.project != request.project: return HttpResponseForbidden() check.status = "paused" @@ -202,7 +203,7 @@ def badge(request, username, signature, tag, format="svg"): return HttpResponseNotFound() status = "up" - q = Check.objects.filter(user__username=username) + q = Check.objects.filter(project__owner__username=username) if tag != "*": q = q.filter(tags__contains=tag) label = tag diff --git a/hc/test.py b/hc/test.py index f0e45c2c..29bfe226 100644 --- a/hc/test.py +++ b/hc/test.py @@ -27,6 +27,9 @@ class BaseTestCase(TestCase): self.bob.set_password("password") self.bob.save() + self.bobs_project = Project(owner=self.bob) + self.bobs_project.save() + self.bobs_profile = Profile(user=self.bob) self.bobs_profile.current_team = self.profile self.bobs_profile.current_project = self.project