marstin's picture
[martin-dev] add demo v1 test
d425e71
"""janus.py.
File for providing the Janus model implementation.
"""
import torch
from transformers import JanusForConditionalGeneration, JanusProcessor
from src.models.base import ModelBase
from src.models.config import Config
class JanusModel(ModelBase):
"""Janus model implementation."""
def __init__(self, config: Config) -> None:
"""Initialize the Janus model.
Args:
config (Config): Parsed config.
"""
super().__init__(config)
def _load_specific_model(self) -> None:
"""Populate self.model with the specified Janus model."""
# require this import to force the models script to load
self.model = (
JanusForConditionalGeneration.from_pretrained(
self.model_path,
**self.config.model
) if hasattr(self.config, 'model') else
JanusForConditionalGeneration.from_pretrained(
self.model_path,
)
)
self.model.to(torch.bfloat16)
def _init_processor(self) -> None:
"""Initialize the Janus processor."""
self.processor = JanusProcessor.from_pretrained(self.model_path)
def _generate_prompt(self, prompt: str) -> str:
"""Generates the prompt string with the input messages.
Args:
prompt (str): prompt content.
Returns:
str: Returns the prompt content as is.
"""
return prompt
def _generate_processor_output(self, prompt: str, img_path: str) -> dict:
"""Override the base function to produce processor arguments for Janus.
Args:
prompt (str): The input prompt to be processed.
img_path (str): The path to the image to be processed.
Returns:
dict: The formatted inputs for the processor.
"""
# Do the _generate_prompt first
messages = [
{
'role': 'user',
'content': [
{'type': 'image', 'image': img_path},
{'type': 'text', 'text': prompt}
]
}
]
inputs = self.processor.apply_chat_template(
messages,
add_generation_prompt=True,
generation_mode='text',
tokenize=True,
return_dict=True,
return_tensors='pt',
).to(self.config.device, dtype=torch.bfloat16)
return inputs
def _forward(self, data: dict) -> None:
"""Given some input data, performs a single forward pass.
This function itself can be overriden, while _hook_and_eval
should be left in tact.
Args:
data (dict): The given data tensor.
"""
_ = self.model.generate(**data, **self.config.forward)