class JaxWrapper: def __init__(self, arr): self.arr = arr def __setitem__(self, key, val): return self.arr.at[key].set(val) ....