ProteinGenesis / app.py
aiqcamp's picture
Update app.py
b122ddc verified
raw
history blame
37.9 kB
import os,sys
from openai import OpenAI
import gradio as gr
# install required packages
os.system('pip install -q plotly')
os.system('pip install -q matplotlib')
os.system('pip install dgl==1.0.2+cu116 -f https://data.dgl.ai/wheels/cu116/repo.html')
os.environ["DGLBACKEND"] = "pytorch"
print('Modules installed')
# ํ•„์ˆ˜ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์ž„ํฌํŠธ
from datasets import load_dataset
import plotly.graph_objects as go
import numpy as np
import py3Dmol
from io import StringIO
import json
import secrets
import copy
import matplotlib.pyplot as plt
from utils.sampler import HuggingFace_sampler
from utils.parsers_inference import parse_pdb
from model.util import writepdb
from utils.inpainting_util import *
# Hugging Face ํ† ํฐ ์„ค์ •
ACCESS_TOKEN = os.getenv("HF_TOKEN")
if not ACCESS_TOKEN:
raise ValueError("HF_TOKEN not found in environment variables")
# OpenAI ํด๋ผ์ด์–ธํŠธ ์„ค์ • (Hugging Face ์—”๋“œํฌ์ธํŠธ ์‚ฌ์šฉ)
client = OpenAI(
base_url="https://api-inference.huggingface.co/v1/",
api_key=ACCESS_TOKEN,
)
# ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
ds = load_dataset("lamm-mit/protein_secondary_structure_from_PDB",
token=ACCESS_TOKEN)
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
for message in client.chat.completions.create(
model="CohereForAI/c4ai-command-r-plus-08-2024",
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
messages=messages,
):
token = message.choices[0].delta.content
response += token
yield response
# ์ฑ—๋ด‡ ๋ฐ ๋‹จ๋ฐฑ์งˆ ์ƒ์„ฑ ๊ด€๋ จ ํ•จ์ˆ˜๋“ค
def process_chat(message, history):
messages = [{"role": "user", "content": message}]
response = pipe(messages)[0]['generated_text']
if any(keyword in message.lower() for keyword in ['protein', 'generate', '๋‹จ๋ฐฑ์งˆ', '์ƒ์„ฑ']):
relevant_data = search_protein_data(message)
params = extract_parameters(response, relevant_data)
protein_result = generate_protein(params)
explanation = generate_explanation(protein_result, params)
return response + "\n\n" + explanation
return response
def search_protein_data(query):
relevant_entries = []
for entry in ds['train']:
if any(keyword in entry['sequence'].lower() for keyword in query.lower().split()):
relevant_entries.append(entry)
return relevant_entries
def extract_parameters(llm_response, dataset_info):
params = {
'sequence_length': 100,
'helix_bias': 0.02,
'strand_bias': 0.02,
'loop_bias': 0.1,
'hydrophobic_target_score': 0
}
return params
def generate_explanation(result, params):
explanation = f"""
์ƒ์„ฑ๋œ ๋‹จ๋ฐฑ์งˆ ๋ถ„์„:
- ๊ธธ์ด: {params['sequence_length']} ์•„๋ฏธ๋…ธ์‚ฐ
- ๊ตฌ์กฐ์  ํŠน์ง•:
* ์•ŒํŒŒ ๋‚˜์„  ๋น„์œจ: {params['helix_bias']*100}%
* ๋ฒ ํƒ€ ์‹œํŠธ ๋น„์œจ: {params['strand_bias']*100}%
* ๋ฃจํ”„ ๊ตฌ์กฐ ๋น„์œจ: {params['loop_bias']*100}%
- ํŠน์ˆ˜ ๊ธฐ๋Šฅ: {result.get('special_features', '์—†์Œ')}
"""
return explanation
def protein_diffusion_model(sequence, seq_len, helix_bias, strand_bias, loop_bias,
secondary_structure, aa_bias, aa_bias_potential,
num_steps, noise, hydrophobic_target_score, hydrophobic_potential,
contigs, pssm, seq_mask, str_mask, rewrite_pdb):
dssp_checkpoint = './SEQDIFF_230205_dssp_hotspots_25mask_EQtasks_mod30.pt'
og_checkpoint = './SEQDIFF_221219_equalTASKS_nostrSELFCOND_mod30.pt'
model_args = copy.deepcopy(args)
# make sampler
S = HuggingFace_sampler(args=model_args)
# get random prefix
S.out_prefix = './tmp/'+secrets.token_hex(nbytes=10).upper()
# set args
S.args['checkpoint'] = None
S.args['dump_trb'] = False
S.args['dump_args'] = True
S.args['save_best_plddt'] = True
S.args['T'] = 20
S.args['strand_bias'] = 0.0
S.args['loop_bias'] = 0.0
S.args['helix_bias'] = 0.0
S.args['potentials'] = None
S.args['potential_scale'] = None
S.args['aa_composition'] = None
# get sequence if entered and make sure all chars are valid
alt_aa_dict = {'B':['D','N'],'J':['I','L'],'U':['C'],'Z':['E','Q'],'O':['K']}
if sequence not in ['',None]:
L = len(sequence)
aa_seq = []
for aa in sequence.upper():
if aa in alt_aa_dict.keys():
aa_seq.append(np.random.choice(alt_aa_dict[aa]))
else:
aa_seq.append(aa)
S.args['sequence'] = aa_seq
elif contigs not in ['',None]:
S.args['contigs'] = [contigs]
else:
S.args['contigs'] = [f'{seq_len}']
L = int(seq_len)
print('DEBUG: ',rewrite_pdb)
if rewrite_pdb not in ['',None]:
S.args['pdb'] = rewrite_pdb.name
if seq_mask not in ['',None]:
S.args['inpaint_seq'] = [seq_mask]
if str_mask not in ['',None]:
S.args['inpaint_str'] = [str_mask]
if secondary_structure in ['',None]:
secondary_structure = None
else:
secondary_structure = ''.join(['E' if x == 'S' else x for x in secondary_structure])
if L < len(secondary_structure):
secondary_structure = secondary_structure[:len(sequence)]
elif L == len(secondary_structure):
pass
else:
dseq = L - len(secondary_structure)
secondary_structure += secondary_structure[-1]*dseq
# potentials
potential_list = []
potential_bias_list = []
if aa_bias not in ['',None]:
potential_list.append('aa_bias')
S.args['aa_composition'] = aa_bias
if aa_bias_potential in ['',None]:
aa_bias_potential = 3
potential_bias_list.append(str(aa_bias_potential))
'''
if target_charge not in ['',None]:
potential_list.append('charge')
if charge_potential in ['',None]:
charge_potential = 1
potential_bias_list.append(str(charge_potential))
S.args['target_charge'] = float(target_charge)
if target_ph in ['',None]:
target_ph = 7.4
S.args['target_pH'] = float(target_ph)
'''
if hydrophobic_target_score not in ['',None]:
potential_list.append('hydrophobic')
S.args['hydrophobic_score'] = float(hydrophobic_target_score)
if hydrophobic_potential in ['',None]:
hydrophobic_potential = 3
potential_bias_list.append(str(hydrophobic_potential))
if pssm not in ['',None]:
potential_list.append('PSSM')
potential_bias_list.append('5')
S.args['PSSM'] = pssm.name
if len(potential_list) > 0:
S.args['potentials'] = ','.join(potential_list)
S.args['potential_scale'] = ','.join(potential_bias_list)
# normalise secondary_structure bias from range 0-0.3
S.args['secondary_structure'] = secondary_structure
S.args['helix_bias'] = helix_bias
S.args['strand_bias'] = strand_bias
S.args['loop_bias'] = loop_bias
# set T
if num_steps in ['',None]:
S.args['T'] = 20
else:
S.args['T'] = int(num_steps)
# noise
if 'normal' in noise:
S.args['sample_distribution'] = noise
S.args['sample_distribution_gmm_means'] = [0]
S.args['sample_distribution_gmm_variances'] = [1]
elif 'gmm2' in noise:
S.args['sample_distribution'] = noise
S.args['sample_distribution_gmm_means'] = [-1,1]
S.args['sample_distribution_gmm_variances'] = [1,1]
elif 'gmm3' in noise:
S.args['sample_distribution'] = noise
S.args['sample_distribution_gmm_means'] = [-1,0,1]
S.args['sample_distribution_gmm_variances'] = [1,1,1]
if secondary_structure not in ['',None] or helix_bias+strand_bias+loop_bias > 0:
S.args['checkpoint'] = dssp_checkpoint
S.args['d_t1d'] = 29
print('using dssp checkpoint')
else:
S.args['checkpoint'] = og_checkpoint
S.args['d_t1d'] = 24
print('using og checkpoint')
for k,v in S.args.items():
print(f"{k} --> {v}")
# init S
S.model_init()
S.diffuser_init()
S.setup()
# sampling loop
plddt_data = []
for j in range(S.max_t):
print(f'on step {j}')
output_seq, output_pdb, plddt = S.take_step_get_outputs(j)
plddt_data.append(plddt)
yield output_seq, output_pdb, display_pdb(output_pdb), get_plddt_plot(plddt_data, S.max_t)
output_seq, output_pdb, plddt = S.get_outputs()
return output_seq, output_pdb, display_pdb(output_pdb), get_plddt_plot(plddt_data, S.max_t)
def get_plddt_plot(plddt_data, max_t):
x = [i+1 for i in range(len(plddt_data))]
fig, ax = plt.subplots(figsize=(15,6))
ax.plot(x,plddt_data,color='#661dbf', linewidth=3,marker='o')
ax.set_xticks([i+1 for i in range(max_t)])
ax.set_yticks([(i+1)/10 for i in range(10)])
ax.set_ylim([0,1])
ax.set_ylabel('model confidence (plddt)')
ax.set_xlabel('diffusion steps (t)')
return fig
def display_pdb(path_to_pdb):
'''
#function to display pdb in py3dmol
'''
pdb = open(path_to_pdb, "r").read()
view = py3Dmol.view(width=500, height=500)
view.addModel(pdb, "pdb")
view.setStyle({'model': -1}, {"cartoon": {'colorscheme':{'prop':'b','gradient':'roygb','min':0,'max':1}}})#'linear', 'min': 0, 'max': 1, 'colors': ["#ff9ef0","#a903fc",]}}})
view.zoomTo()
output = view._make_html().replace("'", '"')
print(view._make_html())
x = f"""<!DOCTYPE html><html></center> {output} </center></html>""" # do not use ' in this input
return f"""<iframe height="500px" width="100%" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
'''
return f"""<iframe style="width: 100%; height:700px" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
'''
def get_motif_preview(pdb_id, contigs):
try:
input_pdb = fetch_pdb(pdb_id=pdb_id.lower() if pdb_id else None)
if input_pdb is None:
return gr.HTML("PDB ID๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”"), None
parse = parse_pdb(input_pdb)
output_name = input_pdb
pdb = open(output_name, "r").read()
view = py3Dmol.view(width=500, height=500)
view.addModel(pdb, "pdb")
if contigs in ['',0]:
contigs = ['0']
else:
contigs = [contigs]
print('DEBUG: ',contigs)
pdb_map = get_mappings(ContigMap(parse,contigs))
print('DEBUG: ',pdb_map)
print('DEBUG: ',pdb_map['con_ref_idx0'])
roi = [x[1]-1 for x in pdb_map['con_ref_pdb_idx']]
colormap = {0:'#D3D3D3', 1:'#F74CFF'}
colors = {i+1: colormap[1] if i in roi else colormap[0] for i in range(parse['xyz'].shape[0])}
view.setStyle({"cartoon": {"colorscheme": {"prop": "resi", "map": colors}}})
view.zoomTo()
output = view._make_html().replace("'", '"')
print(view._make_html())
x = f"""<!DOCTYPE html><html></center> {output} </center></html>""" # do not use ' in this input
return f"""<iframe height="500px" width="100%" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""", output_name
except Exception as e:
return gr.HTML(f"์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"), None
def fetch_pdb(pdb_id=None):
if pdb_id is None or pdb_id == "":
return None
else:
os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_id}.pdb")
return f"{pdb_id}.pdb"
# MSA AND PSSM GUIDANCE
def save_pssm(file_upload):
filename = file_upload.name
orig_name = file_upload.orig_name
if filename.split('.')[-1] in ['fasta', 'a3m']:
return msa_to_pssm(file_upload)
return filename
def msa_to_pssm(msa_file):
# Define the lookup table for converting amino acids to indices
aa_to_index = {'A': 0, 'R': 1, 'N': 2, 'D': 3, 'C': 4, 'Q': 5, 'E': 6, 'G': 7, 'H': 8, 'I': 9, 'L': 10,
'K': 11, 'M': 12, 'F': 13, 'P': 14, 'S': 15, 'T': 16, 'W': 17, 'Y': 18, 'V': 19, 'X': 20, '-': 21}
# Open the FASTA file and read the sequences
records = list(SeqIO.parse(msa_file.name, "fasta"))
assert len(records) >= 1, "MSA must contain more than one protein sequecne."
first_seq = str(records[0].seq)
aligned_seqs = [first_seq]
# print(aligned_seqs)
# Perform sequence alignment using the Needleman-Wunsch algorithm
aligner = Align.PairwiseAligner()
aligner.open_gap_score = -0.7
aligner.extend_gap_score = -0.3
for record in records[1:]:
alignment = aligner.align(first_seq, str(record.seq))[0]
alignment = alignment.format().split("\n")
al1 = alignment[0]
al2 = alignment[2]
al1_fin = ""
al2_fin = ""
percent_gap = al2.count('-')/ len(al2)
if percent_gap > 0.4:
continue
for i in range(len(al1)):
if al1[i] != '-':
al1_fin += al1[i]
al2_fin += al2[i]
aligned_seqs.append(str(al2_fin))
# Get the length of the aligned sequences
aligned_seq_length = len(first_seq)
# Initialize the position scoring matrix
matrix = np.zeros((22, aligned_seq_length))
# Iterate through the aligned sequences and count the amino acids at each position
for seq in aligned_seqs:
#print(seq)
for i in range(aligned_seq_length):
if i == len(seq):
break
amino_acid = seq[i]
if amino_acid.upper() not in aa_to_index.keys():
continue
else:
aa_index = aa_to_index[amino_acid.upper()]
matrix[aa_index, i] += 1
# Normalize the counts to get the frequency of each amino acid at each position
matrix /= len(aligned_seqs)
print(len(aligned_seqs))
matrix[20:,]=0
outdir = ".".join(msa_file.name.split('.')[:-1]) + ".csv"
np.savetxt(outdir, matrix[:21,:].T, delimiter=",")
return outdir
def get_pssm(fasta_msa, input_pssm):
try:
if input_pssm is not None:
outdir = input_pssm.name
elif fasta_msa is not None:
outdir = save_pssm(fasta_msa)
else:
return gr.Plot(label="ํŒŒ์ผ์„ ์—…๋กœ๋“œํ•ด์ฃผ์„ธ์š”"), None
pssm = np.loadtxt(outdir, delimiter=",", dtype=float)
fig, ax = plt.subplots(figsize=(15,6))
plt.imshow(torch.permute(torch.tensor(pssm),(1,0)))
return fig, outdir
except Exception as e:
return gr.Plot(label=f"์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"), None
# ํžˆ์–ด๋กœ ๋Šฅ๋ ฅ์น˜ ๊ณ„์‚ฐ ํ•จ์ˆ˜ ์ถ”๊ฐ€
def calculate_hero_stats(helix_bias, strand_bias, loop_bias, hydrophobic_score):
stats = {
'strength': strand_bias * 20, # ๋ฒ ํƒ€์‹œํŠธ ๊ตฌ์กฐ ๊ธฐ๋ฐ˜
'flexibility': helix_bias * 20, # ์•ŒํŒŒํ—ฌ๋ฆญ์Šค ๊ตฌ์กฐ ๊ธฐ๋ฐ˜
'speed': loop_bias * 5, # ๋ฃจํ”„ ๊ตฌ์กฐ ๊ธฐ๋ฐ˜
'defense': abs(hydrophobic_score) if hydrophobic_score else 0
}
return stats
def toggle_seq_input(choice):
if choice == "์ž๋™ ์„ค๊ณ„":
return gr.update(visible=True), gr.update(visible=False)
else: # "์ง์ ‘ ์ž…๋ ฅ"
return gr.update(visible=False), gr.update(visible=True)
def toggle_secondary_structure(choice):
if choice == "์Šฌ๋ผ์ด๋”๋กœ ์„ค์ •":
return (
gr.update(visible=True), # helix_bias
gr.update(visible=True), # strand_bias
gr.update(visible=True), # loop_bias
gr.update(visible=False) # secondary_structure
)
else: # "์ง์ ‘ ์ž…๋ ฅ"
return (
gr.update(visible=False), # helix_bias
gr.update(visible=False), # strand_bias
gr.update(visible=False), # loop_bias
gr.update(visible=True) # secondary_structure
)
def create_radar_chart(stats):
# ๋ ˆ์ด๋” ์ฐจํŠธ ์ƒ์„ฑ ๋กœ์ง
categories = list(stats.keys())
values = list(stats.values())
fig = go.Figure(data=go.Scatterpolar(
r=values,
theta=categories,
fill='toself'
))
fig.update_layout(
polar=dict(
radialaxis=dict(
visible=True,
range=[0, 1]
)),
showlegend=False
)
return fig
def generate_hero_description(name, stats, abilities):
# ํžˆ์–ด๋กœ ์„ค๋ช… ์ƒ์„ฑ ๋กœ์ง
description = f"""
ํžˆ์–ด๋กœ ์ด๋ฆ„: {name}
์ฃผ์š” ๋Šฅ๋ ฅ:
- ๊ทผ๋ ฅ: {'โ˜…' * int(stats['strength'] * 5)}
- ์œ ์—ฐ์„ฑ: {'โ˜…' * int(stats['flexibility'] * 5)}
- ์Šคํ”ผ๋“œ: {'โ˜…' * int(stats['speed'] * 5)}
- ๋ฐฉ์–ด๋ ฅ: {'โ˜…' * int(stats['defense'] * 5)}
ํŠน์ˆ˜ ๋Šฅ๋ ฅ: {', '.join(abilities)}
"""
return description
def combined_generation(name, strength, flexibility, speed, defense, size, abilities,
sequence, seq_len, helix_bias, strand_bias, loop_bias,
secondary_structure, aa_bias, aa_bias_potential,
num_steps, noise, hydrophobic_target_score, hydrophobic_potential,
contigs, pssm, seq_mask, str_mask, rewrite_pdb):
try:
# protein_diffusion_model ์‹คํ–‰
generator = protein_diffusion_model(
sequence=None,
seq_len=size, # ํžˆ์–ด๋กœ ํฌ๊ธฐ๋ฅผ seq_len์œผ๋กœ ์‚ฌ์šฉ
helix_bias=flexibility, # ํžˆ์–ด๋กœ ์œ ์—ฐ์„ฑ์„ helix_bias๋กœ ์‚ฌ์šฉ
strand_bias=strength, # ํžˆ์–ด๋กœ ๊ฐ•๋„๋ฅผ strand_bias๋กœ ์‚ฌ์šฉ
loop_bias=speed, # ํžˆ์–ด๋กœ ์Šคํ”ผ๋“œ๋ฅผ loop_bias๋กœ ์‚ฌ์šฉ
secondary_structure=None,
aa_bias=None,
aa_bias_potential=None,
num_steps="25",
noise="normal",
hydrophobic_target_score=str(-defense), # ํžˆ์–ด๋กœ ๋ฐฉ์–ด๋ ฅ์„ hydrophobic score๋กœ ์‚ฌ์šฉ
hydrophobic_potential="2",
contigs=None,
pssm=None,
seq_mask=None,
str_mask=None,
rewrite_pdb=None
)
# ๋งˆ์ง€๋ง‰ ๊ฒฐ๊ณผ ๊ฐ€์ ธ์˜ค๊ธฐ
final_result = None
for result in generator:
final_result = result
if final_result is None:
raise Exception("์ƒ์„ฑ ๊ฒฐ๊ณผ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค")
output_seq, output_pdb, structure_view, plddt_plot = final_result
# ํžˆ์–ด๋กœ ๋Šฅ๋ ฅ์น˜ ๊ณ„์‚ฐ
stats = calculate_hero_stats(flexibility, strength, speed, defense)
# ๋ชจ๋“  ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜
return (
create_radar_chart(stats), # ๋Šฅ๋ ฅ์น˜ ์ฐจํŠธ
generate_hero_description(name, stats, abilities), # ํžˆ์–ด๋กœ ์„ค๋ช…
output_seq, # ๋‹จ๋ฐฑ์งˆ ์„œ์—ด
output_pdb, # PDB ํŒŒ์ผ
structure_view, # 3D ๊ตฌ์กฐ
plddt_plot # ์‹ ๋ขฐ๋„ ์ฐจํŠธ
)
except Exception as e:
print(f"Error in combined_generation: {str(e)}")
return (
None,
f"์—๋Ÿฌ: {str(e)}",
None,
None,
gr.HTML("์—๋Ÿฌ๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค"),
None
)
with gr.Blocks(theme='ParityError/Interstellar') as demo:
with gr.Row():
# ์™ผ์ชฝ ์—ด: ์ฑ—๋ด‡ ๋ฐ ์ปจํŠธ๋กค ํŒจ๋„
with gr.Column(scale=1):
# ์ฑ—๋ด‡ ์ธํ„ฐํŽ˜์ด์Šค
gr.Markdown("# ๐Ÿค– AI ๋‹จ๋ฐฑ์งˆ ์„ค๊ณ„ ๋„์šฐ๋ฏธ")
chatbot = gr.Chatbot(height=600)
with gr.Accordion("์ฑ„ํŒ… ์„ค์ •", open=False):
system_message = gr.Textbox(
value="๋‹น์‹ ์€ ๋‹จ๋ฐฑ์งˆ ์„ค๊ณ„๋ฅผ ๋„์™€์ฃผ๋Š” ์ „๋ฌธ๊ฐ€์ž…๋‹ˆ๋‹ค.",
label="์‹œ์Šคํ…œ ๋ฉ”์‹œ์ง€"
)
max_tokens = gr.Slider(
minimum=1,
maximum=2048,
value=512,
step=1,
label="์ตœ๋Œ€ ํ† ํฐ ์ˆ˜"
)
temperature = gr.Slider(
minimum=0.1,
maximum=4.0,
value=0.7,
step=0.1,
label="Temperature"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-P"
)
# ํƒญ ์ธํ„ฐํŽ˜์ด์Šค
with gr.Tabs():
with gr.TabItem("๐Ÿฆธโ€โ™‚๏ธ ํžˆ์–ด๋กœ ๋””์ž์ธ"):
gr.Markdown("""
### โœจ ๋‹น์‹ ๋งŒ์˜ ํŠน๋ณ„ํ•œ ํžˆ์–ด๋กœ๋ฅผ ๋งŒ๋“ค์–ด๋ณด์„ธ์š”!
๊ฐ ๋Šฅ๋ ฅ์น˜๋ฅผ ์กฐ์ ˆํ•˜๋ฉด ํžˆ์–ด๋กœ์˜ DNA๊ฐ€ ์ž๋™์œผ๋กœ ์„ค๊ณ„๋ฉ๋‹ˆ๋‹ค.
""")
# ํžˆ์–ด๋กœ ๊ธฐ๋ณธ ์ •๋ณด
hero_name = gr.Textbox(
label="ํžˆ์–ด๋กœ ์ด๋ฆ„",
placeholder="๋‹น์‹ ์˜ ํžˆ์–ด๋กœ ์ด๋ฆ„์„ ์ง€์–ด์ฃผ์„ธ์š”!",
info="ํžˆ์–ด๋กœ์˜ ์ •์ฒด์„ฑ์„ ๋‚˜ํƒ€๋‚ด๋Š” ์ด๋ฆ„์„ ์ž…๋ ฅํ•˜์„ธ์š”"
)
# ๋Šฅ๋ ฅ์น˜ ์„ค์ •
gr.Markdown("### ๐Ÿ’ช ํžˆ์–ด๋กœ ๋Šฅ๋ ฅ์น˜ ์„ค์ •")
with gr.Row():
strength = gr.Slider(
minimum=0.0, maximum=0.05,
label="๐Ÿ’ช ์ดˆ๊ฐ•๋ ฅ(๊ทผ๋ ฅ)",
value=0.02,
info="๋‹จ๋‹จํ•œ ๋ฒ ํƒ€์‹œํŠธ ๊ตฌ์กฐ๋กœ ๊ฐ•๋ ฅํ•œ ํž˜์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค"
)
flexibility = gr.Slider(
minimum=0.0, maximum=0.05,
label="๐Ÿคธโ€โ™‚๏ธ ์œ ์—ฐ์„ฑ",
value=0.02,
info="๋‚˜์„ ํ˜• ์•ŒํŒŒํ—ฌ๋ฆญ์Šค ๊ตฌ์กฐ๋กœ ์œ ์—ฐํ•œ ์›€์ง์ž„์„ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•ฉ๋‹ˆ๋‹ค"
)
with gr.Row():
speed = gr.Slider(
minimum=0.0, maximum=0.20,
label="โšก ์Šคํ”ผ๋“œ",
value=0.1,
info="๋ฃจํ”„ ๊ตฌ์กฐ๋กœ ๋น ๋ฅธ ์›€์ง์ž„์„ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค"
)
defense = gr.Slider(
minimum=-10, maximum=10,
label="๐Ÿ›ก๏ธ ๋ฐฉ์–ด๋ ฅ",
value=0,
info="์Œ์ˆ˜: ์ˆ˜์ค‘ ํ™œ๋™์— ํŠนํ™”, ์–‘์ˆ˜: ์ง€์ƒ ํ™œ๋™์— ํŠนํ™”"
)
# ํžˆ์–ด๋กœ ํฌ๊ธฐ ์„ค์ •
hero_size = gr.Slider(
minimum=50, maximum=200,
label="๐Ÿ“ ํžˆ์–ด๋กœ ํฌ๊ธฐ",
value=100,
info="ํžˆ์–ด๋กœ์˜ ์ „์ฒด์ ์ธ ํฌ๊ธฐ๋ฅผ ๊ฒฐ์ •ํ•ฉ๋‹ˆ๋‹ค"
)
# ํŠน์ˆ˜ ๋Šฅ๋ ฅ ์„ค์ •
with gr.Accordion("๐ŸŒŸ ํŠน์ˆ˜ ๋Šฅ๋ ฅ", open=False):
gr.Markdown("""
ํŠน์ˆ˜ ๋Šฅ๋ ฅ์„ ์„ ํƒํ•˜๋ฉด ํžˆ์–ด๋กœ์˜ DNA์— ํŠน๋ณ„ํ•œ ๊ตฌ์กฐ๊ฐ€ ์ถ”๊ฐ€๋ฉ๋‹ˆ๋‹ค.
- ์ž๊ฐ€ ํšŒ๋ณต: ๋‹จ๋ฐฑ์งˆ ๊ตฌ์กฐ ๋ณต๊ตฌ ๋Šฅ๋ ฅ ๊ฐ•ํ™”
- ์›๊ฑฐ๋ฆฌ ๊ณต๊ฒฉ: ํŠน์ˆ˜ํ•œ ๊ตฌ์กฐ์  ๋Œ์ถœ๋ถ€ ํ˜•์„ฑ
- ๋ฐฉ์–ด๋ง‰ ์ƒ์„ฑ: ์•ˆ์ •์ ์ธ ๋ณดํ˜ธ์ธต ๊ตฌ์กฐ ์ƒ์„ฑ
""")
special_ability = gr.CheckboxGroup(
choices=["์ž๊ฐ€ ํšŒ๋ณต", "์›๊ฑฐ๋ฆฌ ๊ณต๊ฒฉ", "๋ฐฉ์–ด๋ง‰ ์ƒ์„ฑ"],
label="ํŠน์ˆ˜ ๋Šฅ๋ ฅ ์„ ํƒ"
)
# ์ƒ์„ฑ ๋ฒ„ํŠผ
create_btn = gr.Button("๐Ÿงฌ ํžˆ์–ด๋กœ ์ƒ์„ฑ!", variant="primary", scale=2)
with gr.TabItem("๐Ÿงฌ ํžˆ์–ด๋กœ DNA ์„ค๊ณ„"):
gr.Markdown("""
### ๐Ÿงช ํžˆ์–ด๋กœ DNA ๊ณ ๊ธ‰ ์„ค์ •
ํžˆ์–ด๋กœ์˜ ์œ ์ „์ž ๊ตฌ์กฐ๋ฅผ ๋” ์„ธ๋ฐ€ํ•˜๊ฒŒ ์กฐ์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
""")
seq_opt = gr.Radio(
["์ž๋™ ์„ค๊ณ„", "์ง์ ‘ ์ž…๋ ฅ"],
label="DNA ์„ค๊ณ„ ๋ฐฉ์‹",
value="์ž๋™ ์„ค๊ณ„"
)
sequence = gr.Textbox(
label="DNA ์‹œํ€€์Šค",
lines=1,
placeholder='์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ์•„๋ฏธ๋…ธ์‚ฐ: A,C,D,E,F,G,H,I,K,L,M,N,P,Q,R,S,T,V,W,Y (X๋Š” ๋ฌด์ž‘์œ„)',
visible=False
)
seq_len = gr.Slider(
minimum=5.0, maximum=250.0,
label="DNA ๊ธธ์ด",
value=100,
visible=True
)
with gr.Accordion(label='๐Ÿฆด ๊ณจ๊ฒฉ ๊ตฌ์กฐ ์„ค์ •', open=True):
gr.Markdown("""
ํžˆ์–ด๋กœ์˜ ๊ธฐ๋ณธ ๊ณจ๊ฒฉ ๊ตฌ์กฐ๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
- ๋‚˜์„ ํ˜• ๊ตฌ์กฐ: ์œ ์—ฐํ•˜๊ณ  ํƒ„๋ ฅ์žˆ๋Š” ์›€์ง์ž„
- ๋ณ‘ํ’ํ˜• ๊ตฌ์กฐ: ๋‹จ๋‹จํ•˜๊ณ  ๊ฐ•๋ ฅํ•œ ํž˜
- ๊ณ ๋ฆฌํ˜• ๊ตฌ์กฐ: ๋น ๋ฅด๊ณ  ๋ฏผ์ฒฉํ•œ ์›€์ง์ž„
""")
sec_str_opt = gr.Radio(
["์Šฌ๋ผ์ด๋”๋กœ ์„ค์ •", "์ง์ ‘ ์ž…๋ ฅ"],
label="๊ณจ๊ฒฉ ๊ตฌ์กฐ ์„ค์ • ๋ฐฉ์‹",
value="์Šฌ๋ผ์ด๋”๋กœ ์„ค์ •"
)
secondary_structure = gr.Textbox(
label="๊ณจ๊ฒฉ ๊ตฌ์กฐ",
lines=1,
placeholder='H:๋‚˜์„ ํ˜•, S:๋ณ‘ํ’ํ˜•, L:๊ณ ๋ฆฌํ˜•, X:์ž๋™์„ค์ •',
visible=False
)
with gr.Column():
helix_bias = gr.Slider(
minimum=0.0, maximum=0.05,
label="๋‚˜์„ ํ˜• ๊ตฌ์กฐ ๋น„์œจ",
visible=True
)
strand_bias = gr.Slider(
minimum=0.0, maximum=0.05,
label="๋ณ‘ํ’ํ˜• ๊ตฌ์กฐ ๋น„์œจ",
visible=True
)
loop_bias = gr.Slider(
minimum=0.0, maximum=0.20,
label="๊ณ ๋ฆฌํ˜• ๊ตฌ์กฐ ๋น„์œจ",
visible=True
)
with gr.Accordion(label='๐Ÿงฌ DNA ๊ตฌ์„ฑ ์„ค์ •', open=False):
gr.Markdown("""
ํŠน์ • ์•„๋ฏธ๋…ธ์‚ฐ์˜ ๋น„์œจ์„ ์กฐ์ ˆํ•˜์—ฌ ํžˆ์–ด๋กœ์˜ ํŠน์„ฑ์„ ๊ฐ•ํ™”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
์˜ˆ์‹œ: W0.2,E0.1 (ํŠธ๋ฆฝํ† ํŒ 20%, ๊ธ€๋ฃจํƒ์‚ฐ 10%)
""")
with gr.Row():
aa_bias = gr.Textbox(
label="์•„๋ฏธ๋…ธ์‚ฐ ๋น„์œจ",
lines=1,
placeholder='์˜ˆ์‹œ: W0.2,E0.1'
)
aa_bias_potential = gr.Textbox(
label="๊ฐ•ํ™” ์ •๋„",
lines=1,
placeholder='1.0-5.0 ์‚ฌ์ด ๊ฐ’ ์ž…๋ ฅ'
)
with gr.Accordion(label='๐ŸŒ ํ™˜๊ฒฝ ์ ์‘๋ ฅ ์„ค์ •', open=False):
gr.Markdown("""
ํžˆ์–ด๋กœ์˜ ํ™˜๊ฒฝ ์ ์‘๋ ฅ์„ ์กฐ์ ˆํ•ฉ๋‹ˆ๋‹ค.
์Œ์ˆ˜: ์ˆ˜์ค‘ ํ™œ๋™์— ํŠนํ™”, ์–‘์ˆ˜: ์ง€์ƒ ํ™œ๋™์— ํŠนํ™”
""")
with gr.Row():
hydrophobic_target_score = gr.Textbox(
label="ํ™˜๊ฒฝ ์ ์‘ ์ ์ˆ˜",
lines=1,
placeholder='์˜ˆ์‹œ: -5 (์ˆ˜์ค‘ ํ™œ๋™์— ํŠนํ™”)'
)
hydrophobic_potential = gr.Textbox(
label="์ ์‘๋ ฅ ๊ฐ•ํ™” ์ •๋„",
lines=1,
placeholder='1.0-2.0 ์‚ฌ์ด ๊ฐ’ ์ž…๋ ฅ'
)
with gr.Accordion(label='โš™๏ธ ๊ณ ๊ธ‰ ์„ค์ •', open=False):
gr.Markdown("""
DNA ์ƒ์„ฑ ๊ณผ์ •์˜ ์„ธ๋ถ€ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์กฐ์ •ํ•ฉ๋‹ˆ๋‹ค.
""")
with gr.Row():
num_steps = gr.Textbox(
label="์ƒ์„ฑ ๋‹จ๊ณ„",
lines=1,
placeholder='25 ์ดํ•˜ ๊ถŒ์žฅ'
)
noise = gr.Dropdown(
['normal','gmm2 [-1,1]','gmm3 [-1,0,1]'],
label='๋…ธ์ด์ฆˆ ํƒ€์ž…',
value='normal'
)
design_btn = gr.Button("๐Ÿงฌ DNA ์„ค๊ณ„ ์ƒ์„ฑ!", variant="primary", scale=2)
with gr.TabItem("๐Ÿงช ํžˆ์–ด๋กœ ์œ ์ „์ž ๊ฐ•ํ™”"):
gr.Markdown("""
### โšก ๊ธฐ์กด ํžˆ์–ด๋กœ์˜ DNA ํ™œ์šฉ
๊ฐ•๋ ฅํ•œ ํžˆ์–ด๋กœ์˜ DNA ์ผ๋ถ€๋ฅผ ์ƒˆ๋กœ์šด ํžˆ์–ด๋กœ์—๊ฒŒ ์ด์‹ํ•ฉ๋‹ˆ๋‹ค.
""")
gr.Markdown("๊ณต๊ฐœ๋œ ํžˆ์–ด๋กœ DNA ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์—์„œ ์ฝ”๋“œ๋ฅผ ์ฐพ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค")
pdb_id_code = gr.Textbox(
label="ํžˆ์–ด๋กœ DNA ์ฝ”๋“œ",
lines=1,
placeholder='๊ธฐ์กด ํžˆ์–ด๋กœ์˜ DNA ์ฝ”๋“œ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š” (์˜ˆ: 1DPX)'
)
gr.Markdown("์ด์‹ํ•˜๊ณ  ์‹ถ์€ DNA ์˜์—ญ์„ ์„ ํƒํ•˜๊ณ  ์ƒˆ๋กœ์šด DNA๋ฅผ ์ถ”๊ฐ€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค")
contigs = gr.Textbox(
label="์ด์‹ํ•  DNA ์˜์—ญ",
lines=1,
placeholder='์˜ˆ์‹œ: 15,A3-10,20-30'
)
with gr.Row():
seq_mask = gr.Textbox(
label='๋Šฅ๋ ฅ ์žฌ์„ค๊ณ„',
lines=1,
placeholder='์„ ํƒํ•œ ์˜์—ญ์˜ ๋Šฅ๋ ฅ์„ ์ƒˆ๋กญ๊ฒŒ ๋””์ž์ธ'
)
str_mask = gr.Textbox(
label='๊ตฌ์กฐ ์žฌ์„ค๊ณ„',
lines=1,
placeholder='์„ ํƒํ•œ ์˜์—ญ์˜ ๊ตฌ์กฐ๋ฅผ ์ƒˆ๋กญ๊ฒŒ ๋””์ž์ธ'
)
preview_viewer = gr.HTML()
rewrite_pdb = gr.File(label='ํžˆ์–ด๋กœ DNA ํŒŒ์ผ')
preview_btn = gr.Button("๐Ÿ” ๋ฏธ๋ฆฌ๋ณด๊ธฐ", variant="secondary")
enhance_btn = gr.Button("โšก ๊ฐ•ํ™”๋œ ํžˆ์–ด๋กœ ์ƒ์„ฑ!", variant="primary", scale=2)
with gr.TabItem("๐Ÿ‘‘ ํžˆ์–ด๋กœ ๊ฐ€๋ฌธ"):
gr.Markdown("""
### ๐Ÿฐ ์œ„๋Œ€ํ•œ ํžˆ์–ด๋กœ ๊ฐ€๋ฌธ์˜ ์œ ์‚ฐ
๊ฐ•๋ ฅํ•œ ํžˆ์–ด๋กœ ๊ฐ€๋ฌธ์˜ ํŠน์„ฑ์„ ๊ณ„์Šนํ•˜์—ฌ ์ƒˆ๋กœ์šด ํžˆ์–ด๋กœ๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
""")
with gr.Row():
with gr.Column():
gr.Markdown("ํžˆ์–ด๋กœ ๊ฐ€๋ฌธ์˜ DNA ์ •๋ณด๊ฐ€ ๋‹ด๊ธด ํŒŒ์ผ์„ ์—…๋กœ๋“œํ•˜์„ธ์š”")
fasta_msa = gr.File(label='๊ฐ€๋ฌธ DNA ๋ฐ์ดํ„ฐ')
with gr.Column():
gr.Markdown("์ด๋ฏธ ๋ถ„์„๋œ ๊ฐ€๋ฌธ ํŠน์„ฑ ๋ฐ์ดํ„ฐ๊ฐ€ ์žˆ๋‹ค๋ฉด ์—…๋กœ๋“œํ•˜์„ธ์š”")
input_pssm = gr.File(label='๊ฐ€๋ฌธ ํŠน์„ฑ ๋ฐ์ดํ„ฐ')
pssm = gr.File(label='๋ถ„์„๋œ ๊ฐ€๋ฌธ ํŠน์„ฑ')
pssm_view = gr.Plot(label='๊ฐ€๋ฌธ ํŠน์„ฑ ๋ถ„์„ ๊ฒฐ๊ณผ')
pssm_gen_btn = gr.Button("โœจ ๊ฐ€๋ฌธ ํŠน์„ฑ ๋ถ„์„", variant="secondary")
inherit_btn = gr.Button("๐Ÿ‘‘ ๊ฐ€๋ฌธ์˜ ํž˜ ๊ณ„์Šน!", variant="primary", scale=2)
# ์˜ค๋ฅธ์ชฝ ์—ด: ๊ฒฐ๊ณผ ํ‘œ์‹œ
with gr.Column(scale=1):
gr.Markdown("## ๐Ÿฆธโ€โ™‚๏ธ ํžˆ์–ด๋กœ ํ”„๋กœํ•„")
hero_stats = gr.Plot(label="๋Šฅ๋ ฅ์น˜ ๋ถ„์„")
hero_description = gr.Textbox(label="ํžˆ์–ด๋กœ ํŠน์„ฑ", lines=3)
gr.Markdown("## ๐Ÿงฌ ํžˆ์–ด๋กœ DNA ๋ถ„์„ ๊ฒฐ๊ณผ")
gr.Markdown("#### โšก DNA ์•ˆ์ •์„ฑ ์ ์ˆ˜")
plddt_plot = gr.Plot(label='์•ˆ์ •์„ฑ ๋ถ„์„')
gr.Markdown("#### ๐Ÿ“ DNA ์‹œํ€€์Šค")
output_seq = gr.Textbox(label="DNA ์„œ์—ด")
gr.Markdown("#### ๐Ÿ’พ DNA ๋ฐ์ดํ„ฐ")
output_pdb = gr.File(label="DNA ํŒŒ์ผ")
gr.Markdown("#### ๐Ÿ”ฌ DNA ๊ตฌ์กฐ")
output_viewer = gr.HTML()
# ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
# ์ฑ—๋ด‡ ์ด๋ฒคํŠธ
msg.submit(process_chat, [msg, chatbot], [chatbot])
clear.click(lambda: None, None, chatbot, queue=False)
# UI ์ปจํŠธ๋กค ์ด๋ฒคํŠธ
seq_opt.change(
fn=toggle_seq_input,
inputs=[seq_opt],
outputs=[seq_len, sequence],
queue=False
)
sec_str_opt.change(
fn=toggle_secondary_structure,
inputs=[sec_str_opt],
outputs=[helix_bias, strand_bias, loop_bias, secondary_structure],
queue=False
)
preview_btn.click(
get_motif_preview,
inputs=[pdb_id_code, contigs],
outputs=[preview_viewer, rewrite_pdb]
)
pssm_gen_btn.click(
get_pssm,
inputs=[fasta_msa, input_pssm],
outputs=[pssm_view, pssm]
)
# ์ฑ—๋ด‡ ๊ธฐ๋ฐ˜ ๋‹จ๋ฐฑ์งˆ ์ƒ์„ฑ ๊ฒฐ๊ณผ ์—…๋ฐ์ดํŠธ
def update_protein_display(chat_response):
if "์ƒ์„ฑ๋œ ๋‹จ๋ฐฑ์งˆ ๋ถ„์„" in chat_response:
params = extract_parameters_from_chat(chat_response)
result = generate_protein(params)
return {
hero_stats: create_radar_chart(calculate_hero_stats(params)),
hero_description: chat_response,
output_seq: result[0],
output_pdb: result[1],
output_viewer: display_pdb(result[1]),
plddt_plot: result[3]
}
return None
# ๊ฐ ์ƒ์„ฑ ๋ฒ„ํŠผ ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
for btn in [create_btn, design_btn, enhance_btn, inherit_btn]:
btn.click(
combined_generation,
inputs=[
hero_name, strength, flexibility, speed, defense, hero_size, special_ability,
sequence, seq_len, helix_bias, strand_bias, loop_bias,
secondary_structure, aa_bias, aa_bias_potential,
num_steps, noise, hydrophobic_target_score, hydrophobic_potential,
contigs, pssm, seq_mask, str_mask, rewrite_pdb
],
outputs=[
hero_stats,
hero_description,
output_seq,
output_pdb,
output_viewer,
plddt_plot
]
)
# ์ฑ—๋ด‡ ์‘๋‹ต์— ๋”ฐ๋ฅธ ๊ฒฐ๊ณผ ์—…๋ฐ์ดํŠธ
msg.submit(
update_protein_display,
inputs=[chatbot],
outputs=[hero_stats, hero_description, output_seq, output_pdb, output_viewer, plddt_plot]
)
chat_interface = gr.ChatInterface(
respond,
additional_inputs=[
system_message,
max_tokens,
temperature,
top_p,
],
chatbot=chatbot,
)
# ์‹คํ–‰
demo.queue()
demo.launch(debug=True)