mirror of
https://gitlab.archlinux.org/archlinux/aurweb.git
synced 2025-02-03 10:43:03 +01:00
fix: support multiple SSHPubKey records per user
There was one blazing issue with the previous implementation regardless of the multiple records: we were generating fingerprints by storing the key into a file and reading it with ssh-keygen. This is absolutely terrible and was not meant to be left around (it was forgotten, my bad). Took this opportunity to clean up a few things: - simplify pubkey validation - centralize things a bit better Signed-off-by: Kevin Morris <kevr@0cost.org>
This commit is contained in:
parent
660d57340a
commit
4c14a10b91
11 changed files with 162 additions and 108 deletions
|
@ -1,6 +1,3 @@
|
|||
import os
|
||||
import tempfile
|
||||
|
||||
from subprocess import PIPE, Popen
|
||||
|
||||
from sqlalchemy.orm import backref, relationship
|
||||
|
@ -15,28 +12,17 @@ class SSHPubKey(Base):
|
|||
__mapper_args__ = {"primary_key": [__table__.c.Fingerprint]}
|
||||
|
||||
User = relationship(
|
||||
"User", backref=backref("ssh_pub_key", uselist=False),
|
||||
"User", backref=backref("ssh_pub_keys", lazy="dynamic"),
|
||||
foreign_keys=[__table__.c.UserID])
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
def get_fingerprint(pubkey):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pk = os.path.join(tmpdir, "ssh.pub")
|
||||
|
||||
with open(pk, "w") as f:
|
||||
f.write(pubkey)
|
||||
|
||||
proc = Popen(["ssh-keygen", "-l", "-f", pk], stdout=PIPE, stderr=PIPE)
|
||||
out, err = proc.communicate()
|
||||
|
||||
# Invalid SSH Public Key. Return None to the caller.
|
||||
if proc.returncode != 0:
|
||||
return None
|
||||
|
||||
parts = out.decode().split()
|
||||
fp = parts[1].replace("SHA256:", "")
|
||||
|
||||
return fp
|
||||
def get_fingerprint(pubkey: str) -> str:
|
||||
proc = Popen(["ssh-keygen", "-l", "-f", "-"], stdin=PIPE, stdout=PIPE,
|
||||
stderr=PIPE)
|
||||
out, _ = proc.communicate(pubkey.encode())
|
||||
if proc.returncode:
|
||||
raise ValueError("The SSH public key is invalid.")
|
||||
return out.decode().split()[1].split(":", 1)[1]
|
||||
|
|
|
@ -2,6 +2,7 @@ import copy
|
|||
import typing
|
||||
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import APIRouter, Form, Request
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
|
@ -105,7 +106,8 @@ async def passreset_post(request: Request,
|
|||
status_code=HTTPStatus.SEE_OTHER)
|
||||
|
||||
|
||||
def process_account_form(request: Request, user: models.User, args: dict):
|
||||
def process_account_form(request: Request, user: models.User,
|
||||
args: Dict[str, Any]):
|
||||
""" Process an account form. All fields are optional and only checks
|
||||
requirements in the case they are present.
|
||||
|
||||
|
@ -193,8 +195,8 @@ def make_account_form_context(context: dict,
|
|||
context["pgp"] = args.get("K", user.PGPKey or str())
|
||||
context["lang"] = args.get("L", user.LangPreference)
|
||||
context["tz"] = args.get("TZ", user.Timezone)
|
||||
ssh_pk = user.ssh_pub_key.PubKey if user.ssh_pub_key else str()
|
||||
context["ssh_pk"] = args.get("PK", ssh_pk)
|
||||
ssh_pks = [pk.PubKey for pk in user.ssh_pub_keys]
|
||||
context["ssh_pks"] = args.get("PK", ssh_pks)
|
||||
context["cn"] = args.get("CN", user.CommentNotify)
|
||||
context["un"] = args.get("UN", user.UpdateNotify)
|
||||
context["on"] = args.get("ON", user.OwnershipNotify)
|
||||
|
@ -212,7 +214,7 @@ def make_account_form_context(context: dict,
|
|||
context["pgp"] = args.get("K", str())
|
||||
context["lang"] = args.get("L", context.get("language"))
|
||||
context["tz"] = args.get("TZ", context.get("timezone"))
|
||||
context["ssh_pk"] = args.get("PK", str())
|
||||
context["ssh_pks"] = args.get("PK", str())
|
||||
context["cn"] = args.get("CN", True)
|
||||
context["un"] = args.get("UN", False)
|
||||
context["on"] = args.get("ON", True)
|
||||
|
@ -314,16 +316,13 @@ async def account_register_post(request: Request,
|
|||
# PK mismatches the existing user's SSHPubKey.PubKey.
|
||||
if PK:
|
||||
# Get the second element in the PK, which is the actual key.
|
||||
pubkey = PK.strip().rstrip()
|
||||
parts = pubkey.split(" ")
|
||||
if len(parts) == 3:
|
||||
# Remove the host part.
|
||||
pubkey = parts[0] + " " + parts[1]
|
||||
fingerprint = get_fingerprint(pubkey)
|
||||
with db.begin():
|
||||
user.ssh_pub_key = models.SSHPubKey(UserID=user.ID,
|
||||
PubKey=pubkey,
|
||||
Fingerprint=fingerprint)
|
||||
keys = util.parse_ssh_keys(PK.strip())
|
||||
for k in keys:
|
||||
pk = " ".join(k)
|
||||
fprint = get_fingerprint(pk)
|
||||
with db.begin():
|
||||
db.create(models.SSHPubKey, UserID=user.ID,
|
||||
PubKey=pk, Fingerprint=fprint)
|
||||
|
||||
# Send a reset key notification to the new user.
|
||||
WelcomeNotification(user.ID).send()
|
||||
|
@ -409,6 +408,9 @@ async def account_edit_post(request: Request,
|
|||
context = make_account_form_context(context, request, user, args)
|
||||
ok, errors = process_account_form(request, user, args)
|
||||
|
||||
if PK:
|
||||
context["ssh_pks"] = [PK]
|
||||
|
||||
if not passwd:
|
||||
context["errors"] = ["Invalid password."]
|
||||
return render_template(request, "account/edit.html", context,
|
||||
|
|
|
@ -2,7 +2,8 @@ from typing import Any, Dict
|
|||
|
||||
from fastapi import Request
|
||||
|
||||
from aurweb import cookies, db, models, time
|
||||
from aurweb import cookies, db, models, time, util
|
||||
from aurweb.models import SSHPubKey
|
||||
from aurweb.models.ssh_pub_key import get_fingerprint
|
||||
from aurweb.util import strtobool
|
||||
|
||||
|
@ -52,32 +53,35 @@ def timezone(TZ: str = str(),
|
|||
context["language"] = TZ
|
||||
|
||||
|
||||
def ssh_pubkey(PK: str = str(),
|
||||
user: models.User = None,
|
||||
**kwargs) -> None:
|
||||
# If a PK is given, compare it against the target user's PK.
|
||||
if PK:
|
||||
# Get the second token in the public key, which is the actual key.
|
||||
pubkey = PK.strip().rstrip()
|
||||
parts = pubkey.split(" ")
|
||||
if len(parts) == 3:
|
||||
# Remove the host part.
|
||||
pubkey = parts[0] + " " + parts[1]
|
||||
fingerprint = get_fingerprint(pubkey)
|
||||
if not user.ssh_pub_key:
|
||||
# No public key exists, create one.
|
||||
with db.begin():
|
||||
db.create(models.SSHPubKey, UserID=user.ID,
|
||||
PubKey=pubkey, Fingerprint=fingerprint)
|
||||
elif user.ssh_pub_key.PubKey != pubkey:
|
||||
# A public key already exists, update it.
|
||||
with db.begin():
|
||||
user.ssh_pub_key.PubKey = pubkey
|
||||
user.ssh_pub_key.Fingerprint = fingerprint
|
||||
elif user.ssh_pub_key:
|
||||
# Else, if the user has a public key already, delete it.
|
||||
def ssh_pubkey(PK: str = str(), user: models.User = None, **kwargs) -> None:
|
||||
if not PK:
|
||||
# If no pubkey is provided, wipe out any pubkeys the user
|
||||
# has and return out early.
|
||||
with db.begin():
|
||||
db.delete(user.ssh_pub_key)
|
||||
db.delete_all(user.ssh_pub_keys)
|
||||
return
|
||||
|
||||
# Otherwise, parse ssh keys and their fprints out of PK.
|
||||
keys = util.parse_ssh_keys(PK.strip())
|
||||
fprints = [get_fingerprint(" ".join(k)) for k in keys]
|
||||
|
||||
with db.begin():
|
||||
# Delete any existing keys we can't find.
|
||||
to_remove = user.ssh_pub_keys.filter(
|
||||
~SSHPubKey.Fingerprint.in_(fprints))
|
||||
db.delete_all(to_remove)
|
||||
|
||||
# For each key, if it does not yet exist, create it.
|
||||
for i, full_key in enumerate(keys):
|
||||
prefix, key = full_key
|
||||
exists = user.ssh_pub_keys.filter(
|
||||
SSHPubKey.Fingerprint == fprints[i]
|
||||
).exists()
|
||||
if not db.query(exists).scalar():
|
||||
# No public key exists, create one.
|
||||
db.create(models.SSHPubKey, UserID=user.ID,
|
||||
PubKey=" ".join([prefix, key]),
|
||||
Fingerprint=fprints[i])
|
||||
|
||||
|
||||
def account_type(T: int = None,
|
||||
|
|
|
@ -107,14 +107,16 @@ def invalid_pgp_key(K: str = str(), **kwargs) -> None:
|
|||
|
||||
def invalid_ssh_pubkey(PK: str = str(), user: models.User = None,
|
||||
_: l10n.Translator = None, **kwargs) -> None:
|
||||
if PK:
|
||||
invalid_exc = ValidationError(["The SSH public key is invalid."])
|
||||
if not util.valid_ssh_pubkey(PK):
|
||||
raise invalid_exc
|
||||
if not PK:
|
||||
return
|
||||
|
||||
fingerprint = get_fingerprint(PK.strip().rstrip())
|
||||
if not fingerprint:
|
||||
raise invalid_exc
|
||||
try:
|
||||
keys = util.parse_ssh_keys(PK.strip())
|
||||
except ValueError as exc:
|
||||
raise ValidationError([str(exc)])
|
||||
|
||||
for prefix, key in keys:
|
||||
fingerprint = get_fingerprint(f"{prefix} {key}")
|
||||
|
||||
exists = db.query(models.SSHPubKey).filter(
|
||||
and_(models.SSHPubKey.UserID != user.ID,
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import base64
|
||||
import math
|
||||
import re
|
||||
import secrets
|
||||
|
@ -7,7 +6,8 @@ import string
|
|||
from datetime import datetime
|
||||
from distutils.util import strtobool as _strtobool
|
||||
from http import HTTPStatus
|
||||
from typing import Callable, Iterable, Tuple, Union
|
||||
from subprocess import PIPE, Popen
|
||||
from typing import Callable, Iterable, List, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import fastapi
|
||||
|
@ -82,25 +82,6 @@ def valid_pgp_fingerprint(fp):
|
|||
return len(fp) == 40
|
||||
|
||||
|
||||
def valid_ssh_pubkey(pk):
|
||||
valid_prefixes = aurweb.config.get("auth", "valid-keytypes")
|
||||
valid_prefixes = set(valid_prefixes.split(" "))
|
||||
|
||||
has_valid_prefix = False
|
||||
for prefix in valid_prefixes:
|
||||
if "%s " % prefix in pk:
|
||||
has_valid_prefix = True
|
||||
break
|
||||
if not has_valid_prefix:
|
||||
return False
|
||||
|
||||
tokens = pk.strip().rstrip().split(" ")
|
||||
if len(tokens) < 2:
|
||||
return False
|
||||
|
||||
return base64.b64encode(base64.b64decode(tokens[1])).decode() == tokens[1]
|
||||
|
||||
|
||||
def jsonify(obj):
|
||||
""" Perform a conversion on obj if it's needed. """
|
||||
if isinstance(obj, datetime):
|
||||
|
@ -191,3 +172,29 @@ async def error_or_result(next: Callable, *args, **kwargs) \
|
|||
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
return JSONResponse({"error": str(exc)}, status_code=status_code)
|
||||
return response
|
||||
|
||||
|
||||
def parse_ssh_key(string: str) -> Tuple[str, str]:
|
||||
""" Parse an SSH public key. """
|
||||
invalid_exc = ValueError("The SSH public key is invalid.")
|
||||
parts = re.sub(r'\s\s+', ' ', string.strip()).split()
|
||||
if len(parts) < 2:
|
||||
raise invalid_exc
|
||||
|
||||
prefix, key = parts[:2]
|
||||
prefixes = set(aurweb.config.get("auth", "valid-keytypes").split(" "))
|
||||
if prefix not in prefixes:
|
||||
raise invalid_exc
|
||||
|
||||
proc = Popen(["ssh-keygen", "-l", "-f", "-"], stdin=PIPE, stdout=PIPE,
|
||||
stderr=PIPE)
|
||||
out, _ = proc.communicate(f"{prefix} {key}".encode())
|
||||
if proc.returncode:
|
||||
raise invalid_exc
|
||||
|
||||
return (prefix, key)
|
||||
|
||||
|
||||
def parse_ssh_keys(string: str) -> List[Tuple[str, str]]:
|
||||
""" Parse a list of SSH public keys. """
|
||||
return [parse_ssh_key(e) for e in string.splitlines()]
|
||||
|
|
|
@ -262,7 +262,7 @@
|
|||
|
||||
<!-- Only set PK auto-fill when we've got a NewAccount form. -->
|
||||
<textarea id="id_ssh" name="PK"
|
||||
rows="5" cols="30">{{ ssh_pk }}</textarea>
|
||||
rows="5" cols="30">{{ ssh_pks | join("\n") }}</textarea>
|
||||
</p>
|
||||
</fieldset>
|
||||
|
||||
|
|
|
@ -577,10 +577,13 @@ def test_post_register_error_ssh_pubkey_taken(client: TestClient, user: User):
|
|||
# Read in the public key, then delete the temp dir we made.
|
||||
pk = open(f"{tmpdir}/test.ssh.pub").read().rstrip()
|
||||
|
||||
prefix, key, loc = pk.split()
|
||||
norm_pk = prefix + " " + key
|
||||
|
||||
# Take the sha256 fingerprint of the ssh public key, create it.
|
||||
fp = get_fingerprint(pk)
|
||||
fp = get_fingerprint(norm_pk)
|
||||
with db.begin():
|
||||
create(SSHPubKey, UserID=user.ID, PubKey=pk, Fingerprint=fp)
|
||||
create(SSHPubKey, UserID=user.ID, PubKey=norm_pk, Fingerprint=fp)
|
||||
|
||||
with client as request:
|
||||
response = post_register(request, PK=pk)
|
||||
|
@ -1080,22 +1083,16 @@ def test_post_account_edit_missing_ssh_pubkey(client: TestClient, user: User):
|
|||
def test_post_account_edit_invalid_ssh_pubkey(client: TestClient, user: User):
|
||||
pubkey = "ssh-rsa fake key"
|
||||
|
||||
request = Request()
|
||||
sid = user.login(request, "testPassword")
|
||||
|
||||
post_data = {
|
||||
data = {
|
||||
"U": "test",
|
||||
"E": "test@example.org",
|
||||
"P": "newPassword",
|
||||
"C": "newPassword",
|
||||
"PK": pubkey,
|
||||
"passwd": "testPassword"
|
||||
}
|
||||
|
||||
cookies = {"AURSID": user.login(Request(), "testPassword")}
|
||||
with client as request:
|
||||
response = request.post("/account/test/edit", cookies={
|
||||
"AURSID": sid
|
||||
}, data=post_data, allow_redirects=False)
|
||||
response = request.post("/account/test/edit", data=data,
|
||||
cookies=cookies, allow_redirects=False)
|
||||
|
||||
assert response.status_code == int(HTTPStatus.BAD_REQUEST)
|
||||
|
||||
|
|
|
@ -53,4 +53,4 @@ def test_adduser_ssh_pk():
|
|||
"--ssh-pubkey", TEST_SSH_PUBKEY])
|
||||
test = db.query(User).filter(User.Username == "test").first()
|
||||
assert test is not None
|
||||
assert TEST_SSH_PUBKEY.startswith(test.ssh_pub_key.PubKey)
|
||||
assert TEST_SSH_PUBKEY.startswith(test.ssh_pub_keys.first().PubKey)
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from subprocess import PIPE, Popen
|
||||
|
||||
import pytest
|
||||
|
||||
from aurweb import db
|
||||
|
@ -61,8 +63,12 @@ def test_pubkey_cs(user: User):
|
|||
|
||||
|
||||
def test_pubkey_fingerprint():
|
||||
assert get_fingerprint(TEST_SSH_PUBKEY) is not None
|
||||
proc = Popen(["ssh-keygen", "-l", "-f", "-"], stdin=PIPE, stdout=PIPE)
|
||||
out, _ = proc.communicate(TEST_SSH_PUBKEY.encode())
|
||||
expected = out.decode().split()[1].split(":", 1)[1]
|
||||
assert get_fingerprint(TEST_SSH_PUBKEY) == expected
|
||||
|
||||
|
||||
def test_pubkey_invalid_fingerprint():
|
||||
assert get_fingerprint("ssh-rsa fake and invalid") is None
|
||||
with pytest.raises(ValueError):
|
||||
get_fingerprint("invalid-prefix some-fake-content")
|
||||
|
|
|
@ -183,14 +183,14 @@ def test_user_has_credential(user: User):
|
|||
|
||||
|
||||
def test_user_ssh_pub_key(user: User):
|
||||
assert user.ssh_pub_key is None
|
||||
assert user.ssh_pub_keys.first() is None
|
||||
|
||||
with db.begin():
|
||||
ssh_pub_key = db.create(SSHPubKey, UserID=user.ID,
|
||||
Fingerprint="testFingerprint",
|
||||
PubKey="testPubKey")
|
||||
|
||||
assert user.ssh_pub_key == ssh_pub_key
|
||||
assert user.ssh_pub_keys.first() == ssh_pub_key
|
||||
|
||||
|
||||
def test_user_credential_types(user: User):
|
||||
|
|
|
@ -60,3 +60,53 @@ def test_valid_homepage():
|
|||
assert not util.valid_homepage("https://[google.com/broken-ipv6")
|
||||
|
||||
assert not util.valid_homepage("gopher://gopher.hprc.utoronto.ca/")
|
||||
|
||||
|
||||
def test_parse_ssh_key():
|
||||
# Test a valid key.
|
||||
pk = """ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyN\
|
||||
TYAAABBBEURnkiY6JoLyqDE8Li1XuAW+LHmkmLDMW/GL5wY7k4/A+Ta7bjA3MOKrF9j4EuUTvCuNX\
|
||||
ULxvpfSqheTFWZc+g="""
|
||||
prefix, key = util.parse_ssh_key(pk)
|
||||
e_prefix, e_key = pk.split()
|
||||
assert prefix == e_prefix
|
||||
assert key == e_key
|
||||
|
||||
# Test an invalid key with just one word in it.
|
||||
with pytest.raises(ValueError):
|
||||
util.parse_ssh_key("ssh-rsa")
|
||||
|
||||
# Test a valid key with extra words in it (after the PK).
|
||||
pk = pk + " blah blah"
|
||||
prefix, key = util.parse_ssh_key(pk)
|
||||
assert prefix == e_prefix
|
||||
assert key == e_key
|
||||
|
||||
# Test an invalid prefix.
|
||||
with pytest.raises(ValueError):
|
||||
util.parse_ssh_key("invalid-prefix fake-content")
|
||||
|
||||
|
||||
def test_parse_ssh_keys():
|
||||
pks = """ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyN\
|
||||
TYAAABBBEURnkiY6JoLyqDE8Li1XuAW+LHmkmLDMW/GL5wY7k4/A+Ta7bjA3MOKrF9j4EuUTvCuNX\
|
||||
ULxvpfSqheTFWZc+g=
|
||||
ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQDmqEapFMh/ajPHnm1dBweYPeLOUjC0Ydp6uw7rB\
|
||||
S5KCggUVQR8WfIm+sRYTj2+smGsK6zHMBjFnbzvV11vnMqcnY+Sa4LhIAdwkbt/b8HlGaLj1hCWSh\
|
||||
a5b5/noeK7L+CECGHdvfJhpxBbhq38YEdFnCGbslk/4NriNcUp/DO81CXb1RzJ9GBFH8ivPW1mbe9\
|
||||
YbxDwGimZZslg0OZu9UzoAT6xEGyiZsqJkTMbRp1ZYIOv9jHCJxRuxxuN3fzxyT3xE69+vhq2/NJX\
|
||||
8aRsxGPL9G/XKcaYGS6y6LW4quIBCz/XsTZfx1GmkQeZPYHH8FeE+XC/+toXL/kamxdOQKFYEEpWK\
|
||||
vTNJCD6JtMClxbIXW9q74nNqG+2SD/VQNMUz/505TK1PbY/4uyFfq5HquHJXQVCBll03FRerNHH2N\
|
||||
schFne6BFHpa48PCoZNH45wLjFXwUyrGU1HrNqh6ZPdRfBTrTOkgs+BKBxGNeV45aYUPu/cFBSPcB\
|
||||
fRSo6OFcejKc="""
|
||||
keys = util.parse_ssh_keys(pks)
|
||||
assert len(keys) == 2
|
||||
|
||||
pfx1, key1, pfx2, key2 = pks.split()
|
||||
k1, k2 = keys
|
||||
|
||||
assert pfx1 == k1[0]
|
||||
assert key1 == k1[1]
|
||||
|
||||
assert pfx2 == k2[0]
|
||||
assert key2 == k2[1]
|
||||
|
|
Loading…
Add table
Reference in a new issue