diff --git a/hc/accounts/middleware.py b/hc/accounts/middleware.py index dd20e6de..7e84dc46 100644 --- a/hc/accounts/middleware.py +++ b/hc/accounts/middleware.py @@ -1,4 +1,4 @@ -from hc.accounts.models import Profile +from hc.accounts.models import Profile, Project class TeamAccessMiddleware(object): @@ -9,15 +9,16 @@ class TeamAccessMiddleware(object): if not request.user.is_authenticated: return self.get_response(request) - teams_q = Profile.objects.filter(member__user_id=request.user.id) - teams_q = teams_q.select_related("user") - request.get_teams = lambda: list(teams_q) + projects_q = Project.objects.filter(member__user_id=request.user.id) + projects_q = projects_q.select_related("owner") + request.get_projects = lambda: list(projects_q) - request.profile = Profile.objects.for_user(request.user) - request.team = request.profile.team() + profile = Profile.objects.for_user(request.user) + if profile.current_project is None: + profile.current_project = profile.get_own_project() + profile.save() - request.project = request.profile.current_project - if request.project is None: - request.project = request.team.user.project_set.first() + request.profile = profile + request.project = profile.current_project return self.get_response(request) diff --git a/hc/accounts/models.py b/hc/accounts/models.py index 0d57324c..91d7fd12 100644 --- a/hc/accounts/models.py +++ b/hc/accounts/models.py @@ -233,6 +233,13 @@ class Profile(models.Model): q.update(next_nag_date=timezone.now() + models.F("nag_period")) + def get_own_project(self): + project = self.user.project_set.first() + if project is None: + project = Project.objects.create(owner=self.user) + + return project + class Project(models.Model): code = models.UUIDField(default=uuid.uuid4, editable=False, unique=True) @@ -242,12 +249,17 @@ class Project(models.Model): api_key_readonly = models.CharField(max_length=128, blank=True) badge_key = models.CharField(max_length=150, unique=True) - def num_checks_available(self): - owner_profile = Profile.objects.for_user(self.owner) + def __str__(self): + return self.name or self.owner.email + @property + def owner_profile(self): + return Profile.objects.for_user(self.owner) + + def num_checks_available(self): from hc.api.models import Check num_used = Check.objects.filter(project__owner=self.owner).count() - return owner_profile.check_limit - num_used + return self.owner_profile.check_limit - num_used class Member(models.Model): diff --git a/hc/accounts/views.py b/hc/accounts/views.py index 4b68a59b..809430af 100644 --- a/hc/accounts/views.py +++ b/hc/accounts/views.py @@ -76,9 +76,8 @@ def _make_user(email): def _ensure_own_team(request): """ Make sure user is switched to their own team. """ - if request.team != request.profile: - request.team = request.profile - request.project = request.user.project_set.first() + if request.project.owner != request.user: + request.project = request.profile.get_own_project() request.profile.current_team = request.profile request.profile.current_project = request.project @@ -271,9 +270,8 @@ def profile(request): profile.team_name = form.cleaned_data["team_name"] profile.save() - for project in request.user.project_set.all(): - project.name = form.cleaned_data["team_name"] - project.save() + request.project.name = form.cleaned_data["team_name"] + request.project.save() ctx["team_name_updated"] = True ctx["team_status"] = "success" @@ -454,7 +452,7 @@ def switch_team(request, target_username): return HttpResponseForbidden() request.profile.current_team = target_team - request.profile.current_project = target_team.user.project_set.first() + request.profile.current_project = target_team.get_own_project() request.profile.save() return redirect("hc-checks") diff --git a/hc/api/management/commands/sendalerts.py b/hc/api/management/commands/sendalerts.py index 1217a077..bf3755a3 100644 --- a/hc/api/management/commands/sendalerts.py +++ b/hc/api/management/commands/sendalerts.py @@ -20,8 +20,8 @@ def notify(flip_id, stdout): stdout.write(tmpl % (flip.new_status, check.code)) # Set dates for followup nags - if flip.new_status == "down" and check.user.profile: - check.user.profile.set_next_nag_date() + if flip.new_status == "down": + check.project.owner_profile.set_next_nag_date() # Send notifications errors = flip.send_alerts() diff --git a/hc/api/transports.py b/hc/api/transports.py index 95834303..29421de6 100644 --- a/hc/api/transports.py +++ b/hc/api/transports.py @@ -42,7 +42,7 @@ class Transport(object): return False def checks(self): - return self.channel.user.check_set.order_by("created") + return self.channel.project.check_set.order_by("created") class Email(Transport): diff --git a/hc/front/views.py b/hc/front/views.py index efd040c5..30d7903d 100644 --- a/hc/front/views.py +++ b/hc/front/views.py @@ -97,14 +97,14 @@ def my_checks(request): request.profile.sort = request.GET["sort"] request.profile.save() - checks = list(Check.objects.filter(user=request.team.user).prefetch_related("channel_set")) + checks = list(Check.objects.filter(project=request.project).prefetch_related("channel_set")) sortchecks(checks, request.profile.sort) tags_statuses, num_down = _tags_statuses(checks) pairs = list(tags_statuses.items()) pairs.sort(key=lambda pair: pair[0].lower()) - channels = Channel.objects.filter(user=request.team.user) + channels = Channel.objects.filter(project=request.project) channels = list(channels.order_by("created")) hidden_checks = set() @@ -173,7 +173,7 @@ def switch_channel(request, code, channel_code): check = _get_check_for_user(request, code) channel = get_object_or_404(Channel, code=channel_code) - if channel.user_id != check.user_id: + if channel.project_id != check.project_id: return HttpResponseBadRequest() if request.POST.get("state") == "on": @@ -248,7 +248,7 @@ def add_check(request): if request.project.num_checks_available() <= 0: return HttpResponseBadRequest() - check = Check(user=request.team.user, project=request.project) + check = Check(user=request.project.owner, project=request.project) check.save() check.assign_all_channels() @@ -411,7 +411,7 @@ def _get_events(check, limit): def log(request, code): check = _get_check_for_user(request, code) - limit = check.user.profile.ping_log_limit + limit = request.project.owner_profile.ping_log_limit ctx = { "check": check, "events": _get_events(check, limit), @@ -426,7 +426,7 @@ def log(request, code): def details(request, code): check = _get_check_for_user(request, code) - channels = Channel.objects.filter(user=check.user) + channels = Channel.objects.filter(project=check.project) channels = list(channels.order_by("created")) ctx = { @@ -470,7 +470,7 @@ def channels(request): channel = Channel.objects.get(code=code) except Channel.DoesNotExist: return HttpResponseBadRequest() - if channel.user_id != request.team.user.id: + if channel.project_id != request.project.id: return HttpResponseForbidden() new_checks = [] @@ -481,20 +481,20 @@ def channels(request): check = Check.objects.get(code=code) except Check.DoesNotExist: return HttpResponseBadRequest() - if check.user_id != request.team.user.id: + if check.project_id != request.project.id: return HttpResponseForbidden() new_checks.append(check) channel.checks.set(new_checks) return redirect("hc-channels") - channels = Channel.objects.filter(user=request.team.user) + channels = Channel.objects.filter(project=request.project) channels = channels.order_by("created") channels = channels.annotate(n_checks=Count("checks")) ctx = { "page": "channels", - "profile": request.team, + "profile": request.project.owner_profile, "channels": channels, "enable_pushbullet": settings.PUSHBULLET_CLIENT_ID is not None, "enable_pushover": settings.PUSHOVER_API_TOKEN is not None, @@ -512,11 +512,11 @@ def channels(request): @login_required def channel_checks(request, code): channel = get_object_or_404(Channel, code=code) - if channel.user_id != request.team.user.id: + if channel.project_id != request.project.id: return HttpResponseForbidden() assigned = set(channel.checks.values_list('code', flat=True).distinct()) - checks = Check.objects.filter(user=request.team.user).order_by("created") + checks = Check.objects.filter(project=request.project).order_by("created") ctx = { "checks": checks, @@ -531,7 +531,7 @@ def channel_checks(request, code): @login_required def update_channel_name(request, code): channel = get_object_or_404(Channel, code=code) - if channel.user_id != request.team.user.id: + if channel.project_id != request.project.id: return HttpResponseForbidden() form = ChannelNameForm(request.POST) @@ -575,7 +575,7 @@ def remove_channel(request, code): # user may refresh the page during POST and cause two deletion attempts channel = Channel.objects.filter(code=code).first() if channel: - if channel.user != request.team.user: + if channel.project_id != request.project.id: return HttpResponseForbidden() channel.delete() @@ -587,7 +587,7 @@ def add_email(request): if request.method == "POST": form = AddEmailForm(request.POST) if form.is_valid(): - channel = Channel(user=request.team.user, kind="email") + channel = Channel(user=request.project.owner, kind="email") channel.project = request.project channel.value = form.cleaned_data["value"] channel.save() @@ -607,7 +607,7 @@ def add_webhook(request): if request.method == "POST": form = AddWebhookForm(request.POST) if form.is_valid(): - channel = Channel(user=request.team.user, kind="webhook") + channel = Channel(user=request.project.owner, kind="webhook") channel.project = request.project channel.value = form.get_value() channel.save() @@ -660,7 +660,7 @@ def add_pd(request, state=None): return redirect("hc-channels") channel = Channel(kind="pd", project=request.project) - channel.user = request.team.user + channel.user = request.project.owner channel.value = json.dumps({ "service_key": request.GET.get("service_key"), "account": request.GET.get("account") @@ -686,7 +686,7 @@ def add_pagertree(request): if request.method == "POST": form = AddUrlForm(request.POST) if form.is_valid(): - channel = Channel(user=request.team.user, kind="pagertree") + channel = Channel(user=request.project.owner, kind="pagertree") channel.project = request.project channel.value = form.cleaned_data["value"] channel.save() @@ -707,7 +707,7 @@ def add_slack(request): if request.method == "POST": form = AddUrlForm(request.POST) if form.is_valid(): - channel = Channel(user=request.team.user, kind="slack") + channel = Channel(user=request.project.owner, kind="slack") channel.project = request.project channel.value = form.cleaned_data["value"] channel.save() @@ -744,7 +744,7 @@ def add_slack_btn(request): doc = result.json() if doc.get("ok"): channel = Channel(kind="slack", project=request.project) - channel.user = request.team.user + channel.user = request.project.owner channel.value = result.text channel.save() channel.assign_all_checks() @@ -767,7 +767,7 @@ def add_hipchat(request): return redirect("hc-channels") channel = Channel(kind="hipchat", project=request.project) - channel.user = request.team.user + channel.user = request.project.owner channel.value = response.text channel.save() @@ -812,7 +812,7 @@ def add_pushbullet(request): doc = result.json() if "access_token" in doc: channel = Channel(kind="pushbullet", project=request.project) - channel.user = request.team.user + channel.user = request.project.owner channel.value = doc["access_token"] channel.save() channel.assign_all_checks() @@ -860,7 +860,7 @@ def add_discord(request): doc = result.json() if "access_token" in doc: channel = Channel(kind="discord", project=request.project) - channel.user = request.team.user + channel.user = request.project.owner channel.value = result.text channel.save() channel.assign_all_checks() @@ -933,7 +933,7 @@ def add_pushover(request): return redirect("hc-channels") # Subscription - channel = Channel(user=request.team.user, kind="po") + channel = Channel(user=request.project.owner, kind="po") channel.project = request.project channel.value = "%s|%s|%s" % (key, prio, prio_up) channel.save() @@ -956,7 +956,7 @@ def add_opsgenie(request): if request.method == "POST": form = AddOpsGenieForm(request.POST) if form.is_valid(): - channel = Channel(user=request.team.user, kind="opsgenie") + channel = Channel(user=request.project.owner, kind="opsgenie") channel.project = request.project channel.value = form.cleaned_data["value"] channel.save() @@ -975,7 +975,7 @@ def add_victorops(request): if request.method == "POST": form = AddUrlForm(request.POST) if form.is_valid(): - channel = Channel(user=request.team.user, kind="victorops") + channel = Channel(user=request.project.owner, kind="victorops") channel.project = request.project channel.value = form.cleaned_data["value"] channel.save() @@ -1024,7 +1024,7 @@ def add_telegram(request): chat_id, chat_type, chat_name = signing.loads(qs, max_age=600) if request.method == "POST": - channel = Channel(user=request.team.user, kind="telegram") + channel = Channel(user=request.project.owner, kind="telegram") channel.project = request.project channel.value = json.dumps({ "id": chat_id, @@ -1055,7 +1055,7 @@ def add_sms(request): if request.method == "POST": form = AddSmsForm(request.POST) if form.is_valid(): - channel = Channel(user=request.team.user, kind="sms") + channel = Channel(user=request.project.owner, kind="sms") channel.project = request.project channel.name = form.cleaned_data["label"] channel.value = json.dumps({ @@ -1071,7 +1071,7 @@ def add_sms(request): ctx = { "page": "channels", "form": form, - "profile": request.team + "profile": request.project.owner_profile } return render(request, "integrations/add_sms.html", ctx) @@ -1082,7 +1082,7 @@ def add_trello(request): raise Http404("trello integration is not available") if request.method == "POST": - channel = Channel(user=request.team.user, kind="trello") + channel = Channel(user=request.project.owner, kind="trello") channel.value = request.POST["settings"] channel.save() diff --git a/hc/payments/views.py b/hc/payments/views.py index 3fbc7d49..c4acfd6b 100644 --- a/hc/payments/views.py +++ b/hc/payments/views.py @@ -21,7 +21,7 @@ def get_client_token(request): def pricing(request): - if request.user.is_authenticated and request.profile != request.team: + if request.user.is_authenticated and request.user != request.project.owner: ctx = {"page": "pricing"} return render(request, "payments/pricing_not_owner.html", ctx) @@ -35,9 +35,11 @@ def pricing(request): @login_required def billing(request): - if request.team != request.profile: - request.team = request.profile + if request.project.owner != request.user: + request.project = request.profile.get_own_project() + request.profile.current_team = request.profile + request.profile.current_project = request.project request.profile.save() # Don't use Subscription.objects.for_user method here, so a diff --git a/templates/base.html b/templates/base.html index 922d3efb..349c2b31 100644 --- a/templates/base.html +++ b/templates/base.html @@ -116,15 +116,15 @@