Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                
Skip to content

Commit

Permalink
Make SDK models pickleable (#8746)
Browse files Browse the repository at this point in the history
Currently, they aren't (or rather, they can be pickled, but unpickling
fails). This is due to a small quirk of how the model classes work, and
is easily worked around.

In addition, don't create a new Configuration object for each model.
These objects are pretty beefy, and they increase the size of each
pickle by a full kilobyte (and of course they increase memory usage even
when pickle is not involved). AFAICS, these objects are only used when
assigning values to file-type fields, and it's easy enough to rewrite
the logic so that it still works when the model's `_configuration` field
is None.
  • Loading branch information
SpecLad authored Nov 28, 2024
1 parent 9091be8 commit 86deaff
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 30 deletions.
4 changes: 4 additions & 0 deletions changelog.d/20241127_132256_roman_pickle_models.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
### Added

- \[SDK\] Model instances can now be pickled
(<https://github.com/cvat-ai/cvat/pull/8746>)
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
{{name}} ({{{dataType}}}):{{#description}} {{{.}}}.{{/description}} [optional]{{#defaultValue}} if omitted the server will use the default value of {{{.}}}{{/defaultValue}} # noqa: E501
{{/optionalVars}}
"""
from {{packageName}}.configuration import Configuration

{{#requiredVars}}
{{#defaultValue}}
Expand All @@ -32,7 +31,7 @@
_check_type = kwargs.pop('_check_type', True)
_spec_property_naming = kwargs.pop('_spec_property_naming', False)
_path_to_item = kwargs.pop('_path_to_item', ())
_configuration = kwargs.pop('_configuration', Configuration())
_configuration = kwargs.pop('_configuration', None)
_visited_composed_classes = kwargs.pop('_visited_composed_classes', ())

self = super(OpenApiModel, cls).__new__(cls)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
{{/optionalVars}}
{{> model_templates/docstring_init_required_kwargs }}
"""
from {{packageName}}.configuration import Configuration

{{#requiredVars}}
{{#defaultValue}}
Expand All @@ -37,7 +36,7 @@
_check_type = kwargs.pop('_check_type', True)
_spec_property_naming = kwargs.pop('_spec_property_naming', True)
_path_to_item = kwargs.pop('_path_to_item', ())
_configuration = kwargs.pop('_configuration', Configuration())
_configuration = kwargs.pop('_configuration', None)
_visited_composed_classes = kwargs.pop('_visited_composed_classes', ())

self = super(OpenApiModel, cls).__new__(cls)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
value ({{{dataType}}}):{{#description}} {{{.}}}.{{/description}}{{#defaultValue}} if omitted defaults to {{{.}}}{{/defaultValue}}{{#allowableValues}}, must be one of [{{#enumVars}}{{{value}}}, {{/enumVars}}]{{/allowableValues}} # noqa: E501
{{> model_templates/docstring_init_required_kwargs }}
"""
from {{packageName}}.configuration import Configuration

# required up here when default value is not given
_path_to_item = kwargs.pop('_path_to_item', ())
Expand All @@ -39,7 +38,7 @@

_check_type = kwargs.pop('_check_type', True)
_spec_property_naming = kwargs.pop('_spec_property_naming', False)
_configuration = kwargs.pop('_configuration', Configuration())
_configuration = kwargs.pop('_configuration', None)
_visited_composed_classes = kwargs.pop('_visited_composed_classes', ())

{{> model_templates/invalid_pos_args }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
{{/optionalVars}}
{{> model_templates/docstring_init_required_kwargs }}
"""
from {{packageName}}.configuration import Configuration

{{#requiredVars}}
{{^isReadOnly}}
Expand All @@ -42,7 +41,7 @@
_check_type = kwargs.pop('_check_type', True)
_spec_property_naming = kwargs.pop('_spec_property_naming', False)
_path_to_item = kwargs.pop('_path_to_item', ())
_configuration = kwargs.pop('_configuration', Configuration())
_configuration = kwargs.pop('_configuration', None)
_visited_composed_classes = kwargs.pop('_visited_composed_classes', ())

{{> model_templates/invalid_pos_args }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
value ({{{dataType}}}):{{#description}} {{{.}}}.{{/description}}{{#defaultValue}} if omitted defaults to {{{.}}}{{/defaultValue}}{{#allowableValues}}, must be one of [{{#enumVars}}{{{value}}}, {{/enumVars}}]{{/allowableValues}} # noqa: E501
{{> model_templates/docstring_init_required_kwargs }}
"""
from {{packageName}}.configuration import Configuration

# required up here when default value is not given
_path_to_item = kwargs.pop('_path_to_item', ())
Expand All @@ -45,7 +44,7 @@

_check_type = kwargs.pop('_check_type', True)
_spec_property_naming = kwargs.pop('_spec_property_naming', False)
_configuration = kwargs.pop('_configuration', Configuration())
_configuration = kwargs.pop('_configuration', None)
_visited_composed_classes = kwargs.pop('_visited_composed_classes', ())

{{> model_templates/invalid_pos_args }}
Expand Down
41 changes: 21 additions & 20 deletions cvat-sdk/gen/templates/openapi-generator/model_utils.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,13 @@ class OpenApiModel(object):
new_inst = new_cls._new_from_openapi_data(*args, **kwargs)
return new_inst

def __setstate__(self, state):
# This is the same as the default implementation. We override it,
# because unpickling attempts to access `obj.__setstate__` on an uninitialized
# object, and if this method is not defined, it results in a call to `__getattr__`.
# This fails, because `__getattr__` relies on `self._data_store`, which doesn't
# exist in an uninitialized object.
self.__dict__.update(state)

class ModelSimple(OpenApiModel):
"""the parent class of models whose type != object in their
Expand Down Expand Up @@ -1084,7 +1091,7 @@ def deserialize_file(response_data, configuration, content_disposition=None):
(file_type): the deserialized file which is open
The user is responsible for closing and reading the file
"""
fd, path = tempfile.mkstemp(dir=configuration.temp_folder_path)
fd, path = tempfile.mkstemp(dir=configuration.temp_folder_path if configuration else None)
os.close(fd)
os.remove(path)

Expand Down Expand Up @@ -1263,27 +1270,21 @@ def validate_and_convert_types(input_value, required_types_mixed, path_to_item,
input_class_simple = get_simple_class(input_value)
valid_type = is_valid_type(input_class_simple, valid_classes)
if not valid_type:
if (configuration
or (input_class_simple == dict
and dict not in valid_classes)):
# if input_value is not valid_type try to convert it
converted_instance = attempt_convert_item(
input_value,
valid_classes,
path_to_item,
configuration,
spec_property_naming,
key_type=False,
must_convert=True,
check_type=_check_type
)
return converted_instance
else:
raise get_type_error(input_value, path_to_item, valid_classes,
key_type=False)
# if input_value is not valid_type try to convert it
converted_instance = attempt_convert_item(
input_value,
valid_classes,
path_to_item,
configuration,
spec_property_naming,
key_type=False,
must_convert=True,
check_type=_check_type
)
return converted_instance

# input_value's type is in valid_classes
if len(valid_classes) > 1 and configuration:
if len(valid_classes) > 1:
# there are valid classes which are not the current class
valid_classes_coercible = remove_uncoercible(
valid_classes, input_value, spec_property_naming, must_convert=False)
Expand Down
10 changes: 10 additions & 0 deletions tests/python/sdk/test_api_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: MIT

import pickle
from copy import deepcopy

from cvat_sdk import models
Expand Down Expand Up @@ -112,3 +113,12 @@ def test_models_do_not_return_internal_collections():
model_data2 = model.to_dict()

assert DeepDiff(model_data1_original, model_data2) == {}


def test_models_are_pickleable():
model = models.PatchedLabelRequest(id=5, name="person")
pickled_model = pickle.dumps(model)
unpickled_model = pickle.loads(pickled_model)

assert unpickled_model.id == model.id
assert unpickled_model.name == model.name

0 comments on commit 86deaff

Please sign in to comment.