import urllib
import xml.etree.ElementTree as ET
import os
import urlparse
import math
import cgi
import sys


class AppURLopener(urllib.FancyURLopener):
    version = 'Mozilla/5.0 (Windows; U; Windows NT 5.1; it; rv:1.8.1.11) Gecko/20071127 Firefox/2.0.0.11'

urllib._urlopener = AppURLopener()

def escapeToHTML(text, escapeQuotes=False):
    htmlEscapedText = cgi.escape(text, escapeQuotes) # escape html symbols, like <>&
    htmlEntityText = htmlEscapedText.encode('ascii', 'xmlcharrefreplace') # encode non-ascii characters into xhtml entities, like &#947;
    return htmlEntityText

def flatten(lst):
    for i in lst:
        if isinstance(i, list):
            for j in flatten(i):
                yield j
        else:
            yield i

def depth(lst):
    if isinstance(lst, list):
        return 1 + max(depth(item) for item in lst)
    else:
        return 0

def checkXML(XML, path):
    if XML.find(path) is not None:
        if XML.find(path).text is not None:
            return XML.find(path).text
        else:
            return ""
    else:
        return ""   

def digest_authors(authors):
    author_list = []
    for author in authors:
         author_list.append(checkXML(author, "./Initials") + " " + checkXML(author, "./LastName"))

    return author_list

def digest_issue(issue):
    issue_string = ""
    issue_string += checkXML(issue, "./Volume")
    if issue.find("./Issue") is not None:
        issue_string += "(" + issue.find("./Issue").text + ")"
    return issue_string 

def digest_year(XML):
    if XML.find("./Article/Journal/JournalIssue/PubDate/Year") is not None:
        return XML.find("./Article/Journal/JournalIssue/PubDate/Year").text
    elif XML.find("./Article/Journal/JournalIssue/PubDate/MedlineDate") is not None:
        return XML.find("./Article/Journal/JournalIssue/PubDate/MedlineDate").text[0:4]
    else:
        return "1900"

def family(lst):
    num_children = 0
    num_gchildren = 0
    num_ggchildren = 0
    num_gggchildren = 0
    for item in lst:
        if isinstance(item, list):
            for subitem in item:
                if isinstance(subitem, list):
                    for subsubitem in subitem:
                        if isinstance(subsubitem, list):
                            for subsubsubitem in subsubitem:
                                if isinstance(subsubsubitem, list):
                                    num_gggchildren += len(list(flatten(subsubsubitem)))
                                else:
                                    num_gggchildren += 1
                        else:
                            num_ggchildren +=1
                else:
                    num_gchildren += 1
                    
        else:
            num_children += 1
    return [num_children, num_gchildren, num_ggchildren, num_gggchildren]

def whocitedme(pmid_in, cache):
    if len(cache) > 100000: #if this has been going on for far too long
        return
    print ".",
    sys.stdout.flush() #To prevent timeout via web.
    
    url = "http://www.ncbi.nlm.nih.gov/pubmed?linkname=pubmed_pubmed_citedin&from_uid=" + pmid_in
    data = urllib.urlopen(url).read()
    tree = ET.fromstring(data)
    ps = tree.findall(".//{http://www.w3.org/1999/xhtml}p") #find all <p>s

    urls = []
    pmids = [] #People who cited this paper
    years = [] #Year of the paper
    for elm in ps:
        if elm.attrib == {'class': 'title'}:
            url_string = elm.find('{http://www.w3.org/1999/xhtml}a').get('href')
            urls.append("http://www.ncbi.nlm.nih.gov" + url_string )
            pmids.append(url_string[url_string.rfind('/')+1:])
        if elm.attrib == {'class': 'details'}:
            for txt in elm.itertext():
                if txt[2:6].isdigit():
                    years.append(txt[2:6])    
    new_ids = []
    new_years = []
    cache.append ( pmid_in )
    if len(pmids) > 0: #if we found something new
        for pmid in pmids:
            if pmid not in cache:
                super_new, super_year = whocitedme(pmid, cache)
                if super_new is not None:
                    new_ids.append( super_new )
                    new_years.append( super_year )        
        return new_ids + pmids, new_years + years
    else:
        return None, None

def search_pubmed(term):
    params= {
        'db': 'pubmed',
        'tool': 'test',
        'email':'test@test.com',
        'term': term,
        'usehistory':'y',
        'retmax':20
        }
    url = 'http://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?' + urllib.urlencode(params)

    tree = ET.fromstring(urllib.urlopen(url).read())

    params['query_key'] = tree.find("./QueryKey").text
    params['WebEnv'] =  tree.find("./WebEnv").text               
    params['retmode'] = 'xml'

    url = 'http://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?' + urllib.urlencode(params)
    data = urllib.urlopen(url).read()
    return data

