# library(MASS)

# given an absolute position and a set of segment lengths, it returns the segment id and the relative position in that segment as a vector with length 2. Positions are one-based.
abspos2relpos <- function(pos, lengths) {
#  cat("hey!\n")
#  print(pos)
#  print(lengths)
  result <- NULL
  if (length(pos) > 1) {
    for (i in 1:length(pos)) {
      result[i] <- abspos2relpos(pos[i], lengths)
    }
   stopifnot(length(result) == length(pos))
  } else {
  stopifnot(pos > 0)
  sum <- 0 # current offset
  sums <- rep(0, length(lengths)+1)
  for (i in 1:length(lengths)) {
    sums[i+1] <- sums[i] + lengths[i]
  }
#  cat("Sums:\n")
#  print(sums)
  result <- c(0,0)
  names(result) <- c("segment", "pos")
  for (i in 2:length(sums)) {
#   cat("Testing sum: ", sums[i], "\n")
   if (sums[i] >= pos) { # (pos - sum > lengths[i]) { # ((sum + lengths[i]) >= pos) {
#       cat("jep: ", pos, sums[i], sums[i-1], pos-sums[i-1], "\n")
       result[1] <- i - 1
       result[2] <- pos - sums[i-1] 
       stopifnot(result[2] >= 1)
#       stopifnot(result[2] <= lengths[i-1])
       break
   }
   sum <- sum + lengths[i] 
  }
  if (result[1] == 0) {
    stopifnot(pos <= sum) # otherwise illegal position
    result[1] = length(lengths)
    result[2] = pos - sum
  }
  stopifnot(result[1] > 0)
#  stopifnot(result[2] <= lengths[result[1]])
  }
  result  
}


# given an absolute position and a set of segment lengths, it returns the segment id and the relative position in that segment as a vector with length 2. Positions are one-based.
abspos2relpos2 <- function(pos, absStarts) {
  cat("hey!\n")
  print(pos)
  print(absStarts)
  result <- NULL
  if (length(pos) > 1) {
    for (i in 1:length(pos)) {
      result[i] <- abspos2relpos(pos[i], absStarts)
    }
   stopifnot(length(result) == length(pos))
  } else {
  stopifnot(pos > 0)
  sum <- 0 # current offset
#  sums <- rep(0, length(lengths)+1)
#  for (i in 1:length(lengths)) {
#    sums[i+1] <- sums[i] + lengths[i]
#  }
#  cat("Sums:\n")
#  print(sums)
  sums <- absStarts
  offsets <- absStarts - 1
  result <- c(0,0)
  names(result) <- c("segment", "pos")
  for (i in 2:length(sums)) {
#   cat("Testing sum: ", sums[i], "\n")
   if (absStarts[i] > pos) { # (pos - sum > lengths[i]) { # ((sum + lengths[i]) >= pos) {
#       cat("jep: ", pos, sums[i], sums[i-1], pos-sums[i-1], "\n")
       result[1] <- i - 1
       result[2] <- pos - offsets[i-1] 
       stopifnot(result[2] >= 1)
#       stopifnot(result[2] <= lengths[i-1])
       break
   }
#   sum <- sum + lengths[i] 
  }
  stopifnot(pos <= absStarts[length(absStarts)])
  if (result[1] == 0) {
#    stopifnot(pos <= sum) # otherwise illegal position
    result[1] = length(absStarts)-1
    result[2] = absStarts[length(absStarts)] - absStarts[length(absStarts)-1]
  }
  stopifnot(result[1] > 0)
#  stopifnot(result[2] <= lengths[result[1]])
  }
  result  
}


test.abspos2relpos <- function() {
  lengths <- c(5,10)
  result <- abspos2relpos(3,lengths)
  print(result)
  stopifnot(result[1] == 1)
  stopifnot(result[2] == 3)
  result2 <- abspos2relpos(6,lengths)
  print(result2)
  stopifnot(result2[1] == 2)
  stopifnot(result2[2] == 1)
  result3 <- abspos2relpos(5,lengths)
  print(result3)
  stopifnot(result3[1] == 1)
  stopifnot(result3[2] == 5)
  result4 <- abspos2relpos(15,lengths)
  print(result4)
  stopifnot(result4[1] == 2)
  stopifnot(result4[2] == 10)
}

