sandz7 commited on
Commit
7f412e3
Β·
1 Parent(s): f1d3e92

start for chimera

Browse files
Files changed (3) hide show
  1. .gitignore +0 -0
  2. app.py +121 -0
  3. requirements.txt +6 -0
.gitignore ADDED
File without changes
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import TextIteratorStreamer, AutoProcessor, LlavaForConditionalGeneration
3
+ from diffusers import DiffusionPipeline
4
+ import gradio as gr
5
+ import numpy as np
6
+ import accelerate
7
+ import spaces
8
+ from PIL import Image
9
+ import threading
10
+
11
+ DESCRIPTION = '''
12
+ <div>
13
+ <h1 style="text-align: center;">Krypton πŸ•‹</h1>
14
+ <p>This uses an Open Source model from <a href="https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers"><b>xtuner/llava-llama-3-8b-v1_1-transformers</b></a></p>
15
+ </div>
16
+ '''
17
+ # Llava Installed
18
+ llava_model = LlavaForConditionalGeneration.from_pretrained(
19
+ "xtuner/llava-llama-3-8b-v1_1-transformers",
20
+ torch_dtype=torch.float16,
21
+ low_cpu_mem_usage=True,
22
+ )
23
+
24
+ llava_model.to("cuda:0")
25
+
26
+ processor = AutoProcessor.from_pretrained("xtuner/llava-llama-3-8b-v1_1-transformers")
27
+
28
+ llava_model.generation_config.eos_token_id=128009
29
+
30
+ # Stable Diffusor Installed
31
+ base = DiffusionPipeline.from_pretrained(
32
+ "stabilityai/stable-diffusion-xl-base-1.0",
33
+ torch_dtype=torch.float16,
34
+ variant="fp16",
35
+ use_safetensors=True,
36
+ )
37
+ base.to('cuda')
38
+
39
+ refiner = DiffusionPipeline.from_pretrained(
40
+ "stabilityai/stable-diffusion-xl-base-1.0",
41
+ text_encoder_2=base.text_encoder_2,
42
+ vae=base.vae,
43
+ torch_dtype=torch.float16,
44
+ use_safetensors=True,
45
+ variant="fp16",
46
+ )
47
+ refiner.to('cuda')
48
+
49
+ # All Installed. Let's instance them in the function
50
+
51
+ def chimera(message, history):
52
+ """
53
+ Receives input from gradio from the prompt but also
54
+ if any images were passed that i also placed for formatting
55
+ for PIL and with the prompt both are passed to proper generation,
56
+ depending on the request from prompt, that prompt output will return here.
57
+ """
58
+ print(f"Message:\n{message}\nType:\n{type.message}")
59
+ if message["files"]:
60
+ if type(message["files"][-1]) == dict:
61
+ image_path = message["files"][-1]["path"]
62
+ else:
63
+ image_path = message["files"][-1]
64
+ else:
65
+ # If no image was uploaded than look for past ones
66
+ for hist in history:
67
+ if type(hist[0]) == tuple:
68
+ image_path = hist[0][0] # item inside items for history
69
+
70
+ prompt = f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
71
+
72
+ if image_path is None:
73
+ image = base(
74
+ prompt=prompt,
75
+ num_inference_steps=40,
76
+ denoising_end=0.8,
77
+ output_type="latent",
78
+ ).images
79
+ image = refiner(
80
+ prompt=prompt,
81
+ num_inference_steps=40,
82
+ denoising_start=0.8,
83
+ image=image
84
+ ).images[0]
85
+ return image
86
+
87
+ else:
88
+
89
+ # Time to instance the llava
90
+ image = Image.open(image_path)
91
+ inputs = processor(prompt, image, return_tensors='pt').to(0, torch.float16)
92
+ streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": False, "skip_prompt": True})
93
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, do_sample=False)
94
+
95
+ thread = threading.Thread(target=llava_model.generate, kwargs=generation_kwargs)
96
+ thread.start()
97
+
98
+ buffer = ""
99
+ for new_text in streamer:
100
+ # find <|eot_id|> and remove it from the new_text
101
+ if "<|eot_id|>" in new_text:
102
+ new_text = new_text.split("<|eot_id|>")[0]
103
+ buffer += new_text
104
+ generated_text_no_prompt = buffer
105
+ yield generated_text_no_prompt
106
+
107
+
108
+ chatbot=gr.Chatbot(height=600, label="Chimera AI")
109
+ chat_input = gr.MultimodalTextbox(interactive=True, file_types=["images"], placeholder="Enter your question or upload an image.", show_label=False)
110
+ with gr.Blocks(fill_height=True) as demo:
111
+ gr.Markdown(DESCRIPTION)
112
+ gr.ChatInterface(
113
+ fn=chimera,
114
+ chatbot=chatbot,
115
+ fill_height=True,
116
+ multimodal=True,
117
+ textbox=chat_input,
118
+ )
119
+
120
+ if __name__ == "__main__":
121
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ gradio
4
+ numpy
5
+ accelerate
6
+ diffusers