def xml_to_papers(data):
    tree = ET.fromstring(data)
    articles = tree.findall("./PubmedArticle/MedlineCitation")

    papers = []
    for article in articles:
        paper = dict()
        paper["journal_name"] = article.find("./Article/Journal/ISOAbbreviation").text
        paper["title"] = article.find("./Article/ArticleTitle").text
        paper["authors"] = digest_authors(article.findall("./Article/AuthorList/Author"))
        paper["issue"] = digest_issue(article.find("./Article/Journal/JournalIssue"))
        paper["year"] = digest_year(article)
        paper["page_num"] = checkXML(article, "./Article/Pagination/MedlinePgn")
        paper["pmid"] = article.find("./PMID").text
        paper["doi"] = checkXML(article, "./Article/ELocationID")   
        
        papers.append(paper)

    return papers


def citelist_to_hist(deepyearlist):
    flatlist = list(flatten(deepyearlist))
    
    output = [0 for x in range(120)] # output[0] = num papers in 1900, [1] = 1901 etc...    
    for year in flatlist:
        output[int(year)-1900] += 1
    
    return output

def sum_hist(hist_list):
    output = [0 for x in range(len(hist_list))] # output[n] = sum(hist_list[0]...hist_list[n])
    output[0] = hist_list[0]
    for index in range(1, len(hist_list)):
        output[index] = output[index-1] + hist_list[index]

    return output


def first_non_zero(lst):
    for index in range(len(lst)):
        if lst[index] > 0:
            return index

def last_non_zero(lst):
    for index in range(len(lst)-1, 0, -1):
        if lst[index] > 0:
            return index        

def offset_hist(size_by_year, last_year_index): #reduce it is so the first year with a citation is element zero
    output = []
    index = first_non_zero(size_by_year)

    for ind in range(index, last_year_index+1):
        output.append(size_by_year[ind])

    return output

print "Content-Type: text/html" 
print 
print "<HTML>"


url = os.environ["REQUEST_URI"]
prased_url = urlparse.urlparse(url) 
params = urlparse.parse_qs(prased_url.query)
pmid = False
if len(params["pmid"][0]) == 8 and str(params["pmid"][0]).isdigit():
    pmid = str(params["pmid"][0])
    #Else PMID IS ILLEGALLY FORMATTED