test.abspos2relpos2 <- function() {
  lengths <- c(5,10)
  for (i in 1:20) {
    cat("Testing position", i, ": ")
    print(abspos2relpos(i,lengths))
  }  

}

test.abspos2relpos3 <- function() {
  lengths <- c(3,4)
  for (i in 1:30) {
    cat("Testing position", i, ": ")
    print(abspos2relpos(i,lengths))
  }  

}

test.abspos2relpos2_1 <- function() {
  starts <- c(1,11,17)
  for (i in 1:30) {
    cat("Testing position", i, ": ")
    print(abspos2relpos2(i,starts))
  }  

}


do_matchfold <- function(infile, outfile, jitter, nIter) {
cat("Input file:", infile, "output:", outfile, "jitter:", jitter, "\n")

mtx <- NULL
mtxoutfile = paste(outfile, ".matrix", sep="")

for (i in 1:nIter) {

    command=paste("matchfold --of 3 -o", mtxoutfile, "-j", jitter,"<", infile)

#    cat(command)
    system(command)
    if (!file.exists(mtxoutfile)) {
      cat("Could not find tmp file:", mtxoutfile, "\n")
      stop()
    }
    mtxtmp <- as.matrix(read.table(mtxoutfile))
    if (i == 1) {
      mtx <- mtxtmp
    } else {
      mtx <- mtx + mtxtmp
    }

}


mtx <- mtx/nIter
colnames(mtx) <- NULL
rownames(mtx) <- NULL

command2=paste("matchfold  --diag1 <", infile) # find out sequence starts
resultLines = system(command2, intern=TRUE)

starts = c()
for (line in resultLines) {
  words <- strsplit(line, " ")[[1]]
  if (words[1] == "starts") {
    for (i in 3:length(words)) {
       starts <- c(starts, words[i])
    }
    starts <- as.numeric(starts) # convert strings to numbers
    break
  }
}
attr(mtx, "starts") <- starts

write.table(mtx, file=mtxoutfile, col.names=FALSE, row.names=FALSE)
mtx
}


# main part

nIter = 10;

args = commandArgs(trailingOnly=TRUE)
cat("Script called with arguments:\n")
for (arg in args) {
 cat(arg, " ")
}
cat("\n")
if (length(args) < 3) {
  cat("Usage: Rscript matchfoldprob.R infile outfile jitter\n")
  stop()
}

infile = args[1]
outfile = args[2]
jitter = args[3]
remain = args[4:length(args)]

mtx <- do_matchfold(infile, outfile, jitter, nIter)
starts <- attr(mtx,"starts")
cat("Using starts:\n")
print(starts)
library(ggplot2)

mtxm <- melt(mtx)
mtxm$Relative_Frequency <- mtxm$value
nTot <- nrow(mtx) # total number of residues
print(head(mtxm, n=20))
# convert absolute positions to segment ids (which sequence it belongs to) and relative positions:
n <- nrow(mtxm)
mtxm$S1 <- rep(0, n)
mtxm$RX1 <- rep(0, n)
mtxm$S2 <- rep(0, n)
mtxm$RX2 <- rep(0, n)
startsTot <- c(starts, nTot+1)
# startsNoZero <- c(starts[2:length(starts)], nTot) # no need for initial zero but for total length
startsCore <- c(starts[2:length(starts)]) # no need for initial zero but for total length
startsCore2 <- c(starts[1:length(starts)]) # no need for initial zero but for total length
startsCore2b <- startsCore2 -1
startsCore3 <- c(startsTot, startsTot)

