from frb_ml_utils import * 
import frb_ml_utils
import numpy as np
import pandas as pd
import scipy
from matplotlib import pyplot as plt
from sklearn import svm
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier,AdaBoostClassifier
from sklearn.neighbors import NearestCentroid,KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from imblearn.over_sampling import RandomOverSampler, SMOTE, ADASYN
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
from joblib import Parallel, delayed
CHIME = load_chime()
columns_to_use = ['bc_width','flux','fluence','dm_exc_ne2001',
                  'peak_freq',
                  'bright_temp','rest_width','freq_width','energy']
CHIME['bright_temp'] = np.log10(CHIME['bright_temp'])
CHIME['energy'] = np.log10(CHIME['energy'])
CHIME['rest_width'] = CHIME['rest_width'] * 1000
CHIME['bc_width'] = CHIME['bc_width'] * 1000

bagging_times = 1000
possible_repeaters = np.zeros((bagging_times,6,len(CHIME)))
2 78.8 0.00225301 FRB20180729A
12 101.5 0.00225301 FRB20180814A
38 101.0 0.00225301 FRB20180919A
49 94.7 0.00225301 FRB20180928A
75 101.3 0.00225301 FRB20181028A
76 101.3 0.00225301 FRB20181028A
77 101.3 0.00225301 FRB20181028A
78 101.3 0.00225301 FRB20181028A
79 101.3 0.00225301 FRB20181028A
81 62.3 0.00225301 FRB20181030A
82 62.5 0.00225301 FRB20181030B
158 83.6 0.00225301 FRB20181220A
174 92.6 0.00225301 FRB20181223C
221 96.1 0.00225301 FRB20190107B
399 100.8 0.00225301 FRB20190329A
459 79.4 0.00225301 FRB20190425A
571 100.7 0.00225301 FRB20190625E
572 100.7 0.00225301 FRB20190625E
573 100.7 0.00225301 FRB20190625E
576 101.5 0.00225301 FRB20190626A
CHIME = load_chime()
columns_to_use = ['bc_width','flux','fluence','dm_exc_ne2001',
                  'peak_freq',
                  'bright_temp','rest_width','freq_width','energy']
CHIME['bright_temp'] = np.log10(CHIME['bright_temp'])
CHIME['energy'] = np.log10(CHIME['energy'])
CHIME['rest_width'] = CHIME['rest_width'] * 1000
CHIME['bc_width'] = CHIME['bc_width'] * 1000

CHIME['freq_width'] = np.log10(CHIME['freq_width'])

bagging_times = 1000
possible_repeaters = np.zeros((bagging_times,6,len(CHIME)))
def find_repeater(i):
    temp_repeaters = np.zeros((6,len(CHIME)))
    chime_data = CHIME[columns_to_use]
    chime_target = (CHIME['repeater_name'] != '-9999').to_numpy().astype('int')
    X,test_X,y,test_y = train_test_split(chime_data,chime_target,test_size=0.3,stratify=chime_target)

    scaler = StandardScaler()
    scaler.fit(X)  
    X = scaler.transform(X)  
    test_X = scaler.transform(test_X)
    chime_data = scaler.transform(chime_data)

    X, y = SMOTE().fit_resample(X, y)
    
    clf = svm.SVC()
    clf.fit(X, y)
    predictions = clf.predict(chime_data)
    temp_repeaters[0] = np.logical_and(predictions==1, chime_target==0)
    
    clf = NearestCentroid()
    clf.fit(X, y)
    predictions = clf.predict(chime_data)
    temp_repeaters[1] = np.logical_and(predictions==1, chime_target==0)
    
    clf = RandomForestClassifier()
    clf.fit(X, y)
    predictions = clf.predict(chime_data)
    temp_repeaters[2] = np.logical_and(predictions==1, chime_target==0)

    clf = AdaBoostClassifier()
    clf.fit(X, y)
    predictions = clf.predict(chime_data)
    temp_repeaters[3] = np.logical_and(predictions==1, chime_target==0)

    clf = LGBMClassifier()
    clf.fit(X, y)
    predictions = clf.predict(chime_data)
    temp_repeaters[4] = np.logical_and(predictions==1, chime_target==0)

    clf = XGBClassifier(use_label_encoder=False,eval_metric='logloss')
    clf.fit(X, y)
    predictions = clf.predict(chime_data)
    temp_repeaters[5] = np.logical_and(predictions==1, chime_target==0)
    return temp_repeaters

