# visualization of how chromatin conformation changes depending of the expression of one gene

In [1]:
import pandas as pd
import numpy as np
import hicstraw 
from multiprocessing import Pool
from functools import partial
import glob
import os
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import math
import matplotlib.pyplot as plt
from matplotlib import colors, cm
from pandarallel import pandarallel
import cooler
import cooltools
import pybedtools as pbed
pandarallel.initialize()
from scipy import stats, special
from statsmodels.stats import multitest
import statsmodels.api as sm
import statsmodels.formula.api as smf
import plotly.io as pio
import seaborn as sns
import numba as nb
import vcf
import bioframe
os.makedirs("/mnt/iusers01/jw01/mdefscs4/scratch/temp_pybedtools/", exist_ok = True)
pbed.helpers.set_tempdir("/mnt/iusers01/jw01/mdefscs4/scratch/temp_pybedtools/")
bed_genome_file = "/mnt/iusers01/jw01/mdefscs4/hg38.genome"

plt.rcParams['svg.fonttype'] = 'none'

base_dir = "/mnt/jw01-aruk-home01/projects/psa_functional_genomics/PsA_cleaned_analysis"

from hic_corr_plot import get_region_for_all, retrieve_stats_scipy, myfloor, myceil, add_numbers

INFO: Pandarallel will run on 16 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.
INFO: Pandarallel will run on 16 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


In [2]:
metadata_hic = pd.read_csv(f"{base_dir}/metadata/cleaned_HiC_metadata.csv", index_col = 0)

gtf_annotation_df = pd.read_pickle(f"{base_dir}/gencode_gtf.pickle")
gtf_transcripts = gtf_annotation_df[(gtf_annotation_df["feature"] == "transcript") & (gtf_annotation_df["transcript_type"] == "protein_coding")].dropna(axis=1, how='all')
gtf_transcripts["gene_id"] = gtf_transcripts["gene_id"].str.split(".").str[0]
gtf_transcripts["transcript_id"] = gtf_transcripts["transcript_id"].str.split(".").str[0]
gtf_transcripts["TSS_start"] = gtf_transcripts.apply(lambda x: int(x["start"]) if x["strand"] == "+" else int(x["end"]) ,axis = 1)
gtf_genes = gtf_annotation_df[(gtf_annotation_df["feature"] == "gene") & (gtf_annotation_df["gene_type"] == "protein_coding")].dropna(axis=1, how='all')
gtf_genes["gene_id"] = gtf_genes["gene_id"].str.split(".").str[0]
gtf_genes["TSS_start"] = gtf_genes.apply(lambda x: int(x["start"]) if x["strand"] == "+" else int(x["end"]) ,axis = 1)

In [3]:
normalized_counts_new = pd.read_csv(f"{base_dir}/RNA_seq_analysis/RNA_normalized_counts.csv")
metadata_RNA = pd.read_csv(f"{base_dir}/metadata/cleaned_RNA_metadata.csv", index_col=0)
column_name_dict = dict(zip(metadata_RNA['sample'], metadata_RNA['proper_name']))
normalized_counts_new = normalized_counts_new.rename(columns=column_name_dict)
normalized_counts_melted = pd.melt(normalized_counts_new, id_vars=["ensembl","ENSG","symbol","genename","entrez"], 
        value_vars=normalized_counts_new.columns.difference(["ensembl","ENSG","symbol","genename","entrez"]),
        var_name="sample",value_name="expression")
normalized_counts_melted = normalized_counts_melted.merge(metadata_RNA[["patient","cell_type","condition","proper_name"]], left_on = "sample", right_on = "proper_name")

In [4]:
ATAC_normalised_counts = pd.read_csv(f"{base_dir}/ATAC_seq_analysis/ATAC_DESeq2_quantile_normalized_counts.csv", index_col = 0)
metadata_ATAC = pd.read_csv(f"{base_dir}/metadata/cleaned_ATAC_metadata.csv", index_col=0)
ATAC_normalised_counts_melted = pd.melt(ATAC_normalised_counts, id_vars=["CHR","START","END"], 
        value_vars=ATAC_normalised_counts.columns.difference(["CHR","START","END"]),
        var_name="sample",value_name="peak_height")