if not (pmid==False):

  searched_paper_xml = search_pubmed(pmid)
  paper = xml_to_papers(searched_paper_xml)
  paper = paper[0]
  
  cache = []
  print "<div id='dotdiv'>"
  citers, years = whocitedme(pmid, cache) #Here is the meat
  print "</div>"
  print "<script type='text/javascript'>"
  print "  document.getElementById('dotdiv').innerHTML = '';"
  print "</script>"
  family_tree = family(citers)
  cites_per_year = citelist_to_hist(years)
  size_by_year = sum_hist(cites_per_year)
  last_year_index = last_non_zero(cites_per_year)  #Last year paper was cited (offset by 1900)
  first_year_index = first_non_zero(cites_per_year) #First Year Paper was cited (offset by 1900)

  size_by_year = sum_hist(cites_per_year) #Total Family Tree Size
  offset_size = offset_hist(size_by_year, last_year_index)


  y = [] #log transform
  for size in offset_size:
      y.append(math.log(size))

  n = len(y)
  x = range(len(y))

  sumx = 0
  sumx2 = 0
  sumxy = 0
  sumy = 0 
  sumy2 = 0


  for i in range(len(y)):
      sumx  += x[i]
      sumx2 += x[i] * x[i]
      sumxy += x[i] * y[i]
      sumy  += y[i]
      sumy2 += y[i] * y[i] 

  slope = (n * sumxy  -  sumx * sumy) /  (n * sumx2 - sumx * sumx)
  offset = (sumy * sumx2  -  sumx * sumxy) /  (n * sumx2  -  sumx * sumx);

  print "<HEAD>"
  print "<link rel='stylesheet' id='adaption-style-css'  href='http://www.billconnelly.net/wp-content/themes/adaption/style.css?ver=4.0' type='text/css' media='all' />"
  print "<TITLE>CGI script output</TITLE>"
  #GRAPH CODE
  print ""
  print "    <script type='text/javascript' src='https://www.google.com/jsapi'></script>"
  print "    <script type='text/javascript'>"
  print "      google.load('visualization', '1', {packages:['corechart']});"
  print "      google.setOnLoadCallback(drawChart);"
  print "      function drawChart() {"
  print "        var data = google.visualization.arrayToDataTable(["
  print "          ['Year', 'Total Citations', 'Data Fit'],"
  for index in range(first_non_zero(size_by_year), last_non_zero(cites_per_year) +1):
      if index < last_non_zero(cites_per_year):
          print "['" + str(index + 1900) + "',  " + str(size_by_year[index]) + ", " + str(  math.exp( (index -first_non_zero(size_by_year)) * slope + offset  )) + "],"
      else:
          print "['" + str(index + 1900) + "',  " + str(size_by_year[index]) + ", " + str(  math.exp( (index - first_non_zero(size_by_year)) * slope + offset  )) +"]"
  print "        ]);"
  print " "
  print "        var options = {"
  print "          title: 'Citation tree for "+  paper["title"] +"',"   
  print "          vAxis: {"
  print "            logScale: true,"
  print "            title: 'Total Size of Family Tree'"
  print "          },"
  print "          hAxis: {"
  print "            title: 'Year'"       
  print "          },"
  print "          legend: {"
  print "            position: 'in'"
  print "          },"
  print "          pointSize: 4"
  print "        };"
  print " "
  print "        var chart = new google.visualization.LineChart(document.getElementById('chart_div'));"
  print " "
  print "        chart.draw(data, options);"
  print "      }"
  print "    </script>"
  print "</HEAD>"
  print "<BODY style='background:#FFFFFF';>"
  ##PRINT SUBMITED PAPER
  print "<div class='rslt' style='margin-bottom: 20px'>"
  print "For your paper: <br>"
  print "   <p class='title' style='margin-bottom: 0px'>"
  print "   <a href='http://www.ncbi.nlm.nih.gov/pubmed/" + paper["pmid"] + "'>"+ escapeToHTML(paper["title"]) +"</a>"
  print "   </p>"
  print "   <div class='supp'>"
  authorlist = ""
  num_authors = len(paper["authors"])
  for a in range(num_authors):
    authorlist += escapeToHTML(paper["authors"][a])
    if a < num_authors-1:
      authorlist += ", "
    else:
      authorlist += "."
  
  print "     <p class='desc' style='margin-bottom: 0px'>" + authorlist + "</p>"
  print "     <p class'details' style='margin-bottom: 0px'>"

  if paper["journal_name"][-1] == ".":
    j_title = paper["journal_name"][0:-1]
  else:
    j_title = paper["journal_name"]

  if len(paper["doi"]) > 0:
    print "       <span>"+ j_title +"</span>. "+paper["year"]+" "+  paper["issue"] + ":" + paper["page_num"] + " doi: <a href='http://doi.org/" +paper["doi"]+ "'>"+ paper["doi"] + " </a></p>"
  else:
    print "       <span>"+ j_title +"</span>. "+paper["year"]+" "+  paper["issue"] + ":" + paper["page_num"] + " </p>"
  print "   </div>"
  print "</div>"

  ##PRINT CITATION DETAILS
  print "<div class='post'>"
  deepest = last_non_zero(family_tree)
  famly_str = ""
  for depth in range(0, deepest+1):
      if depth == deepest-1:
          linkage = " and "
      elif depth == deepest:
          linkage = "."
      else:
          linkage = ", "
      if depth == 0:
          famly_str += str(family_tree[0]) + " children" + linkage
      elif depth == 1:
          famly_str += str(family_tree[1]) + " grandchildren" + linkage
      elif depth == 2:
          famly_str += str(family_tree[2]) + " great-grandchildren" + linkage
      elif depth == 3:
          famly_str += str(family_tree[3]) + " great-great (or more) grandchildren"         

  print "Your paper has a family tree with a total of " + str(sum(family_tree)) + " papers in it."
  print "<br>"
  print "Your family tree is made up of " + famly_str + "."
  offset_str = str(math.exp(offset))
  slope_str = str(slope)
  unbiased_str = str(math.exp(offset)/slope) 
  print "According to the citation tree, your paper's original citability was", offset_str[0:4], "and your field citability was", slope_str[0:4] + ". This means your paper had a non-biased impact of", unbiased_str[0:4]
  print "</div>"
  print "<div class='output' style='margin-bottom: 0px'>"
  print "    <div id='chart_div' style='width: 600px; height: 400px;'></div>"
  print "</div>"
else:
  print "That is an improperly formated query. "
print "</BODY>"
print "</HTML>"