GSoC-Super-Rapid-Annotator / src /text_processor.py
ManishThota's picture
Update src/text_processor.py
57cdbdf verified
raw
history blame
1.84 kB
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import warnings
from typing import Dict
import spaces
device = "cuda"
# Ignore warnings
warnings.filterwarnings(action='ignore')
# Set random seed
torch.random.manual_seed(0)
# Define model path and generation arguments
model_path = "microsoft/Phi-3-mini-4k-instruct"
generation_args = {
"max_new_tokens": 50,
"return_full_text": False,
"temperature": 0.1,
"do_sample": True
}
# Load the model and pipeline once and keep it in memory
def load_model_pipeline(model_path: str):
if not hasattr(load_model_pipeline, "pipe"):
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map=device,
torch_dtype="auto",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
load_model_pipeline.pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
return load_model_pipeline.pipe
# Initialize the pipeline and keep it in memory
pipe = load_model_pipeline(model_path)
# Generate output from LLM
@spaces.GPU(duration=50)
def generate_logic(llm_output: str) -> str:
prompt = f"""
Provide a detailed response based on the description: '{llm_output}'.
"""
messages = [
{"role": "system", "content": "Please provide a detailed response."},
{"role": "user", "content": prompt},
]
response = pipe(messages, **generation_args)
generated_text = response[0]['generated_text']
# Log the generated text
print(f"Generated Text: {generated_text}")
return generated_text
# Main function to process LLM output and return raw text
def process_description(description: str) -> str:
generated_output = generate_logic(description)
return generated_output