johann22 commited on
Commit
9491afb
β€’
1 Parent(s): 3652d2d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from huggingface_hub import InferenceClient
4
+ import gradio as gr
5
+ from utils import parse_action, parse_file_content, read_python_module_structure
6
+ from datetime import datetime
7
+ import agent
8
+ from models import models
9
+ import urllib.request
10
+ import uuid
11
+ base_url="https://johann22-chat-diffusion.hf.space/"
12
+ loaded_model=[]
13
+ for i,model in enumerate(models):
14
+ loaded_model.append(gr.load(f'models/{model}'))
15
+ print (loaded_model)
16
+
17
+ now = datetime.now()
18
+ date_time_str = now.strftime("%Y-%m-%d %H:%M:%S")
19
+
20
+ client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
21
+ #model = gr.load("models/stabilityai/sdxl-turbo")
22
+ history = []
23
+
24
+ def infer(txt):
25
+ return (model(txt))
26
+
27
+ def format_prompt(message, history):
28
+ prompt = "<s>"
29
+ for user_prompt, bot_response in history:
30
+ prompt += f"[INST] {user_prompt} [/INST]"
31
+ prompt += f" {bot_response}</s> "
32
+ prompt += f"[INST] {message} [/INST]"
33
+ return prompt
34
+
35
+ def run_gpt(in_prompt,history):
36
+ prompt=format_prompt(in_prompt,history)
37
+ seed = random.randint(1,1111111111111111)
38
+ print (seed)
39
+ generate_kwargs = dict(
40
+ temperature=1.0,
41
+ max_new_tokens=1048,
42
+ top_p=0.99,
43
+ repetition_penalty=1.0,
44
+ do_sample=True,
45
+ seed=seed,
46
+ )
47
+ content = agent.GENERATE_PROMPT + prompt
48
+ #print(content)
49
+ stream = client.text_generation(content, **generate_kwargs, stream=True, details=True, return_full_text=False)
50
+ resp = ""
51
+ for response in stream:
52
+ resp += response.token.text
53
+ return resp
54
+
55
+
56
+ def run(purpose,history,model_drop):
57
+ if history:
58
+ history=str(history).strip("[]")
59
+ if not history:
60
+ history = ""
61
+ out_prompt = run_gpt(purpose,history)
62
+ yield ("",[(purpose,out_prompt)],None)
63
+ model=loaded_model[int(model_drop)]
64
+ out_img=model(out_prompt)
65
+ print(out_img)
66
+ image=f'{base_url}file={out_img}'
67
+ uid = uuid.uuid4()
68
+ urllib.request.urlretrieve(image, f'{uid}.png')
69
+ return ("",[(purpose,out_prompt)],f'{uid}.png')
70
+ #return ("", [(purpose,history)])
71
+
72
+
73
+
74
+ ################################################
75
+
76
+ with gr.Blocks() as iface:
77
+ gr.HTML("""<center><h1>Chat Diffusion</h1><br><h3>This chatbot will generate images</h3></center>""")
78
+ with gr.Row():
79
+ with gr.Column():
80
+ chatbot=gr.Chatbot()
81
+ msg = gr.Textbox()
82
+ model_drop=gr.Dropdown(label="Diffusion Models", type="index", choices=[m for m in models], value=models[0])
83
+ with gr.Row():
84
+ submit_b = gr.Button()
85
+ stop_b = gr.Button("Stop")
86
+ clear = gr.ClearButton([msg, chatbot])
87
+
88
+ sumbox=gr.Image(label="Image",type="filepath")
89
+
90
+
91
+ sub_b = submit_b.click(run, [msg,chatbot,model_drop],[msg,chatbot,sumbox])
92
+ sub_e = msg.submit(run, [msg, chatbot,model_drop], [msg, chatbot,sumbox])
93
+ stop_b.click(None,None,None, cancels=[sub_b,sub_e])
94
+ iface.launch()