ATAC_normalised_counts_melted = ATAC_normalised_counts_melted.merge(metadata_ATAC[["id","patient","cell_type","condition","proper_name"]], left_on = "sample", right_on = "id", suffixes = ["", "_y"])
column_name_dict = dict(zip(metadata_ATAC['id'], metadata_ATAC['proper_name']))
ATAC_normalised_counts_renamed = ATAC_normalised_counts.rename(columns=column_name_dict)

## main function to call to get the figure

In [7]:
import requests

def create_figure(gene, cell_type = ["CD8","CD4","CD8_SF","CD4_SF"], resolution = 5000, range_include = 500000, region = None):
    
    lab = gtf_genes[gtf_genes["gene_name"] == gene].iloc[0]
    chrom = str(lab["seqname"])[3:]
    if region:
        region_start = myfloor(region[0], resolution)
        region_end = myceil(region[1], resolution)
    else:
        region_start = myfloor(lab["TSS_start"] - range_include, resolution)
        region_end = myceil(lab["TSS_start"] + range_include, resolution)

    # request samples from matrix
    subset_samples = metadata_hic[metadata_hic["proper_name"].isin(metadata_RNA["proper_name"]) & metadata_hic["cell_type"].isin(cell_type)] # only do the samples that have RNA-seq  & (downloaded_hic["cell_type"].isin(["CD4","CD4_SF"]))
    required_samples = subset_samples["folder_name"].to_list()
    required_samples_proper = subset_samples["proper_name"].to_list()
    data, orig_shape, matrix_full = get_region_for_all(required_samples,chrom,region_start,region_end,resolution = resolution)
    # prepare data for correlation
    data_for_loop = normalized_counts_melted[normalized_counts_melted["symbol"] == gene]
    y = [float(data_for_loop[data_for_loop["sample"]==x]["expression"]) for x in required_samples_proper]
    X = np.stack(data)
    # run correlation
    results_scipy = retrieve_stats_scipy(X, y) 
    # get for ATAC correlation analysis
    from scipy.stats import pearsonr
    valid_samples_ATAC = metadata_ATAC[metadata_ATAC["cell_type"].isin(cell_type) & metadata_ATAC["proper_name"].isin(metadata_RNA["proper_name"])]["proper_name"]
    valid_ATAC_peaks = ATAC_normalised_counts_renamed[(ATAC_normalised_counts_renamed["CHR"] == "chr" + chrom) & (ATAC_normalised_counts_renamed["START"].between(region_start, region_end))].copy()
    valid_ATAC_peaks = valid_ATAC_peaks.set_index("START")[valid_samples_ATAC]
    expression_levels = normalized_counts_new[normalized_counts_new["symbol"] == gene].copy()[valid_samples_ATAC]
    # loop through every column of the dataframe and calculate its correlation with the target
    corr_coefficients = []
    for col in valid_ATAC_peaks.index.to_list():
        corr, _ = pearsonr(valid_ATAC_peaks.T[col], expression_levels.T)
        corr_coefficients.append(corr[0])

    # create a new dataframe with the correlation coefficients and the column names
    corr_peaks_df = pd.DataFrame({'STARTs': valid_ATAC_peaks.index.to_list(), 'corr_coefficient': corr_coefficients, "mean_expr": valid_ATAC_peaks.mean(axis = 1)})

    corr_peaks_df["converted_start"] = ((corr_peaks_df["STARTs"]-region_start) / resolution) - 0.5 + 0.02

    # make figure
    fig = make_subplots(rows=3, cols=2, shared_xaxes="all", shared_yaxes="rows", 
        vertical_spacing=0.02, row_width=[0.7, 0.12, 0.17],horizontal_spacing=0.03,
        subplot_titles=("ATAC peaks", "", "Genes", "", "Pearson's correlation","Hi-C"))

    # add ATAC-seq corr
    fig.add_trace(go.Scatter(x=corr_peaks_df["converted_start"], 
        y = corr_peaks_df["corr_coefficient"], mode="markers",
        hovertext = corr_peaks_df["STARTs"], 
        marker_color = corr_peaks_df["mean_expr"], marker=dict(cmin=0, cmax = 200, colorscale = "Reds")),
        row=1, col=1)
    fig.add_trace(go.Scatter(x=corr_peaks_df["converted_start"], 
        y = corr_peaks_df["corr_coefficient"], mode="markers",
        hovertext = corr_peaks_df["STARTs"], 
        marker_color = corr_peaks_df["mean_expr"], marker=dict(cmin=0, cmax = 200, colorscale = "Reds")),
        row=1, col=2)
    
    # find genes
    local_genes = gtf_genes[(gtf_genes["seqname"] == lab["seqname"]) &
        (gtf_genes["end"].astype(int) > region_start) &
        (gtf_genes["start"].astype(int) < region_end)].copy()
    local_genes["start"] = local_genes["start"].astype(float)
    local_genes["end"] = local_genes["end"].astype(float)
    local_genes["converted_start"] = ((local_genes["start"]-region_start) / resolution) - 0.5
    local_genes["converted_end"] = ((local_genes["end"]-region_start) / resolution) - 0.5
    local_genes["converted_TSS"] = ((local_genes["TSS_start"]-region_start) / resolution) - 0.5

    local_transcripts = gtf_transcripts[(gtf_transcripts["seqname"] == lab["seqname"]) &
        (gtf_transcripts["start"].astype(int) > region_start) &
        (gtf_transcripts["end"].astype(int) < region_end)].copy()
    local_transcripts["converted_TSS"] = ((local_transcripts["TSS_start"]-region_start) / resolution) - 0.5
    local_transcripts,local_genes = add_numbers(local_transcripts.sort_values("TSS_start"), local_genes)
    local_genes["converted_start"] = local_genes["converted_start"].clip(0,orig_shape[0])
    local_genes["converted_end"] = local_genes["converted_end"].clip(0,orig_shape[0])
    local_genes["converted_TSS"] = local_genes["converted_TSS"].clip(0,orig_shape[0])
    local_transcripts["converted_TSS"] = local_transcripts["converted_TSS"].clip(0,orig_shape[0])
    # add square on the TSS of the gene
    fig.add_shape(type="rect",
        xref="x1", yref="y1",
        x0=((lab["TSS_start"]-region_start) / resolution) - 0.5 - 0.05, y0=-1,
        x1=((lab["TSS_start"]-region_start) / resolution) - 0.5 + 0.05, y1=1, layer="below",
        fillcolor="Blue",line_width=0, opacity = 0.6
    )
    fig.add_shape(type="rect",
        xref="x2", yref="y2",
        x0=((lab["TSS_start"]-region_start) / resolution) - 0.5 - 0.05, y0=-1,
        x1=((lab["TSS_start"]-region_start) / resolution) - 0.5 + 0.05, y1=1, layer="below",
        fillcolor="Blue",line_width=0, opacity = 0.6
    )

    # create markers for genes list with transcripts
    fig.add_trace(go.Scatter(x=local_transcripts["converted_TSS"], 
        y = local_transcripts["position"], mode="markers",textposition="top center", marker_color = "red",
        hovertext=local_transcripts["gene_name"]), row=2, col=1)
    fig.add_trace(go.Scatter(x=local_transcripts["converted_TSS"], 
        y = local_transcripts["position"], mode="markers",textposition="top center", marker_color = "red",
        hovertext=local_transcripts["gene_name"]), row=2, col=2)
    fig.add_trace(go.Scatter(x=local_genes["converted_TSS"], 
        y = local_genes["position"], mode="markers+text",textposition="top center", marker_color = "red",
        text=local_genes["gene_name"]), row=2, col=1)
    fig.add_trace(go.Scatter(x=local_genes["converted_TSS"], 
        y = local_genes["position"], mode="markers+text",textposition="top center", marker_color = "red",
        text=local_genes["gene_name"]), row=2, col=2)

    # add shapes for the genes
    for idx,row in local_genes.iterrows():
        fig.add_shape(type="rect",
            xref="x3", yref="y3",
            x0=row["converted_start"], y0=row["position"]-0.2,
            x1=row["converted_end"], y1=row["position"]+0.8, layer="below",
            fillcolor="pink",line_width=0, opacity = 0.6
        )
        fig.add_shape(type="rect",
            xref="x4", yref="y4",
            x0=row["converted_start"], y0=row["position"]-0.2,
            x1=row["converted_end"], y1=row["position"]+0.8, layer="below",
            fillcolor="pink",line_width=0, opacity = 0.6
        )

    # heatmaps
    fig_1 = go.Heatmap(z = np.rot90(results_scipy[2,:].reshape(orig_shape)), colorscale = "RdBu_r", zmin = -0.8, zmax = 0.8)
    if resolution == 5000:
        fig_2 = go.Heatmap(z = np.rot90(matrix_full), colorscale = "Reds", zmin = 0, zmax = 1500)
    else:
        fig_2 = go.Heatmap(z = np.rot90(matrix_full), colorscale = "Reds", zmin = 0, zmax = 5000)   
    fig.add_trace(fig_1, row=3, col=1)
    fig.add_trace(fig_2, row=3, col=2)

    # update layouts to fit figure correctly
    fig.update_layout(height=950, width=1250, title_text=f"Hi-C correlation with expression of {gene}")
    fig.update_layout(yaxis5 = dict(scaleanchor = 'x5'))
    fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')
    fig.update_traces(row = 3, showscale=False)
    fig.update_layout(yaxis1 = dict(fixedrange = True, gridcolor = "grey", zeroline = False))
    fig.update_layout(yaxis3 = dict(fixedrange = True, showgrid=False, showticklabels = False))
    fig.update_layout(yaxis2 = dict(gridcolor = "grey", zeroline = False))
    fig.update_layout(yaxis4 = dict(showgrid=False))
    fig.update_layout(xaxis1 = dict(showgrid=False, zeroline = False))
    fig.update_layout(xaxis3 = dict(showgrid=False))
    fig.update_layout(xaxis2 = dict(showgrid=False, zeroline = False))
    fig.update_layout(xaxis4 = dict(showgrid=False))
    fig.update_layout(showlegend=False)
    fig.layout.annotations[0].update(x=0, xanchor = "left", font = dict(size = 12))
    fig.layout.annotations[1].update(x=0, xanchor = "left", font = dict(size = 12))
    fig.layout.annotations[2].update(x=0, xanchor = "left", font = dict(size = 12))
    fig.layout.annotations[3].update(x=0.55, xanchor = "left", font = dict(size = 12))
    fig.update_xaxes(showspikes=True, spikemode='across',spikesnap="cursor")
    fig.update_yaxes(showspikes=True, spikemode='across',spikesnap="cursor")

    return fig

