File size: 1,786 Bytes
6788682
bc324c0
6788682
 
0afc69c
 
98d0911
6788682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc324c0
2695071
9a228a9
6788682
 
2695071
 
46702b5
2695071
6788682
bc324c0
6788682
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from time import perf_counter
import streamlit as st
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig

def formatted_prompt(input)-> str:
  return f"<|im_start|>user\n{input}<|im_end|>\n<|im_start|>assistant:"

def generate_response(user_input):
    prompt = formatted_prompt(user_input)

    inputs = tokenizer([prompt], return_tensors="pt")
    generation_config = GenerationConfig(
        penalty_alpha=0.6,
        do_sample=True,
        top_k=5,
        temperature=0.5,
        repetition_penalty=1.2,
        max_new_tokens=500,
        pad_token_id=tokenizer.eos_token_id
    )
    start_time = perf_counter()

    inputs = tokenizer(prompt, return_tensors="pt").to('cuda')

    outputs = model.generate(**inputs, generation_config=generation_config)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    output_time = perf_counter() - start_time
    st.write(response)
    st.write(f"Time taken for inference: {round(output_time, 2)} seconds")

@st.cache(allow_output_mutation=True)
def load_model_and_tokenizer(model_name, token):
    model = AutoModelForSequenceClassification.from_pretrained(model_name, token=token)
    tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
    return model, tokenizer

# Load your model and tokenizer from Hugging Face
model_name = "orYx-models/finetuned-tiny-llama-medical-papers"
token = "Tinyllama_secret"  # Replace <your_token> with your actual Hugging Face Spaces secret
model, tokenizer = load_model_and_tokenizer(model_name, token)

# Define the pipeline with your model
pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer)

user_input = st.text_area("Enter some text:")

if user_input:
    generate_response(user_input)