Coverage for gwcelery/tasks/gwskynet.py: 97%

78 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 18:24 +0000

1"""GWSkyNet annotation with GWSkyNet model""" 

2import json 

3import re 

4from functools import cache 

5 

6import numpy as np 

7 

8from .. import app 

9from ..util.tempfile import NamedTemporaryFile 

10from . import gracedb, igwn_alert, superevents 

11 

12manual_pref_event_change_regexp = re.compile( 

13 app.conf['views_manual_preferred_event_log_message'].replace('.', '\\.') 

14 .replace('{}', '.+') 

15) 

16 

17 

18@cache 

19def GWSkyNet_model(): 

20 # FIXME Remove import from function scope once importing GWSkyNet is not a 

21 # slow operation 

22 from GWSkyNet import GWSkyNet 

23 

24 return GWSkyNet.load_GWSkyNet_model() 

25 

26 

27@app.task(queue='skynet', shared=False) 

28def gwskynet_annotation(input_list, SNRs, superevent_id): 

29 """Perform the series of tasks necessary for GWSkyNet to 

30 

31 Parameters 

32 ---------- 

33 input_list : list 

34 The output of _download_and_keep_file_name that includes the 

35 downloaded the skymap and the versioned file name of the skymap. 

36 This list is in the form [skymap, skymap_filename]. 

37 snr : numpy array of floats 

38 detector SNRs. 

39 superevent_id : str 

40 superevent uid 

41 skymap_filename : str 

42 versioned filename for skymap 

43 """ 

44 # FIXME Remove import from function scope once importing GWSkyNet is not a 

45 # slow operation 

46 from GWSkyNet import GWSkyNet 

47 

48 filecontents, skymap_filename = input_list 

49 with NamedTemporaryFile(content=filecontents) as fitsfile: 

50 GWSkyNet_input = GWSkyNet.prepare_data(fitsfile.name) 

51 # One of the inputs from BAYESTAR to GWSkyNet is the list of instruments, 

52 # i.e., metadata['instruments'], which is converted to a binary array with 

53 # three elements, i.e. GWSkyNet_input[2], for H1, L1 and V1. 

54 # GWSkyNet 2.4.0 uses this array to indicate detector with SNR >= 4.5 

55 GWSkyNet_input[2][0] = np.where(SNRs >= app.conf['gwskynet_snr_threshold'], 

56 1, 0) 

57 gwskynet_score = GWSkyNet.predict(GWSkyNet_model(), GWSkyNet_input) 

58 FAP, FNP = GWSkyNet.get_rates(gwskynet_score) 

59 fap = FAP[0] 

60 fnp = FNP[0] 

61 gs = gwskynet_score[0] 

62 gwskynet_output = {'superevent_id': superevent_id, 

63 'file': skymap_filename, 

64 'GWSkyNet_score': gs, 

65 'GWSkyNet_FAP': fap, 

66 'GWSkyNet_FNP': fnp} 

67 return json.dumps(gwskynet_output) 

68 

69 

70def get_cbc_event_snr(event): 

71 """Get detector SNRs from the LVAlert packet. 

72 

73 Parameters 

74 ---------- 

75 event : dict 

76 Event dictionary (e.g., the return value from 

77 :meth:`gwcelery.tasks.gracedb.get_event`, or 

78 ``preferred_event_data`` in igwn-alert packet.) 

79 

80 Returns 

81 ------- 

82 snr : numpy array of floats 

83 detector SNRs. 

84 

85 """ 

86 # GWSkyNet 2.4.0 uses this SNR array to modify one of the inputs, so 

87 # snr needs to be formatted such that index 0, 1 and 2 points to H1, 

88 # L1 and V1 respectively 

89 snr = np.zeros(3) 

90 attribs = event['extra_attributes']['SingleInspiral'] 

91 for det in attribs: 

92 if det['ifo'] == 'H1': 

93 snr[0] = det['snr'] 

94 if det['ifo'] == 'L1': 

95 snr[1] = det['snr'] 

96 if det['ifo'] == 'V1': 

97 snr[2] = det['snr'] 

98 return snr 

99 

100 

101@gracedb.task(shared=False) 

102def _download_and_return_file_name(filename, graceid): 

103 """Wrapper around gracedb.download that returns the file name.""" 

104 filecontents = gracedb.download(filename, graceid) 

105 return [filecontents, filename] 

106 

107 

108@gracedb.task(shared=False) 

109def _unpack_gwskynet_annotation_and_upload(gwskynet_output, graceid): 

