Custom Codecs¶
This tutorial covers extending DataJoint's type system. You'll learn:
- Codec basics — Encoding and decoding
- Creating codecs — Domain-specific types
- Codec chaining — Composing codecs
In [1]:
Copied!
import datajoint as dj
import numpy as np
schema = dj.Schema('tutorial_codecs')
import datajoint as dj
import numpy as np
schema = dj.Schema('tutorial_codecs')
[2026-02-06 11:44:17] DataJoint 2.1.0 connected to datajoint@127.0.0.1:5432
Creating a Custom Codec¶
In [2]:
Copied!
import networkx as nx
class GraphCodec(dj.Codec):
"""Store NetworkX graphs."""
name = "graph" # Use as <graph>
def get_dtype(self, is_store: bool) -> str:
return "<blob>"
def encode(self, value, *, key=None, store_name=None):
return {'nodes': list(value.nodes(data=True)), 'edges': list(value.edges(data=True))}
def decode(self, stored, *, key=None):
g = nx.Graph()
g.add_nodes_from(stored['nodes'])
g.add_edges_from(stored['edges'])
return g
def validate(self, value):
if not isinstance(value, nx.Graph):
raise TypeError(f"Expected nx.Graph")
import networkx as nx
class GraphCodec(dj.Codec):
"""Store NetworkX graphs."""
name = "graph" # Use as
def get_dtype(self, is_store: bool) -> str:
return ""
def encode(self, value, *, key=None, store_name=None):
return {'nodes': list(value.nodes(data=True)), 'edges': list(value.edges(data=True))}
def decode(self, stored, *, key=None):
g = nx.Graph()
g.add_nodes_from(stored['nodes'])
g.add_edges_from(stored['edges'])
return g
def validate(self, value):
if not isinstance(value, nx.Graph):
raise TypeError(f"Expected nx.Graph")
In [3]:
Copied!
@schema
class Connectivity(dj.Manual):
definition = """
conn_id : int32
---
network : <graph>
"""
@schema
class Connectivity(dj.Manual):
definition = """
conn_id : int32
---
network :
"""
In [4]:
Copied!
# Create and insert
g = nx.Graph()
g.add_edges_from([(1, 2), (2, 3), (1, 3)])
Connectivity.insert1({'conn_id': 1, 'network': g})
# Fetch
result = (Connectivity & {'conn_id': 1}).fetch1('network')
print(f"Type: {type(result)}")
print(f"Edges: {list(result.edges())}")
# Create and insert
g = nx.Graph()
g.add_edges_from([(1, 2), (2, 3), (1, 3)])
Connectivity.insert1({'conn_id': 1, 'network': g})
# Fetch
result = (Connectivity & {'conn_id': 1}).fetch1('network')
print(f"Type: {type(result)}")
print(f"Edges: {list(result.edges())}")
Type: <class 'networkx.classes.graph.Graph'> Edges: [(1, 2), (1, 3), (2, 3)]
Codec Structure¶
class MyCodec(dj.Codec):
name = "mytype" # Use as <mytype>
def get_dtype(self, is_store: bool) -> str:
return "<blob>" # Storage type
def encode(self, value, *, key=None, store_name=None):
return serializable_data
def decode(self, stored, *, key=None):
return python_object
def validate(self, value): # Optional
pass
Example: Spike Train¶
In [5]:
Copied!
from dataclasses import dataclass
@dataclass
class SpikeTrain:
times: np.ndarray
unit_id: int
quality: str
class SpikeTrainCodec(dj.Codec):
name = "spike_train"
def get_dtype(self, is_store: bool) -> str:
return "<blob>"
def encode(self, value, *, key=None, store_name=None):
return {'times': value.times, 'unit_id': value.unit_id, 'quality': value.quality}
def decode(self, stored, *, key=None):
return SpikeTrain(times=stored['times'], unit_id=stored['unit_id'], quality=stored['quality'])
from dataclasses import dataclass
@dataclass
class SpikeTrain:
times: np.ndarray
unit_id: int
quality: str
class SpikeTrainCodec(dj.Codec):
name = "spike_train"
def get_dtype(self, is_store: bool) -> str:
return ""
def encode(self, value, *, key=None, store_name=None):
return {'times': value.times, 'unit_id': value.unit_id, 'quality': value.quality}
def decode(self, stored, *, key=None):
return SpikeTrain(times=stored['times'], unit_id=stored['unit_id'], quality=stored['quality'])
In [6]:
Copied!
@schema
class Unit(dj.Manual):
definition = """
unit_id : int32
---
spikes : <spike_train>
"""
train = SpikeTrain(times=np.sort(np.random.uniform(0, 100, 50)), unit_id=1, quality='good')
Unit.insert1({'unit_id': 1, 'spikes': train})
result = (Unit & {'unit_id': 1}).fetch1('spikes')
print(f"Type: {type(result)}, Spikes: {len(result.times)}")
@schema
class Unit(dj.Manual):
definition = """
unit_id : int32
---
spikes :
"""
train = SpikeTrain(times=np.sort(np.random.uniform(0, 100, 50)), unit_id=1, quality='good')
Unit.insert1({'unit_id': 1, 'spikes': train})
result = (Unit & {'unit_id': 1}).fetch1('spikes')
print(f"Type: {type(result)}, Spikes: {len(result.times)}")
Type: <class '__main__.SpikeTrain'>, Spikes: 50
In [7]:
Copied!
schema.drop(prompt=False)
schema.drop(prompt=False)