My problem was: I had a number of databases generated in different machines and I wanted to query them as if they were one, using the database the data came from as a field while querying and while showing results. The databases are SQLite3 files, generated using SQLAlchemy in a Python program.

I solved this by using SQLAlchemy, which was good because I could use the same ORM mapping that the program used. I noticed that the Horizontal Sharding SQLAlchemy extension would fit well the problem, although not perfectly. I had to make some changes in some classes of this extension, and now it works fine.

It was possible to filter the data using database as a criteria, but I couldn't get the database information from each line of a query result. I made a simple patch to SQLAlchemy, which wasn't likely to be introduced in the distribution, but worked for me, and sent it to its bug tracker. The change was included in SQLAlchemy in a very different fashion, as expected, but since I'm using the released version from SQLAlchemy, I kept on using my version of the patch. I don't want to do direct changes in SQLAlchemy source code, so I made the change in my program:

class ShardedSessionShardId(ShardedSession):
    def __init__(self, *args, **kwargs):
        super(ShardedSessionShardId, self).__init__(*args, **kwargs)
        self._query_cls = ShardedQueryShardId

class ShardedQueryShardId(ShardedQuery):
    def _execute_and_instances(self, context):
        if self._shard_id is not None:
            result = self.session.connection(
                            mapper=self._mapper_zero(),
                            shard_id=self._shard_id).execute(context.statement, self._params)

            news = list(self.instances(result, context))
            for new in news:
                new.shard_id = self._shard_id
            return iter(news)

        else:
            partial = []
            for shard_id in self.query_chooser(self):
                result = self.session.connection(
                            mapper=self._mapper_zero(),
                            shard_id=shard_id).execute(context.statement, self._params)

                news = list(self.instances(result, context))
                for new in news:
                    new.shard_id = shard_id
                partial = partial + news

            # if some kind of in memory 'sorting'
            # were done, this is where it would happen
            return iter(partial)

create_session = sessionmaker(class_=ShardedSessionShardId)

Another problem is that I had to make each result be included in the query, even if two results from different DBs have the same primary key. I achieved this by changing two classes: WeakInstanceDict, and Mapper. For using the new WeakInstanceDict, I had again to change the ShardedSession variation:

class WeakInstanceDictNoIdentity(WeakInstanceDict):
    def add(self, state):
        # if state.key in self:
        #     if dict.__getitem__(self, state.key) is not state:
        #         raise AssertionError("A conflicting state is already "
        #                             "present in the identity map for key %r"
        #                             % (state.key, ))
        # else:
            dict.__setitem__(self, state.key, state)
            self._manage_incoming_state(state)

class ShardedSessionShardId(ShardedSession):
    def __init__(self, *args, **kwargs):
        super(ShardedSessionShardId, self).__init__(*args, **kwargs)
        self._query_cls = ShardedQueryShardId
        self._identity_cls = WeakInstanceDictNoIdentity
        self.identity_map = self._identity_cls()

To start using the new Mapper, I simply replaced each call to mapper with MapperNoIdentity:

