Spaces:
Paused
Paused
| import os | |
| import re | |
| import torch | |
| from threading import Thread | |
| from typing import Iterator | |
| from mongoengine import connect, Document, StringField, SequenceField | |
| import gradio as gr | |
| import spaces | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer | |
| from peft import PeftModel | |
| import openai | |
| from openai import OpenAI | |
| import logging | |
| openai.api_key = os.environ.get("OPENAI_KEY") | |
| # Set up logging configuration | |
| logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') | |
| # Example usage of logging in your function | |
| def generate_image(text): | |
| try: | |
| logging.debug("Generating image with prompt: %s", text) | |
| response = openai.images.generate( | |
| model="dall-e-3", | |
| prompt="Create a 4 panel pixar style illustration that accurately depicts the character and the setting of a story:" + text, | |
| n=1, | |
| size="1024x1024" | |
| ) | |
| image_url = response.data[0].url | |
| logging.info("Image generated successfully: %s", image_url) | |
| return image_url | |
| except Exception as error: | |
| logging.error("Failed to generate image: %s", str(error)) | |
| raise gr.Error("An error occurred while generating the image. Please check your API key and try again.") | |
| rope_scaling = { | |
| 'type': 'linear', # Adjust the type to the appropriate scaling type for your model. | |
| 'factor': 8.0 # Use the intended scaling factor. | |
| } | |
| # Constants | |
| MAX_MAX_NEW_TOKENS = 2048 | |
| DEFAULT_MAX_NEW_TOKENS = 1024 | |
| MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
| LICENSE = """ | |
| --- | |
| As a derivative work of [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) by Meta, | |
| this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md). | |
| """ | |
| # GPU Check and add CPU warning | |
| if not torch.cuda.is_available(): | |
| DESCRIPTION += "\n<p>Running on CPU ๐ฅถ This demo does not work on CPU.</p>" | |
| if torch.cuda.is_available(): | |
| # Model and Tokenizer Configuration | |
| model_id = "meta-llama/Llama-3.1-8B" | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=False, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map="auto", | |
| quantization_config=bnb_config, | |
| rope_scaling=rope_scaling # Add this only if your model specifically requires it. | |
| ) | |
| model = PeftModel.from_pretrained(base_model, "ranamhamoud/storytellai-2.0") | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| def make_prompt(entry): | |
| return f"### Human: When asked to explain use a story.Don't repeat the assesments, limit to 500 words.However keep context in mind if edits to the content is required. {entry} ### Assistant:" | |
| def process_text(text): | |
| text = re.sub(r'\[answer:\]\s*', 'Answer: ', text) | |
| text = re.sub(r'\[.*?\](?<!Answer: )', '', text) | |
| return text | |
| custom_css = """ | |
| body, input, button, textarea, label { | |
| font-family: Arial, sans-serif; | |
| font-size: 24px; | |
| } | |
| .gr-chat-interface .gr-chat-message-container { | |
| font-size: 14px; | |
| } | |
| .gr-button { | |
| font-size: 14px; | |
| padding: 12px 24px; | |
| } | |
| .gr-input { | |
| font-size: 14px; | |
| } | |
| """ | |
| def process_text(text): | |
| text = re.sub(r'\[assessment;[^\]]*\]', '', text, flags=re.DOTALL) | |
| text = re.sub(r'\[.*?\]', '', text, flags=re.DOTALL) | |
| return text | |
| def generate( | |
| message: str, | |
| chat_history: list[tuple[str, str]], | |
| max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS, | |
| temperature: float = 0.8, | |
| top_p: float = 0.7, | |
| top_k: int = 30, | |
| repetition_penalty: float = 1.0, | |
| ) -> Iterator[str]: | |
| conversation = [] | |
| for user, assistant in chat_history: | |
| conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) | |
| conversation.append({"role": "user", "content": make_prompt(message)}) | |
| enc = tokenizer(make_prompt(message), return_tensors="pt", padding=True, truncation=True) | |
| input_ids = enc.input_ids.to(model.device) | |
| if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
| input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
| gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
| streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=False) | |
| generate_kwargs = dict( | |
| {"input_ids": input_ids}, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| top_p=top_p, | |
| top_k=top_k, | |
| temperature=temperature, | |
| num_beams=1, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| outputs = [] | |
| for text in streamer: | |
| processed_text = process_text(text) | |
| outputs.append(processed_text) | |
| output = "".join(outputs) | |
| yield output | |
| final_story = "".join(outputs) | |
| final_story_trimmed = remove_last_sentence(final_story) | |
| image_url = generate_image(final_story_trimmed) | |
| return f"{final_story}\n\n" | |
| def remove_last_sentence(text): | |
| sentences = re.split(r'(?<=\.)\s', text) | |
| return ' '.join(sentences[:-1]) if sentences else text | |
| chat_interface = gr.ChatInterface( | |
| fn=generate, | |
| fill_height=True, | |
| stop_btn=None, | |
| examples=[ | |
| ["Tell me about HTTP."], | |
| ["Can you explain briefly to me what is the Python programming language?"], | |
| ["Could you please provide an explanation about Data Science?"], | |
| ["Could you explain what a URL is?"] | |
| ], | |
| theme='shivi/calm_seafoam',autofocus=True, | |
| ) | |
| js_func = """ | |
| function refresh() { | |
| const url = new URL(window.location); | |
| if (url.searchParams.get('__theme') !== 'light') { | |
| url.searchParams.set('__theme', 'light'); | |
| window.location.href = url.href; | |
| } | |
| } | |
| """ | |
| # Gradio Web Interface | |
| with gr.Blocks(css=custom_css,fill_height=True,theme="shivi/calm_seafoam") as demo: | |
| chat_interface.render() | |
| # gr.Markdown(LICENSE) | |
| # Main Execution | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20) | |
| demo.launch(share=True) |