File size: 4,353 Bytes
aff5a3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import torch
import time
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from tld.diffusion import DiffusionTransformer
from tld.configs import LTDConfig, DenoiserConfig, DenoiserLoad
import numpy as np
from PIL import Image

# Image Generation Model Setup
denoiser_cfg = DenoiserConfig(
    image_size=32, 
    noise_embed_dims=256, 
    patch_size=2, 
    embed_dim=768, 
    dropout=0, 
    n_layers=12, 
    text_emb_size=768
)

denoiser_load = DenoiserLoad(**{
    'dtype': torch.float32, 
    'file_url': 'https://huggingface.co/apapiu/small_ldt/resolve/main/state_dict_378000.pth', 
    'local_filename': 'state_dict_378000.pth'
})

cfg = LTDConfig(denoiser_cfg=denoiser_cfg, denoiser_load=denoiser_load)
diffusion_transformer = DiffusionTransformer(cfg)

# Set PyTorch to use all available CPU cores
num_cores = os.cpu_count()
torch.set_num_threads(num_cores)
print(f"Using {num_cores} CPU cores.")

# Text Model Setup
model_name = 'mllmTeam/PhoneLM-1.5B-Instruct'
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='cpu', trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)

def generate_text_response(question):
    start_time = time.time()
    prompt = [{"role": "user", "content": question}]
    input_text = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
    inp = tokenizer(input_text, return_tensors="pt")
    inp = {k: v.to('cpu') for k, v in inp.items()}
    out = model.generate(**inp, max_length=256, do_sample=True, temperature=0.7, top_p=0.7)
    text = tokenizer.decode(out[0], skip_special_tokens=True)
    text = text.split("\n")[-1]
    end_time = time.time()
    elapsed_time = end_time - start_time
    return text

def generate_image(prompt, class_guidance=6, num_imgs=1, seed=11):
    start_time = time.time()
    try:
        # Generate the image
        out = diffusion_transformer.generate_image_from_text(
            prompt=prompt, 
            class_guidance=class_guidance, 
            num_imgs=num_imgs, 
            seed=seed
        )
        
        # Convert to PIL Image if it's not already
        if isinstance(out, torch.Tensor):
            out = out.squeeze().permute(1, 2, 0).numpy()
        
        # Ensure the image is in the right format for Gradio
        if isinstance(out, np.ndarray):
            # Normalize pixel values to 0-255 range
            out = ((out - out.min()) * (1/(out.max() - out.min()) * 255)).astype('uint8')
            out = Image.fromarray(out)
        
        end_time = time.time()
        print(f"Image generation time: {end_time - start_time:.2f} seconds")
        return out
    except Exception as e:
        print(f"Image generation error: {e}")
        return None

def chat_with_ai(message, history):
    max_history_length = 1  # Adjust as needed
    history = history[-max_history_length:]
    if message.startswith('@imagine'):
        # Extract prompt after '@imagine'
        image_prompt = message.split('@imagine', 1)[1].strip()
        image = generate_image(image_prompt)
        
        if image:
            return "", history, image
        else:
            return "", history + [[message, "Failed to generate image."]], None
    else:
        response = generate_text_response(message)
        return response, history + [[message, response]], None



# Create Gradio interface
with gr.Blocks(title="BlazeChat Image Generator") as demo:
    #################
    gr.Markdown("# ⚡Fast CPU-Powered Chat & Image Generation")
    gr.Markdown("Generate text and images using advanced AI models on CPU. Use `@imagine [prompt]` to create images or chat naturally.")
    gr.Markdown("https://github.com/SanshruthR/CPU_BlazeChat")
    ####################
    chatbot = gr.Chatbot()
    msg = gr.Textbox(label="Enter your message")
    ####submit button
    submit_button = gr.Button("Submit")
    ##########
    clear = gr.Button("Clear")
    img_output = gr.Image(label="Generated Image")

    msg.submit(chat_with_ai, [msg, chatbot], [msg, chatbot, img_output])

    ####################binding with submit
    submit_button.click(chat_with_ai, [msg, chatbot], [msg, chatbot, img_output])



    ###################
    clear.click(lambda: None, None, chatbot, queue=False)

# Launch the demo
demo.launch(debug=True,ssr_mode=False)