diff --git a/cms/server/admin/handlers/main.py b/cms/server/admin/handlers/main.py index e3ce0760aa..4557e28a0b 100644 --- a/cms/server/admin/handlers/main.py +++ b/cms/server/admin/handlers/main.py @@ -31,6 +31,7 @@ from cms import ServiceCoord, get_service_shards, get_service_address from cms.db import Admin, Contest, Question from cms.server.jinja2_toolbox import markdown_filter +from cms.server.util import normalize_login_next_page from cmscommon.crypto import validate_password from cmscommon.datetime import make_datetime, make_timestamp from .base import BaseHandler, SimpleHandler, require_permission @@ -48,12 +49,7 @@ def post(self): next_page: str = self.get_argument("next", None) if next_page is not None: error_args["next"] = next_page - if next_page != "/": - next_page = self.url(*next_page.strip("/").split("/")) - else: - next_page = self.url() - else: - next_page = self.url() + next_page = normalize_login_next_page(next_page, self.url, self.url()) error_page = self.url("login", **error_args) username: str = self.get_argument("username", "") diff --git a/cms/server/contest/handlers/main.py b/cms/server/contest/handlers/main.py index 93402e2b0c..2b3e60f160 100644 --- a/cms/server/contest/handlers/main.py +++ b/cms/server/contest/handlers/main.py @@ -52,6 +52,7 @@ from cms.grading.languagemanager import get_language from cms.grading.steps import COMPILATION_MESSAGES, EVALUATION_MESSAGES from cms.server import multi_contest +from cms.server.util import normalize_login_next_page from cms.server.contest.authentication import validate_login from cms.server.contest.communication import get_communications from cmscommon.crypto import hash_password, validate_password @@ -217,12 +218,7 @@ def post(self): next_page: str | None = self.get_argument("next", None) if next_page is not None: error_args["next"] = next_page - if next_page != "/": - next_page = self.url(*next_page.strip("/").split("/")) - else: - next_page = self.url() - else: - next_page = self.contest_url() + next_page = normalize_login_next_page(next_page, self.url, self.contest_url()) error_page = self.contest_url(**error_args) username: str = self.get_argument("username", "") diff --git a/cms/server/util.py b/cms/server/util.py index 3251a1fc24..902e521352 100644 --- a/cms/server/util.py +++ b/cms/server/util.py @@ -28,7 +28,7 @@ import logging from functools import wraps -from urllib.parse import quote, urlencode +from urllib.parse import quote, urlencode, urlsplit import collections try: @@ -168,6 +168,38 @@ def __getitem__(self, component: object) -> typing.Self: return self.__class__(self.__call__(component)) +def normalize_login_next_page(next_page: str | None, url: Url, default_url: str) -> str: + """Normalize a login redirection target. + + Accept only local absolute-path targets (plus optional query), and + rebase them through the provided URL builder. Query-only values are + treated as "/" and preserved. + + next_page: raw value of the "next" parameter. + url: URL builder for local paths. + default_url: fallback when next_page is missing or invalid. + + return: normalized redirect target. + """ + if next_page is None: + return default_url + + split = urlsplit(next_page) + path = split.path or "/" + if split.scheme or split.netloc or not path.startswith("/"): + return default_url + + if path != "/": + normalized = url(*path.strip("/").split("/")) + else: + normalized = url() + + if split.query: + normalized += "?" + split.query + + return normalized + + class CommonRequestHandler(RequestHandler): """Encapsulates shared RequestHandler functionality.