From 4c14a10b916635953599c6522e0da28399ba2393 Mon Sep 17 00:00:00 2001 From: Kevin Morris Date: Tue, 8 Feb 2022 07:50:15 -0800 Subject: [PATCH] fix: support multiple SSHPubKey records per user There was one blazing issue with the previous implementation regardless of the multiple records: we were generating fingerprints by storing the key into a file and reading it with ssh-keygen. This is absolutely terrible and was not meant to be left around (it was forgotten, my bad). Took this opportunity to clean up a few things: - simplify pubkey validation - centralize things a bit better Signed-off-by: Kevin Morris --- aurweb/models/ssh_pub_key.py | 30 ++++----------- aurweb/routers/accounts.py | 30 ++++++++------- aurweb/users/update.py | 56 +++++++++++++++------------- aurweb/users/validate.py | 16 ++++---- aurweb/util.py | 49 +++++++++++++----------- templates/partials/account_form.html | 2 +- test/test_accounts_routes.py | 21 +++++------ test/test_adduser.py | 2 +- test/test_ssh_pub_key.py | 10 ++++- test/test_user.py | 4 +- test/test_util.py | 50 +++++++++++++++++++++++++ 11 files changed, 162 insertions(+), 108 deletions(-) diff --git a/aurweb/models/ssh_pub_key.py b/aurweb/models/ssh_pub_key.py index 789be629..53c8c3ac 100644 --- a/aurweb/models/ssh_pub_key.py +++ b/aurweb/models/ssh_pub_key.py @@ -1,6 +1,3 @@ -import os -import tempfile - from subprocess import PIPE, Popen from sqlalchemy.orm import backref, relationship @@ -15,28 +12,17 @@ class SSHPubKey(Base): __mapper_args__ = {"primary_key": [__table__.c.Fingerprint]} User = relationship( - "User", backref=backref("ssh_pub_key", uselist=False), + "User", backref=backref("ssh_pub_keys", lazy="dynamic"), foreign_keys=[__table__.c.UserID]) def __init__(self, **kwargs): super().__init__(**kwargs) -def get_fingerprint(pubkey): - with tempfile.TemporaryDirectory() as tmpdir: - pk = os.path.join(tmpdir, "ssh.pub") - - with open(pk, "w") as f: - f.write(pubkey) - - proc = Popen(["ssh-keygen", "-l", "-f", pk], stdout=PIPE, stderr=PIPE) - out, err = proc.communicate() - - # Invalid SSH Public Key. Return None to the caller. - if proc.returncode != 0: - return None - - parts = out.decode().split() - fp = parts[1].replace("SHA256:", "") - - return fp +def get_fingerprint(pubkey: str) -> str: + proc = Popen(["ssh-keygen", "-l", "-f", "-"], stdin=PIPE, stdout=PIPE, + stderr=PIPE) + out, _ = proc.communicate(pubkey.encode()) + if proc.returncode: + raise ValueError("The SSH public key is invalid.") + return out.decode().split()[1].split(":", 1)[1] diff --git a/aurweb/routers/accounts.py b/aurweb/routers/accounts.py index d1b9d428..36ac48d2 100644 --- a/aurweb/routers/accounts.py +++ b/aurweb/routers/accounts.py @@ -2,6 +2,7 @@ import copy import typing from http import HTTPStatus +from typing import Any, Dict from fastapi import APIRouter, Form, Request from fastapi.responses import HTMLResponse, RedirectResponse @@ -105,7 +106,8 @@ async def passreset_post(request: Request, status_code=HTTPStatus.SEE_OTHER) -def process_account_form(request: Request, user: models.User, args: dict): +def process_account_form(request: Request, user: models.User, + args: Dict[str, Any]): """ Process an account form. All fields are optional and only checks requirements in the case they are present. @@ -193,8 +195,8 @@ def make_account_form_context(context: dict, context["pgp"] = args.get("K", user.PGPKey or str()) context["lang"] = args.get("L", user.LangPreference) context["tz"] = args.get("TZ", user.Timezone) - ssh_pk = user.ssh_pub_key.PubKey if user.ssh_pub_key else str() - context["ssh_pk"] = args.get("PK", ssh_pk) + ssh_pks = [pk.PubKey for pk in user.ssh_pub_keys] + context["ssh_pks"] = args.get("PK", ssh_pks) context["cn"] = args.get("CN", user.CommentNotify) context["un"] = args.get("UN", user.UpdateNotify) context["on"] = args.get("ON", user.OwnershipNotify) @@ -212,7 +214,7 @@ def make_account_form_context(context: dict, context["pgp"] = args.get("K", str()) context["lang"] = args.get("L", context.get("language")) context["tz"] = args.get("TZ", context.get("timezone")) - context["ssh_pk"] = args.get("PK", str()) + context["ssh_pks"] = args.get("PK", str()) context["cn"] = args.get("CN", True) context["un"] = args.get("UN", False) context["on"] = args.get("ON", True) @@ -314,16 +316,13 @@ async def account_register_post(request: Request, # PK mismatches the existing user's SSHPubKey.PubKey. if PK: # Get the second element in the PK, which is the actual key. - pubkey = PK.strip().rstrip() - parts = pubkey.split(" ") - if len(parts) == 3: - # Remove the host part. - pubkey = parts[0] + " " + parts[1] - fingerprint = get_fingerprint(pubkey) - with db.begin(): - user.ssh_pub_key = models.SSHPubKey(UserID=user.ID, - PubKey=pubkey, - Fingerprint=fingerprint) + keys = util.parse_ssh_keys(PK.strip()) + for k in keys: + pk = " ".join(k) + fprint = get_fingerprint(pk) + with db.begin(): + db.create(models.SSHPubKey, UserID=user.ID, + PubKey=pk, Fingerprint=fprint) # Send a reset key notification to the new user. WelcomeNotification(user.ID).send() @@ -409,6 +408,9 @@ async def account_edit_post(request: Request, context = make_account_form_context(context, request, user, args) ok, errors = process_account_form(request, user, args) + if PK: + context["ssh_pks"] = [PK] + if not passwd: context["errors"] = ["Invalid password."] return render_template(request, "account/edit.html", context, diff --git a/aurweb/users/update.py b/aurweb/users/update.py index 685dfd80..8e42765e 100644 --- a/aurweb/users/update.py +++ b/aurweb/users/update.py @@ -2,7 +2,8 @@ from typing import Any, Dict from fastapi import Request -from aurweb import cookies, db, models, time +from aurweb import cookies, db, models, time, util +from aurweb.models import SSHPubKey from aurweb.models.ssh_pub_key import get_fingerprint from aurweb.util import strtobool @@ -52,32 +53,35 @@ def timezone(TZ: str = str(), context["language"] = TZ -def ssh_pubkey(PK: str = str(), - user: models.User = None, - **kwargs) -> None: - # If a PK is given, compare it against the target user's PK. - if PK: - # Get the second token in the public key, which is the actual key. - pubkey = PK.strip().rstrip() - parts = pubkey.split(" ") - if len(parts) == 3: - # Remove the host part. - pubkey = parts[0] + " " + parts[1] - fingerprint = get_fingerprint(pubkey) - if not user.ssh_pub_key: - # No public key exists, create one. - with db.begin(): - db.create(models.SSHPubKey, UserID=user.ID, - PubKey=pubkey, Fingerprint=fingerprint) - elif user.ssh_pub_key.PubKey != pubkey: - # A public key already exists, update it. - with db.begin(): - user.ssh_pub_key.PubKey = pubkey - user.ssh_pub_key.Fingerprint = fingerprint - elif user.ssh_pub_key: - # Else, if the user has a public key already, delete it. +def ssh_pubkey(PK: str = str(), user: models.User = None, **kwargs) -> None: + if not PK: + # If no pubkey is provided, wipe out any pubkeys the user + # has and return out early. with db.begin(): - db.delete(user.ssh_pub_key) + db.delete_all(user.ssh_pub_keys) + return + + # Otherwise, parse ssh keys and their fprints out of PK. + keys = util.parse_ssh_keys(PK.strip()) + fprints = [get_fingerprint(" ".join(k)) for k in keys] + + with db.begin(): + # Delete any existing keys we can't find. + to_remove = user.ssh_pub_keys.filter( + ~SSHPubKey.Fingerprint.in_(fprints)) + db.delete_all(to_remove) + + # For each key, if it does not yet exist, create it. + for i, full_key in enumerate(keys): + prefix, key = full_key + exists = user.ssh_pub_keys.filter( + SSHPubKey.Fingerprint == fprints[i] + ).exists() + if not db.query(exists).scalar(): + # No public key exists, create one. + db.create(models.SSHPubKey, UserID=user.ID, + PubKey=" ".join([prefix, key]), + Fingerprint=fprints[i]) def account_type(T: int = None, diff --git a/aurweb/users/validate.py b/aurweb/users/validate.py index bbd6082a..26f6eec6 100644 --- a/aurweb/users/validate.py +++ b/aurweb/users/validate.py @@ -107,14 +107,16 @@ def invalid_pgp_key(K: str = str(), **kwargs) -> None: def invalid_ssh_pubkey(PK: str = str(), user: models.User = None, _: l10n.Translator = None, **kwargs) -> None: - if PK: - invalid_exc = ValidationError(["The SSH public key is invalid."]) - if not util.valid_ssh_pubkey(PK): - raise invalid_exc + if not PK: + return - fingerprint = get_fingerprint(PK.strip().rstrip()) - if not fingerprint: - raise invalid_exc + try: + keys = util.parse_ssh_keys(PK.strip()) + except ValueError as exc: + raise ValidationError([str(exc)]) + + for prefix, key in keys: + fingerprint = get_fingerprint(f"{prefix} {key}") exists = db.query(models.SSHPubKey).filter( and_(models.SSHPubKey.UserID != user.ID, diff --git a/aurweb/util.py b/aurweb/util.py index 7ed4d1d3..6759794f 100644 --- a/aurweb/util.py +++ b/aurweb/util.py @@ -1,4 +1,3 @@ -import base64 import math import re import secrets @@ -7,7 +6,8 @@ import string from datetime import datetime from distutils.util import strtobool as _strtobool from http import HTTPStatus -from typing import Callable, Iterable, Tuple, Union +from subprocess import PIPE, Popen +from typing import Callable, Iterable, List, Tuple, Union from urllib.parse import urlparse import fastapi @@ -82,25 +82,6 @@ def valid_pgp_fingerprint(fp): return len(fp) == 40 -def valid_ssh_pubkey(pk): - valid_prefixes = aurweb.config.get("auth", "valid-keytypes") - valid_prefixes = set(valid_prefixes.split(" ")) - - has_valid_prefix = False - for prefix in valid_prefixes: - if "%s " % prefix in pk: - has_valid_prefix = True - break - if not has_valid_prefix: - return False - - tokens = pk.strip().rstrip().split(" ") - if len(tokens) < 2: - return False - - return base64.b64encode(base64.b64decode(tokens[1])).decode() == tokens[1] - - def jsonify(obj): """ Perform a conversion on obj if it's needed. """ if isinstance(obj, datetime): @@ -191,3 +172,29 @@ async def error_or_result(next: Callable, *args, **kwargs) \ status_code = HTTPStatus.INTERNAL_SERVER_ERROR return JSONResponse({"error": str(exc)}, status_code=status_code) return response + + +def parse_ssh_key(string: str) -> Tuple[str, str]: + """ Parse an SSH public key. """ + invalid_exc = ValueError("The SSH public key is invalid.") + parts = re.sub(r'\s\s+', ' ', string.strip()).split() + if len(parts) < 2: + raise invalid_exc + + prefix, key = parts[:2] + prefixes = set(aurweb.config.get("auth", "valid-keytypes").split(" ")) + if prefix not in prefixes: + raise invalid_exc + + proc = Popen(["ssh-keygen", "-l", "-f", "-"], stdin=PIPE, stdout=PIPE, + stderr=PIPE) + out, _ = proc.communicate(f"{prefix} {key}".encode()) + if proc.returncode: + raise invalid_exc + + return (prefix, key) + + +def parse_ssh_keys(string: str) -> List[Tuple[str, str]]: + """ Parse a list of SSH public keys. """ + return [parse_ssh_key(e) for e in string.splitlines()] diff --git a/templates/partials/account_form.html b/templates/partials/account_form.html index 9136ee7a..007fb389 100644 --- a/templates/partials/account_form.html +++ b/templates/partials/account_form.html @@ -262,7 +262,7 @@ + rows="5" cols="30">{{ ssh_pks | join("\n") }}

