f8b3b5c19517 draft — Jelmer Vernooij 3 years ago
Add some functions for dealing with bundles.
5 files changed, 125 insertions(+), 10 deletions(-)

M NEWS
M dulwich/bundle.py
M dulwich/pack.py
M dulwich/tests/__init__.py
A => dulwich/tests/test_bundle.py
M NEWS +5 -0
@@ 1,3 1,8 @@ 
+0.20.15	UNRELEASED
+
+ * Add some functions for parsing and writing bundles.
+   (Jelmer Vernooij)
+
 0.20.14	2020-11-26
 
  * Fix some stash functions on Python 3. (Peter Rowlands)

          
M dulwich/bundle.py +55 -10
@@ 21,18 21,33 @@ 
 """Bundle format support.
 """
 
-from .pack import PackData
-from typing import Dict, List, Tuple
+from typing import Dict, List, Tuple, Optional, Union, Sequence
+from .pack import PackData, write_pack_data
 
 
 class Bundle(object):
 
-    version: int
+    version: Optional[int] = None
+
+    capabilities: Dict[str, str] = {}
+    prerequisites: List[Tuple[bytes, str]] = []
+    references: Dict[str, bytes] = {}
+    pack_data: Union[PackData, Sequence[bytes]] = []
 
-    capabilities: Dict[str, str]
-    prerequisites: List[Tuple[bytes, str]]
-    references: Dict[str, bytes]
-    pack_data: PackData
+    def __eq__(self, other):
+        if not isinstance(other, type(self)):
+            return False
+        if self.version != other.version:
+            return False
+        if self.capabilities != other.capabilities:
+            return False
+        if self.prerequisites != other.prerequisites:
+            return False
+        if self.references != other.references:
+            return False
+        if self.pack_data != other.pack_data:
+            return False
+        return True
 
 
 def _read_bundle(f, version):

          
@@ 45,13 60,15 @@ def _read_bundle(f, version):
             line = line[1:].rstrip(b'\n')
             try:
                 key, value = line.split(b'=', 1)
-            except IndexError:
+            except ValueError:
                 key = line
                 value = None
-            capabilities[key] = value
+            else:
+                value = value.decode('utf-8')
+            capabilities[key.decode('utf-8')] = value
             line = f.readline()
     while line.startswith(b'-'):
-        (obj_id, comment) = line[1:].split(b' ', 1)
+        (obj_id, comment) = line[1:].rstrip(b'\n').split(b' ', 1)
         prerequisites.append((obj_id, comment.decode('utf-8')))
         line = f.readline()
     while line != b'\n':

          
@@ 64,6 81,7 @@ def _read_bundle(f, version):
     ret.capabilities = capabilities
     ret.prerequisites = prerequisites
     ret.pack_data = pack_data
+    ret.version = version
     return ret
 
 

          
@@ 76,3 94,30 @@ def read_bundle(f):
         return _read_bundle(f, 3)
     raise AssertionError(
         'unsupported bundle format header: %r' % firstline)
+
+
+def write_bundle(f, bundle):
+    version = bundle.version
+    if version is None:
+        if bundle.capabilities:
+            version = 3
+        else:
+            version = 2
+    if version == 2:
+        f.write(b'# v2 git bundle\n')
+    elif version == 3:
+        f.write(b'# v3 git bundle\n')
+    else:
+        raise AssertionError('unknown version %d' % version)
+    if version == 3:
+        for key, value in bundle.capabilities.items():
+            f.write(b'@' + key.encode('utf-8'))
+            if value is not None:
+                f.write(b'=' + value.encode('utf-8'))
+            f.write(b'\n')
+    for (obj_id, comment) in bundle.prerequisites:
+        f.write(b'-%s %s\n' % (obj_id, comment.encode('utf-8')))
+    for ref, obj_id in bundle.references.items():
+        f.write(b'%s %s\n' % (obj_id, ref))
+    f.write(b'\n')
+    write_pack_data(f, len(bundle.pack_data), iter(bundle.pack_data))

          
M dulwich/pack.py +12 -0
@@ 1072,6 1072,18 @@ class PackData(object):
     def __exit__(self, exc_type, exc_val, exc_tb):
         self.close()
 
+    def __eq__(self, other):
+        if isinstance(other, PackData):
+            return self.get_stored_checksum() == other.get_stored_checksum()
+        if isinstance(other, list):
+            if len(self) != len(other):
+                return False
+            for o1, o2 in zip(self.iterobjects(), other):
+                if o1 != o2:
+                    return False
+            return True
+        return False
+
     def _get_size(self):
         if self._size is not None:
             return self._size

          
M dulwich/tests/__init__.py +1 -0
@@ 102,6 102,7 @@ def self_test_suite():
     names = [
         'archive',
         'blackbox',
+        'bundle',
         'client',
         'config',
         'diff_tree',

          
A => dulwich/tests/test_bundle.py +52 -0
@@ 0,0 1,52 @@ 
+# test_bundle.py -- tests for bundle
+# Copyright (C) 2020 Jelmer Vernooij <jelmer@jelmer.uk>
+#
+# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
+# General Public License as public by the Free Software Foundation; version 2.0
+# or (at your option) any later version. You can redistribute it and/or
+# modify it under the terms of either of these two licenses.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# You should have received a copy of the licenses; if not, see
+# <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
+# and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
+# License, Version 2.0.
+#
+
+"""Tests for bundle support."""
+
+import os
+import tempfile
+
+from dulwich.tests import (
+    TestCase,
+    )
+
+from dulwich.bundle import (
+    Bundle,
+    read_bundle,
+    write_bundle,
+    )
+
+
+class BundleTests(TestCase):
+
+    def test_roundtrip_bundle(self):
+        origbundle = Bundle()
+        origbundle.version = 3
+        origbundle.capabilities = {'foo': None}
+        origbundle.references = {b'refs/heads/master': b'ab' * 20}
+        origbundle.prerequisites = [(b'cc' * 20, 'comment')]
+        with tempfile.TemporaryDirectory() as td:
+            with open(os.path.join(td, 'foo'), 'wb') as f:
+                write_bundle(f, origbundle)
+
+            with open(os.path.join(td, 'foo'), 'rb') as f:
+                newbundle = read_bundle(f)
+
+                self.assertEqual(origbundle, newbundle)