import os from typing import Tuple import streamlit as st import torch import transformers import tokenizers from sampling import CAIFSampler, TopKWithTemperatureSampler from generator import Generator device = "cuda" if torch.cuda.is_available() else "cpu" def main(): st.subheader("CAIF") cls_model_name = st.selectbox( 'Выберите модель классификации', ('tinkoff-ai/response-quality-classifier-tiny', 'tinkoff-ai/response-quality-classifier-base', 'tinkoff-ai/response-quality-classifier-large') ) lm_model_name = st.selectbox( 'Выберите языковую модель', ('sberbank-ai/rugpt3small_based_on_gpt2',) ) prompt = st.text_input("Начало текста:", "Привет") auth_token = os.environ.get('TOKEN') or True with st.spinner('Running inference...'): text = inference(lm_model_name=lm_model_name, cls_model_name=cls_model_name, prompt=prompt) @st.cache(hash_funcs={str: lambda lm_model_name: hash(lm_model_name)}, allow_output_mutation=True) def load_generator(lm_model_name: str) -> Generator: with st.spinner('Loading language model...'): generator = Generator(lm_model_name=lm_model_name, device=device) return generator def load_sampler(cls_model_name, lm_tokenizer): with st.spinner('Loading classifier model...'): sampler = CAIFSampler(classifier_name=cls_model_name, lm_tokenizer=lm_tokenizer, device=device) return sampler @st.cache def inference(lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool = True) -> str: generator = load_generator(lm_model_name=lm_model_name) lm_tokenizer = transformers.AutoTokenizer.from_pretrained(lm_model_name) caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer) generator.set_caif_sampler(caif_sampler) ordinary_sampler = TopKWithTemperatureSampler() kwargs = { "top_k": 20, "temperature": 1.0, "top_k_classifier": 100, "classifier_weight": 5, } generator.set_ordinary_sampler(ordinary_sampler) if device == "cpu": autocast = torch.cpu.amp.autocast else: autocast = torch.cuda.amp.autocast with autocast(fp16): print(f"Generating for prompt: {prompt}") sequences, tokens = generator.sample_sequences( num_samples=1, input_prompt=prompt, max_length=20, caif_period=1, caif_tokens_num=100, entropy=None, **kwargs ) print(f"Output for prompt: {sequences}") return sequences[0] if __name__ == "__main__": main()