mirror of
https://gitlab.archlinux.org/archlinux/aurweb.git
synced 2025-02-03 10:43:03 +01:00
324 lines
8.1 KiB
Python
Executable file
324 lines
8.1 KiB
Python
Executable file
#!/usr/bin/python -O
|
|
#
|
|
# Description:
|
|
# ------------
|
|
# This is the server-side portion of the Trusted User package
|
|
# manager. This program will receive uploads from its client-side
|
|
# couterpart, tupkg. Once a package is received and verified, it
|
|
# is placed in a specified temporary incoming directory where
|
|
# a separate script will handle migrating it to the AUR. For
|
|
# more information, see the ../README.txt file.
|
|
#
|
|
# Python Indentation:
|
|
# -------------------
|
|
# For a vim: line to be effective, it must be at the end of the
|
|
# file. See the end of the file for more information.
|
|
|
|
import sys
|
|
import socket
|
|
import threading
|
|
import select
|
|
import struct
|
|
import cgi
|
|
import urllib
|
|
import md5
|
|
import MySQLdb
|
|
import MySQLdb.connections
|
|
import ConfigParser
|
|
import getopt
|
|
import os.path
|
|
import os
|
|
import time
|
|
|
|
CONFIGFILE = '/etc/tupkgs.conf'
|
|
|
|
config = ConfigParser.ConfigParser()
|
|
|
|
class ClientFile:
|
|
def __init__(self, filename, actual_size, actual_md5):
|
|
self.pathname = os.path.join(confdict['incomingdir'], filename)
|
|
self.filename = filename
|
|
self.fd = open(self.pathname, "a+b")
|
|
self.actual_size = actual_size
|
|
self.actual_md5 = actual_md5
|
|
self.getSize()
|
|
self.orig_size = self.size
|
|
|
|
def getSize(self):
|
|
cur = self.fd.tell()
|
|
self.fd.seek(0,2)
|
|
self.size = self.fd.tell()
|
|
self.fd.seek(cur)
|
|
|
|
def makeMd5(self):
|
|
md5sum = md5.new()
|
|
cur = self.fd.tell()
|
|
self.getSize()
|
|
self.fd.seek(0)
|
|
while self.fd.tell() != self.size:
|
|
md5sum.update(self.fd.read(1024))
|
|
self.fd.seek(cur)
|
|
self.md5 = md5sum.hexdigest()
|
|
|
|
def finishDownload(self):
|
|
self.fd.close()
|
|
newpathname = os.path.join(confdict['cachedir'], self.filename)
|
|
os.rename(self.pathname, newpathname)
|
|
self.pathname = newpathname
|
|
self.fd = open(self.pathname, "a+b")
|
|
|
|
def delete(self):
|
|
self.fd.close()
|
|
os.remove(self.pathname)
|
|
|
|
class ClientSocket(threading.Thread):
|
|
def __init__(self, sock, **other):
|
|
threading.Thread.__init__(self, *other)
|
|
self.socket = sock
|
|
self.running = 1
|
|
self.files = []
|
|
|
|
def close(self):
|
|
self.running = 0
|
|
|
|
def reliableRead(self, size):
|
|
totalread = ""
|
|
while len(totalread) < size:
|
|
read = self.socket.recv(size-len(totalread))
|
|
if len(read) == 0:
|
|
raise RuntimeError, "socket connection broken"
|
|
totalread += read
|
|
return totalread
|
|
|
|
def sendMsg(self, msg):
|
|
if type(msg) == dict:
|
|
msg = urllib.unquote(urllib.urlencode(msg,1))
|
|
length = struct.pack("H", socket.htons(len(msg)))
|
|
self.socket.sendall(length)
|
|
self.socket.sendall(msg)
|
|
|
|
def readMsg(self, format=0):
|
|
initsize = self.reliableRead(2)
|
|
(length,) = struct.unpack("H", initsize)
|
|
length = socket.ntohs(length)
|
|
data = self.reliableRead(length)
|
|
if format == 1:
|
|
qs = cgi.parse_qs(data)
|
|
return qs
|
|
else:
|
|
return data
|
|
|
|
def auth(self):
|
|
authdata = self.readMsg(1)
|
|
print authdata
|
|
if (not authdata.has_key('username')) or (not authdata.has_key('password')):
|
|
self.sendMsg("result=FAIL")
|
|
return 0
|
|
|
|
print "Connecting to MySQL database"
|
|
dbconn = MySQLdb.connect(host=config.get('mysql', 'host'),
|
|
user=config.get('mysql', 'username'),
|
|
passwd=config.get('mysql', 'password'),
|
|
db=config.get('mysql', 'db'))
|
|
|
|
q = dbconn.cursor()
|
|
m = md5.new()
|
|
m.update(authdata['password'][0])
|
|
encpw = m.hexdigest()
|
|
try:
|
|
q.execute("SELECT ID, Suspended, AccountTypeID FROM Users WHERE Username = '"+
|
|
MySQLdb.escape_string(authdata['username'][0])+
|
|
"' AND Passwd = '"+
|
|
MySQLdb.escape_string(encpw)+
|
|
"'")
|
|
dbconn.close()
|
|
except :
|
|
self.sendMsg("result=SQLERR")
|
|
return 0
|
|
if q.rowcount == 0:
|
|
self.sendMsg("result=FAIL")
|
|
return 0
|
|
row = q.fetchone()
|
|
if row[1] != 0:
|
|
self.sendMsg("result=FAIL")
|
|
return 0
|
|
if row[2] not in (2, 3):
|
|
self.sendMsg("result=FAIL")
|
|
return 0
|
|
self.sendMsg("result=PASS")
|
|
return 1
|
|
|
|
def readFileMeta(self):
|
|
files = self.readMsg(1)
|
|
print files
|
|
# Actually do file checking, et al
|
|
for i in range(int(files['numpkgs'][0])):
|
|
self.files.append(ClientFile(files['name'+str(i)][0], int(files['size'+str(i)][0]), files['md5sum'+str(i)][0]))
|
|
new_files = files.copy()
|
|
for i in files:
|
|
if i[:4] == 'size':
|
|
clientfile = self.files[int(i[4:])]
|
|
new_files[i] = str(clientfile.orig_size)
|
|
if i[:6] == 'md5sum':
|
|
del new_files[i]
|
|
self.sendMsg(new_files)
|
|
|
|
def readFiles(self):
|
|
for i in self.files:
|
|
count = i.orig_size
|
|
while count != i.actual_size:
|
|
if count + 1024 > i.actual_size:
|
|
i.fd.write(self.reliableRead(i.actual_size - count))
|
|
count += i.actual_size - count
|
|
else:
|
|
i.fd.write(self.reliableRead(1024))
|
|
count += 1024
|
|
i.fd.flush()
|
|
reply = {'numpkgs': len(self.files)}
|
|
for i, v in enumerate(self.files):
|
|
v.makeMd5()
|
|
if v.actual_md5 == v.md5:
|
|
reply['md5sum'+str(i)] = "PASS"
|
|
v.finishDownload()
|
|
else:
|
|
reply['md5sum'+str(i)] = "FAIL"
|
|
v.delete()
|
|
self.sendMsg(reply)
|
|
print self.readMsg()
|
|
|
|
def run(self):
|
|
try:
|
|
if not self.auth():
|
|
self.close()
|
|
return
|
|
self.readFileMeta()
|
|
self.readFiles()
|
|
except RuntimeError, err:
|
|
if err.__str__() == "socket connection broken":
|
|
print "Client disconnected, cleaning up"
|
|
self.close()
|
|
return
|
|
|
|
class ServerSocket(threading.Thread):
|
|
def __init__(self, port, maxqueue, **other):
|
|
threading.Thread.__init__(self, *other)
|
|
self.running = 1
|
|
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
self.socket.bind(('', port))
|
|
self.socket.listen(maxqueue)
|
|
self.clients = []
|
|
|
|
def _clean(self, client):
|
|
if not client.isAlive():
|
|
return 0
|
|
return 1
|
|
|
|
def close(self):
|
|
self.socket.close()
|
|
self.running = 0
|
|
|
|
def run(self):
|
|
while self.running:
|
|
sread, swrite, serror = select.select([self.socket],[self.socket],[self.socket],1)
|
|
if sread:
|
|
(clientsocket, address) = self.socket.accept()
|
|
print "New connection from " + str(address)
|
|
ct = ClientSocket(clientsocket)
|
|
ct.start()
|
|
self.clients.append(ct)
|
|
|
|
self.clients = filter(self._clean, self.clients)
|
|
self.socket.close()
|
|
[x.close() for x in self.clients]
|
|
[x.join() for x in self.clients]
|
|
|
|
def usage(name):
|
|
print "usage: " + name + " [options]"
|
|
print "options:"
|
|
print " -c, --config Specify an alternate config file (default " + CONFIGFILE + ")"
|
|
|
|
def getDefaultConfig():
|
|
confdict = {}
|
|
confdict['port'] = 1034
|
|
confdict['cachedir'] = '/var/cache/tupkgs/'
|
|
confdict['incomingdir'] = '/var/cache/tupkgs/incomplete/'
|
|
confdict['maxqueue'] = 5
|
|
|
|
return confdict
|
|
|
|
|
|
confdict = getDefaultConfig()
|
|
|
|
def main(argv=None):
|
|
if argv is None:
|
|
argv = sys.argv
|
|
|
|
try:
|
|
optlist, args = getopt.getopt(argv[1:], "c:", ["config="])
|
|
except getopt.GetoptError:
|
|
usage(argv[0])
|
|
return 1
|
|
|
|
conffile = CONFIGFILE
|
|
|
|
for i, k in optlist:
|
|
if i in ('-c', '--config'):
|
|
conffile = k
|
|
|
|
if not os.path.isfile(conffile):
|
|
print "Error: cannot access config file ("+conffile+")"
|
|
usage(argv[0])
|
|
return 1
|
|
|
|
config.read(conffile)
|
|
|
|
running = 1
|
|
|
|
print "Parsing config file"
|
|
if config.has_section('tupkgs'):
|
|
if config.has_option('tupkgs', 'port'):
|
|
confdict['port'] = config.getint('tupkgs', 'port')
|
|
if config.has_option('tupkgs', 'maxqueue'):
|
|
confdict['maxqueue'] = config.getint('tupkgs', 'maxqueue')
|
|
if config.has_option('tupkgs', 'cachedir'):
|
|
confdict['cachedir'] = config.get('tupkgs', 'cachedir')
|
|
if config.has_option('tupkgs', 'incomingdir'):
|
|
confdict['incomingdir'] = config.get('tupkgs', 'incomingdir')
|
|
|
|
print "Verifying "+confdict['cachedir']+" and "+confdict['incomingdir']+" exist"
|
|
if not os.path.isdir(confdict['cachedir']):
|
|
print "Creating "+confdict['cachedir']
|
|
os.mkdir(confdict['cachedir'], 0755)
|
|
if not os.path.isdir(confdict['incomingdir']):
|
|
print "Creating "+confdict['incomingdir']
|
|
os.mkdir(confdict['incomingdir'], 0755)
|
|
|
|
print "Starting ServerSocket"
|
|
servsock = ServerSocket(confdict['port'], confdict['maxqueue'])
|
|
servsock.start()
|
|
|
|
try:
|
|
while running:
|
|
# Maybe do stuff here?
|
|
time.sleep(10)
|
|
except KeyboardInterrupt:
|
|
running = 0
|
|
|
|
print "Waiting for threads to die"
|
|
|
|
servsock.close()
|
|
|
|
servsock.join()
|
|
|
|
|
|
return 0
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|
|
|
|
# Python Indentation:
|
|
# -------------------
|
|
# Use tabs not spaces. If you use vim, the following comment will
|
|
# configure it to use tabs.
|
|
#
|
|
# vim:noet:ts=2 sw=2 ft=python
|