possible_repeaters = np.array(Parallel(n_jobs=6,verbose=10)(delayed(find_repeater)(i) for i in range(bagging_times)))
2 78.8 0.00225301 FRB20180729A
12 101.5 0.00225301 FRB20180814A
38 101.0 0.00225301 FRB20180919A
49 94.7 0.00225301 FRB20180928A
75 101.3 0.00225301 FRB20181028A
76 101.3 0.00225301 FRB20181028A
77 101.3 0.00225301 FRB20181028A
78 101.3 0.00225301 FRB20181028A
79 101.3 0.00225301 FRB20181028A
81 62.3 0.00225301 FRB20181030A
82 62.5 0.00225301 FRB20181030B
158 83.6 0.00225301 FRB20181220A
174 92.6 0.00225301 FRB20181223C
221 96.1 0.00225301 FRB20190107B
399 100.8 0.00225301 FRB20190329A
459 79.4 0.00225301 FRB20190425A
571 100.7 0.00225301 FRB20190625E
572 100.7 0.00225301 FRB20190625E
573 100.7 0.00225301 FRB20190625E
576 101.5 0.00225301 FRB20190626A
[Parallel(n_jobs=6)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=6)]: Done   1 tasks      | elapsed:    3.6s
[Parallel(n_jobs=6)]: Done   6 tasks      | elapsed:    3.8s
[Parallel(n_jobs=6)]: Done  13 tasks      | elapsed:    4.7s
[Parallel(n_jobs=6)]: Done  20 tasks      | elapsed:    5.2s
[Parallel(n_jobs=6)]: Done  29 tasks      | elapsed:    5.9s
[Parallel(n_jobs=6)]: Done  38 tasks      | elapsed:    6.9s
[Parallel(n_jobs=6)]: Done  49 tasks      | elapsed:    8.0s
[Parallel(n_jobs=6)]: Done  60 tasks      | elapsed:    8.8s
[Parallel(n_jobs=6)]: Done  73 tasks      | elapsed:   10.0s
[Parallel(n_jobs=6)]: Done  86 tasks      | elapsed:   10.9s
[Parallel(n_jobs=6)]: Done 101 tasks      | elapsed:   12.1s
[Parallel(n_jobs=6)]: Done 116 tasks      | elapsed:   13.2s
[Parallel(n_jobs=6)]: Done 133 tasks      | elapsed:   14.8s
[Parallel(n_jobs=6)]: Done 150 tasks      | elapsed:   16.5s
[Parallel(n_jobs=6)]: Done 169 tasks      | elapsed:   18.2s
[Parallel(n_jobs=6)]: Done 188 tasks      | elapsed:   20.0s
[Parallel(n_jobs=6)]: Done 209 tasks      | elapsed:   22.2s
[Parallel(n_jobs=6)]: Done 230 tasks      | elapsed:   24.2s
[Parallel(n_jobs=6)]: Done 253 tasks      | elapsed:   26.4s
[Parallel(n_jobs=6)]: Done 276 tasks      | elapsed:   28.7s
[Parallel(n_jobs=6)]: Done 301 tasks      | elapsed:   31.1s
[Parallel(n_jobs=6)]: Done 326 tasks      | elapsed:   33.6s
[Parallel(n_jobs=6)]: Done 353 tasks      | elapsed:   36.3s
[Parallel(n_jobs=6)]: Done 380 tasks      | elapsed:   38.9s
[Parallel(n_jobs=6)]: Done 409 tasks      | elapsed:   41.8s
[Parallel(n_jobs=6)]: Done 438 tasks      | elapsed:   44.9s
[Parallel(n_jobs=6)]: Done 469 tasks      | elapsed:   48.2s
[Parallel(n_jobs=6)]: Done 500 tasks      | elapsed:   51.6s
[Parallel(n_jobs=6)]: Done 533 tasks      | elapsed:   54.9s
[Parallel(n_jobs=6)]: Done 566 tasks      | elapsed:   58.6s
[Parallel(n_jobs=6)]: Done 601 tasks      | elapsed:  1.0min
[Parallel(n_jobs=6)]: Done 636 tasks      | elapsed:  1.1min
[Parallel(n_jobs=6)]: Done 673 tasks      | elapsed:  1.2min
[Parallel(n_jobs=6)]: Done 710 tasks      | elapsed:  1.2min
[Parallel(n_jobs=6)]: Done 749 tasks      | elapsed:  1.3min
[Parallel(n_jobs=6)]: Done 788 tasks      | elapsed:  1.4min
[Parallel(n_jobs=6)]: Done 829 tasks      | elapsed:  1.5min
[Parallel(n_jobs=6)]: Done 870 tasks      | elapsed:  1.6min
[Parallel(n_jobs=6)]: Done 913 tasks      | elapsed:  1.6min
[Parallel(n_jobs=6)]: Done 956 tasks      | elapsed:  1.7min
[Parallel(n_jobs=6)]: Done 1000 out of 1000 | elapsed:  1.8min finished
repeater_threshold = 4
candidate_list = np.zeros(len(CHIME))
for i in range(len(CHIME)):
    for j in range(bagging_times):
        if np.sum(possible_repeaters[j,:,i])>=repeater_threshold:
            candidate_list[i] += 1