for (i in 1:length(startsTot)) {
 startsCore3[2*i-1] <- startsTot[i]-1
 startsCore3[2*i] <- startsTot[i]
}
startsCore4 <- startsTot
startsCore3 <- startsCore3[2:(length(startsCore3)-1)]
startsCore3Labels <- rep(0, length(startsCore3))
startsCore4Labels <- rep(0, length(startsCore4))
cat("Getting labels for positions: ")
print(startsCore3)
for (i in 1:length(startsCore3)) {
  cat("looking for " , startsCore3[i], as.character(abspos2relpos2(startsCore3[i], startsTot)), ":\n")
  print(abspos2relpos2(startsCore3[i], startsTot))
  startsCore3Labels[i] <- abspos2relpos2(startsCore3[i], startsTot)[2]
}
for (i in 1:length(startsCore4)) {
  cat("looking for " , startsCore4[i], as.character(abspos2relpos2(startsCore4[i], startsTot)), ":\n")
  print(abspos2relpos2(startsCore4[i], startsTot))
  if (i == 1) {
   startsCore4Labels[i] <- as.character(abspos2relpos2(startsCore4[i], startsTot)[2])
  } else if (i == length(startsCore4)) {
   startsCore4Labels[i] <- as.character(abspos2relpos2(startsCore4[i]-1, startsTot)[2])
  } else {
   startsCore4Labels[i] <- paste(as.character(abspos2relpos2(startsCore4[i]-1, startsTot)[2]), "|",
	as.character(abspos2relpos2(startsCore4[i], startsTot)[2]))                  
  }
}
startsCore3Labels <- as.character(startsCore3Labels)
print("Labels:\n")
print(startsCore3Labels)
stopifnot(length(startsCore3Labels) == length(startsCore3))

#for (i in 1:n) {
#  cat("Converting: ")
#  print(mtxm[i,])
#  v <- abspos2relpos2(mtxm$X1[i], startsTot)
#  stopifnot(length(v) == 2)
#  mtxm$S1[i] <- v[1] # segment
#  stopifnot(v[1] > 0)
#  mtxm$RX1[i] <- v[2] # relative position
#  stopifnot(v[2] <= startsTot[v[1]])
#  v <- abspos2relpos2(mtxm$X2[i], startsTot)
#  mtxm$S2[i] <- v[1] 
#  mtxm$RX2[i] <- v[2]
#  print(mtxm[i,])
# }

pngFile <- paste(outfile, ".ps", sep="")
postscript(pngFile)
base_size=18
print(ggplot(data=mtxm, aes(x=X1,y=X2, fill=Relative_Frequency)) + geom_tile() + xlab("") + ylab("") + theme_bw() + scale_fill_gradient(low="white", high="black") + geom_hline(yintercept=(startsCore-0.5)) + geom_vline(xintercept=(startsCore-0.5)) + scale_x_continuous(breaks=startsCore4-0.5, labels=startsCore4Labels) + scale_y_continuous(breaks=startsCore4-0.5,labels=startsCore4Labels) + coord_equal()+ opts(axis.text.x = theme_text(size = base_size * 0.8 , lineheight = 0.9, colour = "grey10", hjust = 0.5, vjust=1)) + opts(axis.text.y = theme_text(size = base_size * 0.8 , lineheight = 0.9, colour = "grey10", hjust=0.5, vjust = -1,angle=90)) + opts(legend.position="none"))

# + opts(axis.title.x = theme_text(size = base_size)) + opts(axis.title.y = theme_text(size = base_size, angle=90, hjust=1)) + opts(legend.position="none") )

# + opts(legend.text = theme_text(size = base_size * 0.7)) + opts(legend.title = theme_text(size = base_size * 0.8, face = "bold")))

# print(ggplot(data=mtxm, aes(x=RX1,y=RX2, fill=Relative_Frequency)) + geom_tile() + xlab("Position(residues)") + ylab("Position(residues)") + theme_bw() + scale_fill_gradient(low="white", high="black") + facet_grid(S1 ~ S2, scales="free", space="free") )
# + scale_x_continuous(breaks=startsNoZero))
dev.off()

write.table(mtxm, file="mtxm.tab")