]> git.gsnw.org Git - raspbmirror.git/commitdiff
Merge github f4e2a90 (add urllib3 support --ipv4 and --ipv6 flags)
authorGerman Service Network <support@gsnw.de>
Sun, 4 Sep 2022 11:56:50 +0000 (13:56 +0200)
committerGerman Service Network <support@gsnw.de>
Sun, 4 Sep 2022 11:56:50 +0000 (13:56 +0200)
raspbmirror.py

index 3c516224a6a064c2c7134f22756f3b24dff91e5c..cf0b56df6d5597fdb043f864513345da12ff74ed 100644 (file)
@@ -7,7 +7,7 @@ import os
 import sys
 import hashlib
 import gzip
-import urllib.request
+
 import stat
 from collections import deque
 from collections import OrderedDict
@@ -45,6 +45,14 @@ parser.add_argument("--nolock", help="don't try to lock the target directory", a
 
 parser.add_argument("--repair", help="during mirroring, verify that all on-disk files match the expected sha256", action="store_true")
 
+parser.add_argument("--urllib", help="force usage of the builtin urllib module, even if urllib3 is present", action="store_true")
+
+parser.add_argument("--urllib3", help="force usage of the urllib3 module, panics if the dependency is missing", action="store_true")
+
+parser.add_argument("--ipv4", help="force usage of IPv4 addresses. Requires urllib3", action="store_true")
+
+parser.add_argument("--ipv6", help="force usage of IPv6 addresses. Requires urllib3", action="store_true")
+
 args = parser.parse_args()
 
 if not args.nolock:
@@ -55,6 +63,58 @@ dtNow = datetime.now()
 logpath = os.path.dirname(os.path.realpath(__file__))
 logging.basicConfig(filename=logpath+'/'+dtNow.strftime("%Y-%m-%d")+'_raspbmirror.log',format='%(asctime)s %(levelname)s: %(message)s', level=logging.DEBUG)
 
+if args.urllib and args.urllib3:
+       logging.error('error: flags --urllib and --urllib3 are in conflict')
+       exit(1)
+
+if args.urllib:
+       import urllib.request
+       use_urllib3 = False
+elif args.urllib3:
+       import urllib3
+       use_urllib3 = True
+else:
+       # auto detect urllib3
+       try:
+               import urllib3
+               use_urllib3 = True
+       except:
+               import urllib.request
+               use_urllib3 = False
+
+if args.ipv4 and args.ipv6:
+       logging.error('error: flags --ipv4 and --ipv6 are in conflict')
+       exit(1)
+
+if use_urllib3:
+       # the number of pools should be greater than the number of concurrently used sites.
+       # 10 should be safe.
+       dlmanager = urllib3.PoolManager(num_pools=10)
+       logging.info('info: using urllib3')
+
+       # a fairly hacky way to force the usage of ipv4 or ipv6 addresses
+       # https://stackoverflow.com/questions/33046733/force-requests-to-use-ipv4-ipv6
+       if args.ipv4:
+               import socket
+               import requests.packages.urllib3.util.connection as urllib3_cn
+               def allowed_gai_family():
+                       return socket.AF_INET
+               urllib3_cn.allowed_gai_family = allowed_gai_family
+       elif args.ipv6:
+               import socket
+               import requests.packages.urllib3.util.connection as urllib3_cn
+               def allowed_gai_family():
+                       return socket.AF_INET6
+               urllib3_cn.allowed_gai_family = allowed_gai_family
+else:
+       logging.info('info: using urllib')
+       if args.ipv4:
+               logging.error('error: flag --ipv4 requires the urllib3 package')
+               exit(1)
+       elif args.ipv6:
+               logging.error('error: flag --ipv6 requires the urllib3 package')
+               exit(1)
+
 def addfilefromdebarchive(filestoverify,filequeue,filename,sha256,size):
        size = int(size)
        sha256andsize = [sha256,size,'M']
@@ -111,11 +171,15 @@ def ensuresafepath(path):
                        sys.exit(1)
 
 def geturl(fileurl):
-       with urllib.request.urlopen(fileurl.decode('ascii')) as response:
-               data = response.read()
+       if use_urllib3:
+               response = dlmanager.request("GET", fileurl.decode('ascii'))
                ts = getts(fileurl, response)
-       return (data,ts)
-
+               return (response.data, ts)
+       else:
+               with urllib.request.urlopen(fileurl.decode('ascii')) as response:
+                       data = response.read()
+                       ts = getts(fileurl, response)
+               return (data, ts)
 
 def getts(fileurl, response):
        if fileurl[:7] == b'file://':
@@ -285,7 +349,6 @@ bs = 16 * 1024 * 1024
 def getandcheckfile(fileurl, sha256, size, path, outputpath, errorfromstr, errorsuffix):
        f = None
        try:
-
                sha256hash = hashlib.sha256()
                if path == outputpath:
                        writepath = makenewpath(path)
@@ -298,18 +361,27 @@ def getandcheckfile(fileurl, sha256, size, path, outputpath, errorfromstr, error
                                'ascii') + ' to ' + outputpath.decode(
                                'ascii') + viamsg)
                f = open(writepath, 'wb')
-               with urllib.request.urlopen(fileurl.decode('ascii')) as response:
-                       l = bs
+               if use_urllib3:
+                       response = dlmanager.request("GET", fileurl.decode('ascii'), preload_content=False)
+                       ts = getts(fileurl, response)
                        tl = 0
-                       while l == bs:
-                               data = response.read(bs)
+                       for data in response.stream(bs):
+                               tl += len(data)
                                f.write(data)
-                               l = len(data)
-                               tl += l
                                sha256hash.update(data)
-                       ts = getts(fileurl, response)
-
-                       data = ... #used as a flag to indicate that the data is written to disk rather than stored in memory
+                       response.release_conn()
+               else:
+                       with urllib.request.urlopen(fileurl.decode('ascii')) as response:
+                               l = bs
+                               tl = 0
+                               while l == bs:
+                                       data = response.read(bs)
+                                       f.write(data)
+                                       l = len(data)
+                                       tl += l
+                                       sha256hash.update(data)
+                               ts = getts(fileurl, response)
+               data = ... #used as a flag to indicate that the data is written to disk rather than stored in memory
                f.close()
                if not testandreporthash(sha256hash, sha256, 'hash mismatch while downloading file' + errorfromstr + ' ', path,
                                                         errorsuffix):