Spaces:
Paused
Paused
File size: 5,413 Bytes
59812f5 27afe77 297485e 141ba59 c86c2f3 db22f97 c86c2f3 d2d3f64 ed082d8 0f4b183 c86c2f3 297485e 4522cd0 04c002a 59812f5 4522cd0 297485e 6a31392 297485e 4522cd0 e6dd388 cd39df4 96cee4f e6dd388 50a1316 297485e c86c2f3 09b3f75 c86c2f3 297485e 04c002a 297485e c11dcf8 297485e c11dcf8 297485e 1827259 297485e 28d8d0f 297485e db22f97 297485e db22f97 297485e 3856850 e678653 01c9b0c 297485e 27afe77 01c9b0c 27afe77 297485e d2d3f64 4522cd0 c86c2f3 776bd38 297485e 55c5ebc 297485e 85862c6 141ba59 776bd38 64868e1 297485e 54995d2 6bc8e25 297485e 5a192a9 141ba59 54995d2 141ba59 85862c6 141ba59 85862c6 141ba59 c86c2f3 141ba59 f6ff388 141ba59 27afe77 f6ff388 4c4df5c db22f97 297485e d558cef db22f97 c86c2f3 297485e 0f4b183 297485e 1661753 0f4b183 1827259 297485e 141ba59 0f4b183 e6dd388 297485e 89f9579 cd39df4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import os
import re
import torch
from threading import Thread
from typing import Iterator
from mongoengine import connect, Document, StringField, SequenceField
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
from peft import PeftModel
# Constants
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 950
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
# Description and License Texts
DESCRIPTION = """
# ✨Storytell AI🧑🏽💻
Welcome to the **Storytell AI** space, crafted with care by Ranam & George. Dive into the world of educational storytelling with our model. This iteration of the Llama 2 model with 7 billion parameters is fine-tuned to generate educational stories that engage and educate. Enjoy a journey of discovery and creativity—your storytelling lesson begins here! You can prompt this model to explain any computer science concept. **Please check the examples below**.
"""
LICENSE = """
---
As a derivative work of [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) by Meta,
this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
"""
# GPU Check and add CPU warning
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
# Model and Tokenizer Configuration
model_id = "meta-llama/Llama-2-7b-hf"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
base_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=bnb_config)
model = PeftModel.from_pretrained(base_model, "ranamhamoud/storytell")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
# MongoDB Connection
PASSWORD = os.environ.get("MONGO_PASS")
connect(host=f"mongodb+srv://ranamhammoud11:{PASSWORD}@stories.zf5v52a.mongodb.net/")
# MongoDB Document
class Story(Document):
message = StringField()
content = StringField()
story_id = SequenceField(primary_key=True)
# Utility function for prompts
def make_prompt(entry):
return f"### Human: Don't repeat the assesments, limit to 500 words {entry} ### Assistant:"
# f"TELL A STORY, RELATE TO COMPUTER SCIENCE, INCLUDE ASSESMENTS. MAKE IT REALISTIC AND AROUND 800 WORDS, END THE STORY WITH "THE END.": {entry}"
def process_text(text):
text = re.sub(r'\[.*?\]', '', text, flags=re.DOTALL)
return text
# Gradio Function
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
temperature: float = 0.6,
top_p: float = 0.7,
top_k: int = 20,
repetition_penalty: float = 1.0,
) -> Iterator[str]:
conversation = []
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": make_prompt(message)})
enc = tokenizer(make_prompt(message), return_tensors="pt", padding=True, truncation=True)
input_ids = enc.input_ids.to(model.device)
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=False)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
end_phrase = "the end."
for text in streamer:
processed_text = process_text(text)
outputs.append(processed_text)
current_output = "".join(outputs)
yield current_output
# Check if 'the end.' is in the current output, case-insensitive
if end_phrase in current_output.lower():
break # Stop generating further if 'the end.' is found
final_story = "".join(outputs)
try:
saved_story = Story(message=message, content=final_story).save()
yield f"{final_story}\n\n Story saved with ID: {saved_story.story_id}"
except Exception as e:
yield f"Failed to save story: {str(e)}"
# Gradio Interface Setup
chat_interface = gr.ChatInterface(
fn=generate,
stop_btn=None,
examples=[
["Can you explain briefly to me what is the Python programming language?"],
["Could you please provide an explanation about the concept of recursion?"],
["Could you explain what a URL is?"]
],
)
# Gradio Web Interface
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
chat_interface.render()
gr.Markdown(LICENSE)
# Main Execution
if __name__ == "__main__":
demo.queue(max_size=20)
demo.launch(share=True) |