Hi
@TylerHilbert
, thanks for the comment. I think it would be possible to use some pallas kernels to further optimize inference, such as flash attention. I haven't tried myself, but they should bring good performance improvements. Also, you could use a bigger batch size, I think it should fir in this TPU without issues (I did not explore all configurations).