| The updated version gives me this (after successful setup with the example alien thing): UnfilteredStackTrace Traceback (most recent call last)
<ipython-input-2-0e20e3adf861> in <module>()
2
----> 3 image = generate_image_from_text("alien life", seed=7)
4 display(image) 67 frames
UnfilteredStackTrace: TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float16, float32. The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified. -------------------- The above exception was the direct cause of the following exception: TypeError Traceback (most recent call last)
/content/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py in __call__(self, decoder_state, keys_state, values_state, attention_mask, state_index)
38 keys_state,
39 self.k_proj(decoder_state).reshape(shape_split),
---> 40 state_index
41 )
42 values_state = lax.dynamic_update_slice( TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float16, float32. |