Right now if I want to have state after the generation, PIPELINE.generate doesn't allow it if I didn't have it already: default value is None
So I need to do something silly like generate a single token with model.forward and pass state from there or make model wrapper which has forward and remembers the state.
So possible solutions to get the model state from generate:
- Add
return_state argument and be like
if return_state: return out_str, state
return out_str
Major drawback: it breaks clear simple API as now there are two return types, which will make LSPs/linters/etc unhappy and they will complain
- Add argument
callback_with_state: bool and if it's true call callback(tmp, state) rather than callback(tmp)
Drawback: it's not always called, but I'm not sure it's that big deal
- Rename function to
generate_with_state which returns out_str and state. generate replace with a generate_with_state(...)[0]
This will keep the existing API
- Store the state. Make a field
PIPELINE.last_state and store state there.
Wastes memory when is not always desirable.