greater_than_100 = np.sum(candidate_list>=100)
candi_list = candidate_list.argsort()[:-greater_than_100-1:-1]
candi_list.sort()
for i in range(candi_list.shape[0]):
    index = candi_list[i]
    print(CHIME.iloc[index]['tns_name'])
FRB20181017B
FRB20181030E
FRB20181128C
FRB20181218C
FRB20181221A
FRB20181229B
FRB20181231B
FRB20190106A
FRB20190109B
FRB20190112A
FRB20190125A
FRB20190125B
FRB20190128C
FRB20190129A
FRB20190206B
FRB20190206A
FRB20190218B
FRB20190329A
FRB20190409B
FRB20190410A
FRB20190412B
FRB20190422A
FRB20190423B
FRB20190423B
FRB20190429B
FRB20190527A
FRB20190609A
# leave one out knn
loo_candidate = []
CHIME = load_chime()
columns_to_use = ['bc_width','flux','fluence','dm_exc_ne2001',
                  'peak_freq',
                  'bright_temp','rest_width','freq_width','energy']
CHIME['bright_temp'] = np.log10(CHIME['bright_temp'])
CHIME['energy'] = np.log10(CHIME['energy'])
CHIME['rest_width'] = CHIME['rest_width'] * 1000
CHIME['bc_width'] = CHIME['bc_width'] * 1000
CHIME['freq_width'] = np.log10(CHIME['freq_width'])

for i in range(len(CHIME)):
    chime_data = CHIME[columns_to_use]
    chime_target = (CHIME['repeater_name'] != '-9999').to_numpy().astype('int')
    X = chime_data.drop(i)
    test_X = chime_data.iloc[[i]]
    mask = np.ones(len(chime_target), dtype=bool)
    mask[i] = False
    y = chime_target[mask]
    test_y = chime_target[i]
    
    scaler = StandardScaler()
    scaler.fit(X)  
    X = scaler.transform(X)  
    test_X = scaler.transform(test_X)
    clf = KNeighborsClassifier()
    clf.fit(X, y)
    prediction = clf.predict(test_X)[0]
    if test_y==0 and prediction==1:
        print(i)
        loo_candidate.append(i)
loo_candidate = np.array(loo_candidate)
2 78.8 0.00225301 FRB20180729A
12 101.5 0.00225301 FRB20180814A
38 101.0 0.00225301 FRB20180919A
49 94.7 0.00225301 FRB20180928A
75 101.3 0.00225301 FRB20181028A
76 101.3 0.00225301 FRB20181028A
77 101.3 0.00225301 FRB20181028A
78 101.3 0.00225301 FRB20181028A
79 101.3 0.00225301 FRB20181028A
81 62.3 0.00225301 FRB20181030A
82 62.5 0.00225301 FRB20181030B
158 83.6 0.00225301 FRB20181220A
174 92.6 0.00225301 FRB20181223C
221 96.1 0.00225301 FRB20190107B
399 100.8 0.00225301 FRB20190329A
459 79.4 0.00225301 FRB20190425A
571 100.7 0.00225301 FRB20190625E
572 100.7 0.00225301 FRB20190625E
573 100.7 0.00225301 FRB20190625E
576 101.5 0.00225301 FRB20190626A
5
47
49
60
101
124
203
220
224
262
265
270
287
292
399
428
454
455
465
553
for ind,row in CHIME.iloc[loo_candidate]['tns_name'].iteritems():
    print(row)
FRB20180801A
FRB20180925A
FRB20180928A
FRB20181017B
FRB20181119B
FRB20181128C
FRB20181231B
FRB20190107A
FRB20190109B
FRB20190124E
FRB20190125B
FRB20190128C
FRB20190204A
FRB20190206A
FRB20190329A
FRB20190412B
FRB20190423B
FRB20190423B
FRB20190429B
FRB20190617B