import multiprocessing from functools import partial from pathlib import Path import gradio as gr import lightning import torch from datasets import Dataset from torch.utils.data import DataLoader import deepchopper from deepchopper.deepchopper import default, encode_qual, remove_intervals_and_keep_left, smooth_label_region from deepchopper.models.llm import ( tokenize_and_align_labels_and_quals, ) from deepchopper.utils import ( summary_predict, ) def parse_fq_record(text: str): """Parse a single FASTQ record into a dictionary.""" lines = text.strip().split("\n") for i in range(0, len(lines), 4): content = lines[i : i + 4] record_id, seq, _, qual = content assert len(seq) == len(qual) # noqa: S101 yield { "id": record_id, "seq": seq, "qual": encode_qual(qual, default.KMER_SIZE), "target": [0, 0], } def load_dataset(text: str, tokenizer): """Load dataset from text.""" dataset = Dataset.from_generator(parse_fq_record, gen_kwargs={"text": text}).with_format("torch") tokenized_dataset = dataset.map( partial( tokenize_and_align_labels_and_quals, tokenizer=tokenizer, max_length=tokenizer.max_len_single_sentence, ), num_proc=multiprocessing.cpu_count(), # type: ignore ).remove_columns(["id", "seq", "qual", "target"]) return dataset, tokenized_dataset def predict( text: str, smooth_window_size: int = 21, min_interval_size: int = 13, approved_interval_number: int = 20, max_process_intervals: int = 8, # default is 4 batch_size: int = 1, num_workers: int = 1, ): tokenizer = deepchopper.models.llm.load_tokenizer_from_hyena_model(model_name="hyenadna-small-32k-seqlen") dataset, tokenized_dataset = load_dataset(text, tokenizer) dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, num_workers=num_workers, persistent_workers=True) model = deepchopper.DeepChopper.from_pretrained("yangliz5/deepchopper") accelerator = "gpu" if torch.cuda.is_available() else "cpu" trainer = lightning.pytorch.trainer.Trainer( accelerator=accelerator, deterministic=False, logger=False, ) predicts = trainer.predict(model=model, dataloaders=dataloader, return_predictions=True) assert len(predicts) == 1 # noqa: S101 smooth_interval_json: list[dict[str, int]] = [] highlighted_text: list[tuple[str, str | None]] = [] for idx, preds in enumerate(predicts): true_prediction, _true_label = summary_predict(predictions=preds[0], labels=preds[1]) _id = dataset[idx]["id"] seq = dataset[idx]["seq"] smooth_predict_targets = smooth_label_region( true_prediction[0], smooth_window_size, min_interval_size, approved_interval_number ) if not smooth_predict_targets or len(smooth_predict_targets) > max_process_intervals: continue # zip two consecutive elements _selected_seqs, selected_intervals = remove_intervals_and_keep_left(seq, smooth_predict_targets) total_intervals = sorted(selected_intervals + smooth_predict_targets) smooth_interval_json.extend({"start": i[0], "end": i[1]} for i in smooth_predict_targets) highlighted_text.extend( (seq[interval[0] : interval[1]], "ada" if interval in smooth_predict_targets else None) for interval in total_intervals ) return smooth_interval_json, highlighted_text def process_input(text: str | None, file: str | None): """Process the input and return the prediction.""" if not text and not file: gr.Warning("Both text and file are empty") if file: MAX_LINES = 4 file_content = [] with Path(file).open() as f: for idx, line in enumerate(f): if idx >= MAX_LINES: break file_content.append(line) text = "".join(file_content) return predict(text=text) return predict(text=text) def create_gradio_app(): """Create a Gradio app for DeepChopper.""" example = ( "@1065:1135|393d635c-64f0-41ed-8531-12174d8efb28+f6a60069-1fcf-4049-8e7c-37523b4e273f\n" "GCAGCTATGAATGCAAGGCCACAAGGTGGATGGAAGAGTTGTGGAACCAAAGAGCTGTCTTCCAGAGAAGATTTCGAGATAAGTCGCCCATCAGTGAACAAGATATTGTTGGTGGCATTTGATGAGAACGTTCCAAGATTATTGACAGATTAGTGAAAAGTAAGATTGAAATCATGACTGACCGTAAGTGGCAAGAAAGGGCTTTTGCCTTTGTAACCTTTGACGACCATGACTCCGTGGATAAGATTGTCATTCAGAATACCATACTGTGAATGGCCACATCTTTATTGTGAAGTTAGAAAAGCCCTGTCAAAGCAAGAGATGAATCAGTGCTTCTCCAGCCAAAGAGGTCGAAGTGGTTCTGGAAACTTTGGTGGTGGTCGTGGAGGTGGTTTCGGTGGGAATGACAACTCGGTCGTGGAGGAAACTTCAGTGGTCGTGGTGGCTTTGGTGGCAGCCGTGGTGGTGGTGGATATGGTGGCAGTGGGGATGGCTATAATGGATTTGGTAATGATGGAAGCAATTTGGAGGTGGTGGAAGCTACAATGATTTTGGGAATTACAACAATCAGTCTTCAAATTTTGGACCCCTAGGAGGAAATTTTGGTAGAAGCTCTGGCCCCATGGCGGTGGAGGCCAAATACTTTTGCAAACCACGAAACCAAGGTGGCTATGGCGGTCCAGCAGCAGCAGTAGCTATGGCAGTGGCAGAAGATTTTAATTAGGAAACAAAGCTTAGCAGGAGAGGAGAGCCAGAGAAGTGACAGGGAAGTACAGGTTACAACAGATTTGTGAACTCAGCCCAAGCACAGTGGTGGCAGGGCCTAGCTGCTACAAAGAAGACATGTTTTAGACAAATACTCATGTGTATGGGCAAAACTTGAGGACTGTATTTGTGACTAACTGTATAACAGGTTATTTTAGTTTCTGTTTGTGGAAAGTGTAAAGCATTCCAACAAAGGTTTTTAATGTAGATTTTTTTTTTTGCACCCCATGCTGTTGATTTGCTAAATGTAACAGTCTGATCGTGACGCTGAATAAATGTCTTTTTTAAAAAAAAAAAAAAGCTCCCTCCCATCCCCTGCTGCTAACTGATCCCATTATATCTAACCTGCCCCCCCATATCACCTGCTCCCGAGCTACCTAAGAACAGCTAAAAGAGCACACCCGCATGTAGCAAAATAGTGGGAAGATTATAGGTAGAGGCGACAAACCTACCGAGCCTGGTGATAGCTGGTTGTCCTAGATAGAATCTTAGTTCAACTTTAAATTTGCCCACAGAACCCTCTAAATCCCCTTGTAAATTTAACTGTTAGTCCAAAGAGGAACAGCTCTTTGGACACTAGGAAAAAACCTTGTAGAGAGTAAAAAATCAACACCCA\n" "+\n" ".0==?SSSSSSSSSSSH2216<868;SSSSSSSSSQQSRSIIHEDDESSSSSSJIKMGEKISSJJICCBDQ?;;8:;,**(&$'+501)\"#$()+%&&0<5+*/('%'))))'''$##\"\"\"\"%&--$\"\"\"('%)1L3*'')'#\"#&+*$&\"\"#*(&'''+,,<;9=ANOSJDCADHGMOQSSD=7>BRRSPIEEEOQSSQ4->LIC7EE045///03IIJQSSSNGE6('.5??@A@=,,EGRSPKJ<==<556GFLLQRANSSSSSSSSG...*%%%(***(%'3@LOOSSSSM...7BCMMSSSSSSSSSSSSSSSDFIPSSSGGGGPOQLIHIL4103HMSILLNOSSSSSSSSSS22CBCGSHHHHSSSSSSSSD??@<<<:DDDSSSSSSSSSSA@6688OSSSSSROJJKLSNNNMSSSSQPOOSOOQSSSSSRRHIHISSRSSSSSSSSSSSJFF=??@SSQRK:424<444FFG///1S@@@ASNNNNPN:4JMDDLPSSSSSSBA?B?@@+'&'BD**8EDEFQPIMLE$$&',79CSJJPSGA+***DN;3-('&(;>6(()/-,,)%')1FRNNJ-:=>GC;&;CHNFFDCEEKJLFA22/27A.....HSQLHL))8<=?JSSSFGSKIHDDCCEFDAA@CFJKLNL>:9/1>>?OSLK@+HPSA;>>>K;;;;SSSSOQLPPMORSSSSSQSSSSSSS=:9**?D889SSRFFEDKJJJEEDKSSSNNOSSS.---,&*++SSSSQRSSSSQPGED<<89<@GJ999:SSKBBBAJHK=SSSJJKNMGHKKHQA<<>OPKFEAACDHJKMORB/)'((6**)15DA99;JSQSSS2())+J))EGMQOMMKJF>?<@4,3;B>87FSSSSSSSSSSSSSSS<<::5658@AHMMSSRECC448/=<<>SSCB:5546;=>BJHINJFDPPPPPPPPPPPPPP%'*%$%+-%'(-22&&%('''&&&#\"\"%&'+0,,0;:1&\"\"%'(+++8'**(\"$$#&$'**//.3497$\"3CFHLOSSSSR:887:;;FSSRPRSSS4433$#$%&$$-056>@:;>=@?AHEFEC;*EKMSSRSRRDB>=AFRSSSSBSOOPSMDAABHH976951-9DHPQO/---?@ELSSQSRJHKKBKKLSSLINSOSSQSRIMSSSSSS>?MKIINSSGSSSSSSSQQMK544MJKKNKHGGLFFGBDB?EHIKGD?@DHPPIIF555)&(+,ADSSSSRQSSSQSS=9/0JJMSQSOSSO/97=B@=:>" ) custom_css = """ .header { text-align: center; margin-bottom: 30px; } .footer { text-align: center; margin-top: 30px; font-size: 0.8em; color: #666; } """ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo: gr.HTML( """

🧬 DeepChopper: DNA Sequence Analysis

Analyze DNA sequences and detect Adapter sequences

""" ) with gr.Row(): with gr.Column(scale=1): text_input = gr.Textbox( label="Input DNA Sequence", placeholder="Paste your DNA sequence here...", lines=10 ) file_input = gr.File(label="Or upload a FASTQ file") submit_btn = gr.Button("Analyze", variant="primary") with gr.Column(scale=1): json_output = gr.JSON(label="Detected Adapter Regions") highlighted_text = gr.HighlightedText(label="Highlighted Sequence") submit_btn.click(fn=process_input, inputs=[text_input, file_input], outputs=[json_output, highlighted_text]) gr.Examples( examples=[[example]], inputs=[text_input], ) gr.HTML( """ """ ) return demo def main(): """Launch the Gradio app.""" app = create_gradio_app() app.launch() if __name__ == "__main__": main()