In [12]:
create_figure("ANKRD55", range_include = 600000, region = (55788423,56382284)).write_image(f"{base_dir}/integration_analysis/figures/ANKRD55_HiC_ALL.svg")
create_figure("ANKRD55", range_include = 600000, region = (55788423,56382284), cell_type =  ["CD8", "CD8_SF"]).write_image(f"{base_dir}/integration_analysis/figures/ANKRD55_HiC_CD8.svg")
create_figure("ANKRD55", range_include = 600000, region = (55788423,56382284), cell_type =  ["CD4", "CD4_SF"]).write_image(f"{base_dir}/integration_analysis/figures/ANKRD55_HiC_CD4.svg")
create_figure("IL6ST", range_include = 600000, region = (55788423,56382284)).write_image(f"{base_dir}/integration_analysis/figures/IL6ST_HiC_ALL.svg")
create_figure("IL6ST", range_include = 600000, region = (55788423,56382284), cell_type =  ["CD8", "CD8_SF"]).write_image(f"{base_dir}/integration_analysis/figures/IL6ST_HiC_CD8.svg")
create_figure("IL6ST", range_include = 600000, region = (55788423,56382284), cell_type =  ["CD4", "CD4_SF"]).write_image(f"{base_dir}/integration_analysis/figures/IL6ST_HiC_CD4.svg")

In [8]:
# plot of the region for figure 2
fig = create_figure("IL7R", resolution = 5000, range_include = 300000)
fig.show()
fig.write_image(f"{base_dir}/integration_analysis/figures/IL7R_for_paper.svg")

In [14]:
create_figure("ANKRD55")