File size: 2,807 Bytes
d425e71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
"""janus.py.
File for providing the Janus model implementation.
"""
import torch
from transformers import JanusForConditionalGeneration, JanusProcessor
from src.models.base import ModelBase
from src.models.config import Config
class JanusModel(ModelBase):
"""Janus model implementation."""
def __init__(self, config: Config) -> None:
"""Initialize the Janus model.
Args:
config (Config): Parsed config.
"""
super().__init__(config)
def _load_specific_model(self) -> None:
"""Populate self.model with the specified Janus model."""
# require this import to force the models script to load
self.model = (
JanusForConditionalGeneration.from_pretrained(
self.model_path,
**self.config.model
) if hasattr(self.config, 'model') else
JanusForConditionalGeneration.from_pretrained(
self.model_path,
)
)
self.model.to(torch.bfloat16)
def _init_processor(self) -> None:
"""Initialize the Janus processor."""
self.processor = JanusProcessor.from_pretrained(self.model_path)
def _generate_prompt(self, prompt: str) -> str:
"""Generates the prompt string with the input messages.
Args:
prompt (str): prompt content.
Returns:
str: Returns the prompt content as is.
"""
return prompt
def _generate_processor_output(self, prompt: str, img_path: str) -> dict:
"""Override the base function to produce processor arguments for Janus.
Args:
prompt (str): The input prompt to be processed.
img_path (str): The path to the image to be processed.
Returns:
dict: The formatted inputs for the processor.
"""
# Do the _generate_prompt first
messages = [
{
'role': 'user',
'content': [
{'type': 'image', 'image': img_path},
{'type': 'text', 'text': prompt}
]
}
]
inputs = self.processor.apply_chat_template(
messages,
add_generation_prompt=True,
generation_mode='text',
tokenize=True,
return_dict=True,
return_tensors='pt',
).to(self.config.device, dtype=torch.bfloat16)
return inputs
def _forward(self, data: dict) -> None:
"""Given some input data, performs a single forward pass.
This function itself can be overriden, while _hook_and_eval
should be left in tact.
Args:
data (dict): The given data tensor.
"""
_ = self.model.generate(**data, **self.config.forward)
|