110 filename = 'gwskynet.json' 

111 gwskynet_output_dict = json.loads(gwskynet_output) 

112 message = ('GWSkyNet annotation from <a href=' 

113 '"/api/events/{graceid}/files/' 

114 '{skymap_filename}">' 

115 '{skymap_filename}</a>.' 

116 ' GWSkyNet score: {cs},' 

117 ' GWSkyNet FAP: {GWSkyNet_FAP},' 

118 ' GWSkyNet FNP: {GWSkyNet_FNP}.').format( 

119 graceid=graceid, 

120 skymap_filename=gwskynet_output_dict['file'], 

121 cs=np.round(gwskynet_output_dict['GWSkyNet_score'], 3), 

122 GWSkyNet_FAP=np.round(gwskynet_output_dict['GWSkyNet_FAP'], 

123 3), 

124 GWSkyNet_FNP=np.round(gwskynet_output_dict['GWSkyNet_FNP'], 

125 3) 

126 ) 

127 return gracedb.upload(gwskynet_output, filename, graceid, message=message, 

128 tags=['em_follow', 'public']) 

129 

130 

131def _should_annotate(preferred_event, new_label, new_log_comment, labels, 

132 alert_type): 

133 # First check if the event passes all of GWSkyNet's annotation criteria 

134 SNRs = get_cbc_event_snr(preferred_event) 

135 

136 if not (preferred_event['search'].lower() == 'allsky' and 

137 preferred_event['far'] <= app.conf['gwskynet_upper_far_threshold'] 

138 and (SNRs >= app.conf['gwskynet_snr_threshold']).sum() >= 2 and 

139 np.sqrt(sum(SNRs**2)) >= 

140 app.conf['gwskynet_network_snr_threshold']): 

141 return False 

142 

143 annotate = False 

144 # Check if the GWSkyNet should annotate in response to this IGWN-Alert 

145 if alert_type == 'label_added': 

146 if superevents.should_publish(preferred_event, significant=False) is \ 

147 False and new_label == 'SKYMAP_READY': 

148 # if the superevent is with FAR higher than the preliminary alert 

149 # threshold, GWSkyNet will anotate the superevent directly. 

150 annotate = True 

151 elif new_label == 'GCN_PRELIM_SENT' or \ 

152 new_label == 'LOW_SIGNIF_PRELIM_SENT': 

153 # if the FAR is lower than the preliminary alert threshold then 

154 # GWSkyNet annotates the superevent if the preliminary alert has 

155 # been sent. 

156 annotate = True 

157 elif 'GCN_PRELIM_SENT' not in labels and 'LOW_SIGNIF_PRELIM_SENT' not in \ 

158 labels: 

159 # GWSkyNet annotations not applied until after initial prelim sent when 

160 # FAR passes alert threshold 

161 pass 

162 elif new_log_comment.startswith('Localization copied from '): 

163 # GWSkyNet will also annotate the superevent if the sky map 

164 # has been changed (i.e. a sky map from a new g-event has been copied) 

165 annotate = True 

166 elif manual_pref_event_change_regexp.match(new_log_comment): 

167 # Need to check for a different log comment if the preferred event has 

168 # been changed manually 

169 annotate = True 

170 

171 return annotate 

172 

173 

174@igwn_alert.handler('superevent', 

175 shared=False) 

176def handle_cbc_superevent(alert): 

177 """"Annotate the CBC preferred events of superevents using GWSkyNet 

178 """ 

179 if alert['object']['preferred_event_data']['group'] != 'CBC': 

180 return 

181 

182 if alert['alert_type'] != 'label_added' and \ 

183 alert['alert_type'] != 'log': 

184 return 

185 

186 superevent_id = alert['uid'] 

187 preferred_event = alert['object']['preferred_event_data'] 

188 new_label = alert['data'].get('name', '') 

189 new_log_comment = alert['data'].get('comment', '') 

190 labels = alert['object'].get('labels', []) 

191 SNRs = get_cbc_event_snr(preferred_event) 

192 

193 if _should_annotate(preferred_event, new_label, new_log_comment, labels, 

194 alert['alert_type']): 

195 ( 

196 gracedb.get_latest_file.s(superevent_id, 

197 'bayestar.multiorder.fits') 

198 | 

199 _download_and_return_file_name.s(superevent_id) 

200 | 

201 gwskynet_annotation.s(SNRs, superevent_id) 

202 | 

203 _unpack_gwskynet_annotation_and_upload.s(superevent_id) 

204 ).apply_async()