diff --git a/test/test_accounts_routes.py b/test/test_accounts_routes.py index 92b33730..e532e341 100644 --- a/test/test_accounts_routes.py +++ b/test/test_accounts_routes.py @@ -577,10 +577,13 @@ def test_post_register_error_ssh_pubkey_taken(client: TestClient, user: User): # Read in the public key, then delete the temp dir we made. pk = open(f"{tmpdir}/test.ssh.pub").read().rstrip() + prefix, key, loc = pk.split() + norm_pk = prefix + " " + key + # Take the sha256 fingerprint of the ssh public key, create it. - fp = get_fingerprint(pk) + fp = get_fingerprint(norm_pk) with db.begin(): - create(SSHPubKey, UserID=user.ID, PubKey=pk, Fingerprint=fp) + create(SSHPubKey, UserID=user.ID, PubKey=norm_pk, Fingerprint=fp) with client as request: response = post_register(request, PK=pk) @@ -1080,22 +1083,16 @@ def test_post_account_edit_missing_ssh_pubkey(client: TestClient, user: User): def test_post_account_edit_invalid_ssh_pubkey(client: TestClient, user: User): pubkey = "ssh-rsa fake key" - request = Request() - sid = user.login(request, "testPassword") - - post_data = { + data = { "U": "test", "E": "test@example.org", - "P": "newPassword", - "C": "newPassword", "PK": pubkey, "passwd": "testPassword" } - + cookies = {"AURSID": user.login(Request(), "testPassword")} with client as request: - response = request.post("/account/test/edit", cookies={ - "AURSID": sid - }, data=post_data, allow_redirects=False) + response = request.post("/account/test/edit", data=data, + cookies=cookies, allow_redirects=False) assert response.status_code == int(HTTPStatus.BAD_REQUEST) diff --git a/test/test_adduser.py b/test/test_adduser.py index 6c71a519..c6210e74 100644 --- a/test/test_adduser.py +++ b/test/test_adduser.py @@ -53,4 +53,4 @@ def test_adduser_ssh_pk(): "--ssh-pubkey", TEST_SSH_PUBKEY]) test = db.query(User).filter(User.Username == "test").first() assert test is not None - assert TEST_SSH_PUBKEY.startswith(test.ssh_pub_key.PubKey) + assert TEST_SSH_PUBKEY.startswith(test.ssh_pub_keys.first().PubKey) diff --git a/test/test_ssh_pub_key.py b/test/test_ssh_pub_key.py index 68b6e7a0..93298a11 100644 --- a/test/test_ssh_pub_key.py +++ b/test/test_ssh_pub_key.py @@ -1,3 +1,5 @@ +from subprocess import PIPE, Popen + import pytest from aurweb import db @@ -61,8 +63,12 @@ def test_pubkey_cs(user: User): def test_pubkey_fingerprint(): - assert get_fingerprint(TEST_SSH_PUBKEY) is not None + proc = Popen(["ssh-keygen", "-l", "-f", "-"], stdin=PIPE, stdout=PIPE) + out, _ = proc.communicate(TEST_SSH_PUBKEY.encode()) + expected = out.decode().split()[1].split(":", 1)[1] + assert get_fingerprint(TEST_SSH_PUBKEY) == expected def test_pubkey_invalid_fingerprint(): - assert get_fingerprint("ssh-rsa fake and invalid") is None + with pytest.raises(ValueError): + get_fingerprint("invalid-prefix some-fake-content") diff --git a/test/test_user.py b/test/test_user.py index 7871cd61..5f25f3c9 100644 --- a/test/test_user.py +++ b/test/test_user.py @@ -183,14 +183,14 @@ def test_user_has_credential(user: User): def test_user_ssh_pub_key(user: User): - assert user.ssh_pub_key is None + assert user.ssh_pub_keys.first() is None with db.begin(): ssh_pub_key = db.create(SSHPubKey, UserID=user.ID, Fingerprint="testFingerprint", PubKey="testPubKey") - assert user.ssh_pub_key == ssh_pub_key + assert user.ssh_pub_keys.first() == ssh_pub_key def test_user_credential_types(user: User): diff --git a/test/test_util.py b/test/test_util.py index 51d978fb..ae1de81b 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -60,3 +60,53 @@ def test_valid_homepage(): assert not util.valid_homepage("https://[google.com/broken-ipv6") assert not util.valid_homepage("gopher://gopher.hprc.utoronto.ca/") + + +def test_parse_ssh_key(): + # Test a valid key. + pk = """ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyN\ +TYAAABBBEURnkiY6JoLyqDE8Li1XuAW+LHmkmLDMW/GL5wY7k4/A+Ta7bjA3MOKrF9j4EuUTvCuNX\ +ULxvpfSqheTFWZc+g=""" + prefix, key = util.parse_ssh_key(pk) + e_prefix, e_key = pk.split() + assert prefix == e_prefix + assert key == e_key + + # Test an invalid key with just one word in it. + with pytest.raises(ValueError): + util.parse_ssh_key("ssh-rsa") + + # Test a valid key with extra words in it (after the PK). + pk = pk + " blah blah" + prefix, key = util.parse_ssh_key(pk) + assert prefix == e_prefix + assert key == e_key + + # Test an invalid prefix. + with pytest.raises(ValueError): + util.parse_ssh_key("invalid-prefix fake-content") + + +def test_parse_ssh_keys(): + pks = """ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyN\ +TYAAABBBEURnkiY6JoLyqDE8Li1XuAW+LHmkmLDMW/GL5wY7k4/A+Ta7bjA3MOKrF9j4EuUTvCuNX\ +ULxvpfSqheTFWZc+g= +ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQDmqEapFMh/ajPHnm1dBweYPeLOUjC0Ydp6uw7rB\ +S5KCggUVQR8WfIm+sRYTj2+smGsK6zHMBjFnbzvV11vnMqcnY+Sa4LhIAdwkbt/b8HlGaLj1hCWSh\ +a5b5/noeK7L+CECGHdvfJhpxBbhq38YEdFnCGbslk/4NriNcUp/DO81CXb1RzJ9GBFH8ivPW1mbe9\ +YbxDwGimZZslg0OZu9UzoAT6xEGyiZsqJkTMbRp1ZYIOv9jHCJxRuxxuN3fzxyT3xE69+vhq2/NJX\ +8aRsxGPL9G/XKcaYGS6y6LW4quIBCz/XsTZfx1GmkQeZPYHH8FeE+XC/+toXL/kamxdOQKFYEEpWK\ +vTNJCD6JtMClxbIXW9q74nNqG+2SD/VQNMUz/505TK1PbY/4uyFfq5HquHJXQVCBll03FRerNHH2N\ +schFne6BFHpa48PCoZNH45wLjFXwUyrGU1HrNqh6ZPdRfBTrTOkgs+BKBxGNeV45aYUPu/cFBSPcB\ +fRSo6OFcejKc=""" + keys = util.parse_ssh_keys(pks) + assert len(keys) == 2 + + pfx1, key1, pfx2, key2 = pks.split() + k1, k2 = keys + + assert pfx1 == k1[0] + assert key1 == k1[1] + + assert pfx2 == k2[0] + assert key2 == k2[1]