diff --git a/aurweb/db.py b/aurweb/db.py index b8b49e40..70ad58d1 100644 --- a/aurweb/db.py +++ b/aurweb/db.py @@ -24,42 +24,15 @@ DRIVERS = { "mysql": "mysql+mysqldb" } -# Global introspected object memo. -introspected = dict() - # Some types we don't get access to in this module. Base = NewType("Base", "aurweb.models.declarative_base.Base") -def make_random_value(table: str, column: str): +def make_random_value(table: str, column: str, length: int): """ Generate a unique, random value for a string column in a table. - This can be used to generate for example, session IDs that - align with the properties of the database column with regards - to size. - - Internally, we use SQLAlchemy introspection to look at column - to decide which length to use for random string generation. - :return: A unique string that is not in the database """ - global introspected - - # Make sure column is converted to a string for memo interaction. - scolumn = str(column) - - # If the target column is not yet introspected, store its introspection - # object into our global `introspected` memo. - if scolumn not in introspected: - from sqlalchemy import inspect - target_column = scolumn.split('.')[-1] - col = list(filter(lambda c: c.name == target_column, - inspect(table).columns))[0] - introspected[scolumn] = col - - col = introspected.get(scolumn) - length = col.type.length - string = aurweb.util.make_random_string(length) while query(table).filter(column == string).first(): string = aurweb.util.make_random_string(length) diff --git a/aurweb/models/session.py b/aurweb/models/session.py index 96f88d85..7a06eddc 100644 --- a/aurweb/models/session.py +++ b/aurweb/models/session.py @@ -1,8 +1,7 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import backref, relationship -from aurweb import schema -from aurweb.db import make_random_value, query +from aurweb import db, schema from aurweb.models.declarative import Base from aurweb.models.user import User as _User @@ -19,8 +18,8 @@ class Session(Base): def __init__(self, **kwargs): super().__init__(**kwargs) - user_exists = query( - query(_User).filter(_User.ID == self.UsersID).exists() + user_exists = db.query( + db.query(_User).filter(_User.ID == self.UsersID).exists() ).scalar() if not user_exists: raise IntegrityError( @@ -31,4 +30,4 @@ class Session(Base): def generate_unique_sid(): - return make_random_value(Session, Session.SessionID) + return db.make_random_value(Session, Session.SessionID, 32) diff --git a/aurweb/models/user.py b/aurweb/models/user.py index 43910db9..03634a36 100644 --- a/aurweb/models/user.py +++ b/aurweb/models/user.py @@ -230,3 +230,7 @@ class User(Base): def __repr__(self): return "" % ( self.ID, str(self.AccountType), self.Username) + + +def generate_unique_resetkey(): + return db.make_random_value(User, User.ResetKey, 32) diff --git a/aurweb/routers/accounts.py b/aurweb/routers/accounts.py index 02a7f4c6..ddee1764 100644 --- a/aurweb/routers/accounts.py +++ b/aurweb/routers/accounts.py @@ -16,6 +16,7 @@ from aurweb.exceptions import ValidationError from aurweb.l10n import get_translator_for_request from aurweb.models import account_type as at from aurweb.models.ssh_pub_key import get_fingerprint +from aurweb.models.user import generate_unique_resetkey from aurweb.scripts.notify import ResetKeyNotification, WelcomeNotification from aurweb.templates import make_context, make_variable_context, render_template from aurweb.users import update, validate @@ -92,7 +93,7 @@ async def passreset_post(request: Request, status_code=HTTPStatus.SEE_OTHER) # If we got here, we continue with issuing a resetkey for the user. - resetkey = db.make_random_value(models.User, models.User.ResetKey) + resetkey = generate_unique_resetkey() with db.begin(): user.ResetKey = resetkey @@ -291,7 +292,7 @@ async def account_register_post(request: Request, # Create a user with no password with a resetkey, then send # an email off about it. - resetkey = db.make_random_value(models.User, models.User.ResetKey) + resetkey = generate_unique_resetkey() # By default, we grab the User account type to associate with. atype = db.query(models.AccountType,