Generating Sequences Pre-conditioned on Embeddings from MSA Transformer

#5
by hunarbatra - opened

Hi,

Is it possible to condition ProtGPT2 sequence generation with pre-computed embeddings from another model (eg: MSA Transformer) - so that we could generate a sequence of logits conditioned on the given embeddings (by another model)? If yes, then could you please guide with how to implement this?

Thank you so much :)

Hi!
I know someone who tried to do something similar to this. Let me see if I can get to talk to this person and see if, in the end, they managed to do it :)

Sure, that'll be great if you could check about that and get back to me at the earliest :)
Thank you so much! :)

Not sure when they'll reply, but I found this issue on GitHub that might be of your interest: https://github.com/huggingface/transformers/issues?q=is%3Aissue+protgpt2+is%3Aclosed
I think you can pass input_embeds to the model as well. And from there, get the logits. I don't know if it would be possible to generate new sequences, but I hope this for now maybe helps a bit.

Thank you so much for sharing this :).
Can the returned logits be passed to ProtGPT2 for generating sequences somehow?

(update from what I found out: inputs_embeds cannot be passed to decoder only (autoregressive) models like GPT2)

Good question, I've been reading for a while and all points out that using LogitsProcessors should be what you need.
From what I've read, a LogitsProcessor arg can be passed to the sample(), beam_search(), etc. functions, but not directly to generate(). Thus I'd recommend using sample(): https://huggingface.co/docs/transformers/v4.21.1/en/main_classes/text_generation#transformers.generation_utils.GenerationMixin.sample.example.

I've had a quick look, but I haven't managed to pass the logits with any of the logits processors they specify in the documentation. Something like this will error:

>>> logits_processor = LogitsProcessorList([LogitsProcessor()])
>>> model.sample(input_ids,pad_token_id=0,logits_processor=out['logits'])
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/path/hface/lib/python3.6/site-packages/transformers/generation_utils.py", line 1906, in sample
    next_token_scores = logits_processor(input_ids, next_token_logits)
TypeError: 'Tensor' object is not callable

I do not have the time right now, but I hope I can have a look again later tonight. Maybe in the meanwhile, this can point you in the right direction, i hope! Let me know if you find a better way :)

Best,
Noelia

Thank you so much Noelia for sharing this! I try it out later today and update you if I'm able to make it work :)

Sign up or log in to comment