Smith-Waterman: DP-filling Visualization

calico·2025년 10월 30일

Artificial Intelligence

목록 보기
92/143


# Smith–Waterman (Local) only — DP visualization + traceback
# Requirements: numpy, matplotlib (pip install numpy matplotlib)

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation

# Notebook-friendly embed (safe if not in IPython)
_IN_IPY = False
try:
    from IPython.display import HTML, display
    plt.rcParams["animation.html"] = "jshtml"
    _IN_IPY = True
except Exception:
    pass

# -----------------------------
# Scoring (defaults)
# -----------------------------
MATCH = 1
MISMATCH = -1
GAP = -2

# -----------------------------
# Helper (heatmap renderer)
# -----------------------------
def draw_heatmap(ax, S, seq1, seq2, highlight=None, path=None,
                 show_candidates=True, match=MATCH, mismatch=MISMATCH, gap=GAP):
    ax.clear()
    im = ax.imshow(S, cmap='Blues')
    ax.set_title("Smith–Waterman (Local) — DP filling")
    ax.set_xlabel("Y (with leading '-')")
    ax.set_ylabel("X (with leading '-')")

    ylabels = ['-'] + list(seq1)
    xlabels = ['-'] + list(seq2)
    ax.set_xticks(np.arange(len(xlabels)))
    ax.set_yticks(np.arange(len(ylabels)))
    ax.set_xticklabels(xlabels)
    ax.set_yticklabels(ylabels)

    ax.set_xticks(np.arange(-.5, len(xlabels), 1), minor=True)
    ax.set_yticks(np.arange(-.5, len(ylabels), 1), minor=True)
    ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5)
    ax.tick_params(which='minor', bottom=False, left=False)

    for i in range(S.shape[0]):
        for j in range(S.shape[1]):
            ax.text(j, i, str(S[i, j]), ha='center', va='center', fontsize=9)

    if path:
        ys = [p[0] for p in path]
        xs = [p[1] for p in path]
        ax.plot(xs, ys, marker='o', markersize=3, linewidth=1.5, color='orange')

    if highlight is not None:
        i, j = highlight
        if 0 <= i < S.shape[0] and 0 <= j < S.shape[1]:
            rect = plt.Rectangle((j-0.5, i-0.5), 1, 1, fill=False, edgecolor='red', linewidth=2)
            ax.add_patch(rect)
    return im

# -----------------------------
# Smith–Waterman DP filling
# -----------------------------
def sw_dp_steps(seq1, seq2, match=MATCH, mismatch=MISMATCH, gap=GAP):
    """
    Smith–Waterman (Local) — yields (S_copy, (i,j)) per fill.
    Init: first row/col = 0.
    Recurrence: max(0, diag, up, left).
    """
    m, n = len(seq1), len(seq2)
    S = np.zeros((m+1, n+1), dtype=int)
    yield (S.copy(), (0, 0))
    for i in range(1, m+1):
        for j in range(1, n+1):
            s = match if seq1[i-1] == seq2[j-1] else mismatch
            diag = S[i-1, j-1] + s
            up   = S[i-1, j]   + gap
            left = S[i, j-1]   + gap
            S[i, j] = max(0, diag, up, left)
            yield (S.copy(), (i, j))
    yield (S.copy(), None)

# -----------------------------
# Animation (SW only)
# -----------------------------
def animate_sw(seq1, seq2, match=MATCH, mismatch=MISMATCH, gap=GAP,
               interval=120, save_path=None):
    """
    Animate DP filling for Smith–Waterman. Returns (anim, final_matrix).
    """
    fig, ax = plt.subplots(figsize=(8, 6))
    gen = sw_dp_steps(seq1, seq2, match, mismatch, gap)
    all_states = [state for state in gen]

    def init():
        S, hl = all_states[0]
        draw_heatmap(ax, S, seq1, seq2, highlight=hl,
                     show_candidates=False, match=match, mismatch=mismatch, gap=gap)
        return []

    def update(k):
        S, hl = all_states[k]
        draw_heatmap(ax, S, seq1, seq2, highlight=hl,
                     show_candidates=False, match=match, mismatch=mismatch, gap=gap)
        return []

    anim = animation.FuncAnimation(fig, update, init_func=init,
                                   frames=len(all_states), interval=interval,
                                   blit=False, repeat=False)

    if save_path:
        if save_path.lower().endswith('.gif'):
            anim.save(save_path, writer='pillow', fps=max(1, int(1000/interval)))
            plt.close(fig)
        else:
            anim.save(save_path, writer='ffmpeg', fps=max(1, int(1000/interval)))
            plt.close(fig)
        print(f"Saved animation to: {save_path}")
    else:
        if _IN_IPY:
            display(HTML(anim.to_jshtml()))
        else:
            plt.show()

    return anim, all_states[-1][0]

# -----------------------------
# Final view + Traceback (SW only)
# -----------------------------
def show_sw_with_traceback(seq1, seq2, match=MATCH, mismatch=MISMATCH, gap=GAP):
    """
    최종 DP + traceback path + 정렬 문자열 출력 (SW).
    시작지점: 행렬 최댓값. 0을 만나면 종료.
    """
    m, n = len(seq1), len(seq2)
    H = np.zeros((m+1, n+1), dtype=int)

    # Fill DP
    for i in range(1, m+1):
        for j in range(1, n+1):
            s  = match if seq1[i-1] == seq2[j-1] else mismatch
            H[i, j] = max(0, H[i-1, j-1] + s, H[i-1, j] + gap, H[i, j-1] + gap)

    start_i, start_j = np.unravel_index(np.argmax(H), H.shape)

    # Traceback
    i, j = start_i, start_j
    a1, a2 = [], []
    path = [(i, j)]
    while i > 0 and j > 0 and H[i, j] > 0:
        s = match if seq1[i-1] == seq2[j-1] else mismatch
        if H[i, j] == H[i-1, j-1] + s:
            a1.append(seq1[i-1]); a2.append(seq2[j-1]); i -= 1; j -= 1
        elif H[i, j] == H[i-1, j] + gap:
            a1.append(seq1[i-1]); a2.append('-'); i -= 1
        elif H[i, j] == H[i, j-1] + gap:
            a1.append('-'); a2.append(seq2[j-1]); j -= 1
        else:
            break
        path.append((i, j))

    a1.reverse(); a2.reverse()

    # visualize
    fig, ax = plt.subplots(figsize=(8, 6))
    draw_heatmap(ax, H, seq1, seq2, highlight=None, path=path,
                 show_candidates=False, match=match, mismatch=mismatch, gap=gap)
    score = H[start_i, start_j]
    ax.set_title(f"SW(Local)  |  Score={int(score)}\n{''.join(a1)}\n{''.join(a2)}")
    plt.show()

    return ''.join(a1), ''.join(a2), int(score)

# -----------------------------
# Example run (SW)
# -----------------------------
if __name__ == "__main__":
    x = "GGTTGACTA"
    y = "TGTTACGG"
    anim_obj, final_H = animate_sw(x, y, match=3, mismatch=-3, gap=-2, interval=90)
    a1, a2, score = show_sw_with_traceback(x, y, match=3, mismatch=-3, gap=-2)
    print("[SW] Score:", score)
    print(a1); print(a2)

profile
https://velog.io/@corone_hi/posts

0개의 댓글