Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, pipeline | |
| from threading import Thread | |
| model_id = "rasyosef/Llama-3.2-180M-Amharic-Instruct" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained(model_id) | |
| llama_am = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| # Function that accepts a prompt and generates text using the phi2 pipeline | |
| def generate(message, chat_history, max_new_tokens=256): | |
| history = [] | |
| for sent, received in chat_history: | |
| history.append({"role": "user", "content": sent}) | |
| history.append({"role": "assistant", "content": received}) | |
| history.append({"role": "user", "content": message}) | |
| #print(history) | |
| if len(tokenizer.apply_chat_template(history)) > 512: | |
| yield "chat history is too long" | |
| else: | |
| # Streamer | |
| streamer = TextIteratorStreamer(tokenizer=tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=300.0) | |
| thread = Thread(target=llama_am, | |
| kwargs={ | |
| "text_inputs":history, | |
| "max_new_tokens":max_new_tokens, | |
| "repetition_penalty":1.1, | |
| "streamer":streamer | |
| } | |
| ) | |
| thread.start() | |
| generated_text = "" | |
| for word in streamer: | |
| generated_text += word | |
| response = generated_text.strip() | |
| yield response | |
| # Chat interface with gradio | |
| with gr.Blocks() as demo: | |
| gr.Markdown(""" | |
| # Llama 3.2 180M Amharic Chatbot Demo | |
| This chatbot was created using [Llama-3.2-180M-Amharic-Instruct](https://huggingface.co/rasyosef/Llama-3.2-180M-Amharic-Instruct), a finetuned version of my 180 million parameter [Llama 3.2 180M Amharic](https://huggingface.co/rasyosef/Llama-3.2-180M-Amharic) transformer model. | |
| """) | |
| tokens_slider = gr.Slider(8, 256, value=64, label="Maximum new tokens", info="A larger `max_new_tokens` parameter value gives you longer text responses but at the cost of a slower response time.") | |
| chatbot = gr.ChatInterface( | |
| chatbot=gr.Chatbot(height=400), | |
| fn=generate, | |
| additional_inputs=[tokens_slider], | |
| stop_btn=None, | |
| cache_examples=False, | |
| examples=[ | |
| ["የኢትዮጵያ ዋና ከተማ ስም ምንድን ነው?"], | |
| ["የኢትዮጵያ የመጨረሻው ንጉስ ማን ነበሩ?"], | |
| ["የፈረንሳይ ዋና ከተማ ስም ምንድን ነው?"], | |
| ["አሁን የአሜሪካ ፕሬዚዳንት ማን ነው?"], | |
| ["የእስራኤል ጠቅላይ ሚንስትር ማን ነው?"], | |
| ["ሶስት የአፍሪካ ሀገራት ጥቀስልኝ"], | |
| ["3 የአሜሪካ መሪዎችን ስም ጥቀስ"], | |
| ["5 የአሜሪካ ከተማዎችን ጥቀስ"], | |
| ["አምስት የአውሮፓ ሀገራት ጥቀስልኝ"], | |
| ["የኢትዮጵያ ፕሬዝዳንት ማን ነው?"], | |
| ["በ ዓለም ላይ ያሉትን 7 አህጉራት ንገረኝ"] | |
| ] | |
| ) | |
| demo.queue().launch(debug=True) |