class MapperNoIdentity(Mapper):
    def _instance_processor(self, context, path, adapter,
                                polymorphic_from=None, extension=None,
                                only_load_props=None, refresh_state=None,
                                polymorphic_discriminator=None):

        """Produce a mapper level row processor callable
           which processes rows into mapped instances."""

        pk_cols = self.primary_key

        if polymorphic_from or refresh_state:
            polymorphic_on = None
        else:
            if polymorphic_discriminator is not None:
                polymorphic_on = polymorphic_discriminator
            else:
                polymorphic_on = self.polymorphic_on
            polymorphic_instances = util.PopulateDict(
                                        self._configure_subclass_mapper(
                                                context, path, adapter)
                                        )

        version_id_col = self.version_id_col

        if adapter:
            pk_cols = [adapter.columns[c] for c in pk_cols]
            if polymorphic_on is not None:
                polymorphic_on = adapter.columns[polymorphic_on]
            if version_id_col is not None:
                version_id_col = adapter.columns[version_id_col]

        identity_class = self._identity_class
        def identity_key(row):
            return identity_class, tuple([row[column] for column in pk_cols])

        new_populators = []
        existing_populators = []
        load_path = context.query._current_path + path

        def populate_state(state, dict_, row, isnew, only_load_props):
            if isnew:
                if context.propagate_options:
                    state.load_options = context.propagate_options
                if state.load_options:
                    state.load_path = load_path

            if not new_populators:
                new_populators[:], existing_populators[:] = \
                                    self._populators(context, path, row,
                                                        adapter)

            if isnew:
                populators = new_populators
            else:
                populators = existing_populators

            if only_load_props:
                populators = [p for p in populators
                                if p[0] in only_load_props]

            for key, populator in populators:
                populator(state, dict_, row)

        session_identity_map = context.session.identity_map

        if not extension:
            extension = self.extension

        translate_row = extension.get('translate_row', None)
        create_instance = extension.get('create_instance', None)
        populate_instance = extension.get('populate_instance', None)
        append_result = extension.get('append_result', None)
        populate_existing = context.populate_existing or self.always_refresh
        if self.allow_partial_pks:
            is_not_primary_key = _none_set.issuperset
        else:
            is_not_primary_key = _none_set.issubset

        def _instance(row, result):
            if translate_row:
                ret = translate_row(self, context, row)
                if ret is not EXT_CONTINUE:
                    row = ret

            if polymorphic_on is not None:
                discriminator = row[polymorphic_on]
                if discriminator is not None:
                    _instance = polymorphic_instances[discriminator]
                    if _instance:
                        return _instance(row, result)

            # determine identity key
            if refresh_state:
                identitykey = refresh_state.key
                if identitykey is None:
                    # super-rare condition; a refresh is being called
                    # on a non-instance-key instance; this is meant to only
                    # occur within a flush()
                    identitykey = self._identity_key_from_state(refresh_state)
            else:
                identitykey = identity_key(row)

            # instance = session_identity_map.get(identitykey)
            # if instance is not None:
            #     state = attributes.instance_state(instance)
            #     dict_ = attributes.instance_dict(instance)

            #     isnew = state.runid != context.runid
            #     currentload = not isnew
            #     loaded_instance = False

            #     if not currentload and \
            #             version_id_col is not None and \
            #             context.version_check and \
            #             self._get_state_attr_by_column(
            #                     state,
            #                     dict_,
            #                     self.version_id_col) != \
            #                             row[version_id_col]:

            #         raise orm_exc.ConcurrentModificationError(
            #                 "Instance '%s' version of %s does not match %s"
            #                 % (state_str(state),
            #                     self._get_state_attr_by_column(
            #                                 state, dict_,
            #                                 self.version_id_col),
            #                         row[version_id_col]))
            # elif refresh_state:
            if refresh_state:
                # out of band refresh_state detected (i.e. its not in the
                # session.identity_map) honor it anyway.  this can happen
                # if a _get() occurs within save_obj(), such as
                # when eager_defaults is True.
                state = refresh_state
                instance = state.obj()
                dict_ = attributes.instance_dict(instance)
                isnew = state.runid != context.runid
                currentload = True
                loaded_instance = False
            else:
                # check for non-NULL values in the primary key columns,
                # else no entity is returned for the row
                if is_not_primary_key(identitykey[1]):
                    return None

                isnew = True
                currentload = True
                loaded_instance = True

                if create_instance:
                    instance = create_instance(self,
                                                context,
                                                row, self.class_)
                    if instance is EXT_CONTINUE:
                        instance = self.class_manager.new_instance()
                    else:
                        manager = attributes.manager_of_class(
                                                instance.__class__)
                        # TODO: if manager is None, raise a friendly error
                        # about returning instances of unmapped types
                        manager.setup_instance(instance)
                else:
                    instance = self.class_manager.new_instance()

                dict_ = attributes.instance_dict(instance)
                state = attributes.instance_state(instance)
                state.key = identitykey

                # manually adding instance to session.  for a complete add,
                # session._finalize_loaded() must be called.
                state.session_id = context.session.hash_key
                session_identity_map.add(state)

            if currentload or populate_existing:
                if isnew:
                    state.runid = context.runid
                    context.progress[state] = dict_

                if not populate_instance or \
                        populate_instance(self, context, row, instance,
                            only_load_props=only_load_props,
                            instancekey=identitykey, isnew=isnew) is \
                            EXT_CONTINUE:
                    populate_state(state, dict_, row, isnew, only_load_props)

            else:
                # populate attributes on non-loading instances which have
                # been expired
                # TODO: apply eager loads to un-lazy loaded collections ?
                if state in context.partials or state.unloaded:

                    if state in context.partials:
                        isnew = False
                        (d_, attrs) = context.partials[state]
                    else:
                        isnew = True
                        attrs = state.unloaded
                        # allow query.instances to commit the subset of attrs
                        context.partials[state] = (dict_, attrs)

                    if not populate_instance or \
                            populate_instance(self, context, row, instance,
                                only_load_props=attrs,
                                instancekey=identitykey, isnew=isnew) is \
                                EXT_CONTINUE:
                        populate_state(state, dict_, row, isnew, attrs)

            if loaded_instance:
                state._run_on_load(instance)

            if result is not None and \
                        (not append_result or
                            append_result(self, context, row, instance,
                                    result, instancekey=identitykey,
                                    isnew=isnew)
                                    is EXT_CONTINUE):
                result.append(instance)

            return instance
        return _instance

I had to include some auxiliary definitions to make the rewrites work:

_none_set = frozenset([None])
_runid = 1L
_id_lock = util.threading.Lock()
def _new_runid():
    global _runid
    _id_lock.acquire()
    try:
        _runid += 1
        return _runid
    finally:
        _id_lock.release()

It would be good to be able to set these identity requirements as a parameter.

My last problem was selecting more than one database to search. setshard only worked for one, so I created a new field in query, called shards, and checked for it on querychooser:

def query_chooser(query):
    try:
        return query.shards
    except AttributeError:
        pass
    return tcs.keys()

So, when I want to look only in a list of shards, I set this field. I'm aware that this is not a recommended python idiom, but, well, it works fine.

I did it this way because I didn't know how to do it another way. I'll try it as you said next time. Thanks.
Comment by marcot Ter 01 Mar 2011 11:01:39 UTC
comment 1 c840f348b20e9184f568a22aeff9f90a
[[!comment Error: unsupported page format html]]
Qui 02 Fev 2012 16:28:18 UTC