Coverage for tests/publisher/endpoint_testing.py: 92%

156 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-28 22:05 +0000

1import requests 

2 

3import pymacaroons 

4import responses 

5from flask_testing import TestCase 

6from webapp.app import create_app 

7from webapp.authentication import get_authorization_header 

8 

9# Make sure tests fail on stray responses. 

10responses.mock.assert_all_requests_are_fired = True 

11 

12 

13class BaseTestCases: 

14 """ 

15 This class has a set of test classes that should be inherited by endpoint 

16 that have authentication. 

17 

18 It is also used to avoid unittest to run this tests file. 

19 """ 

20 

21 class BaseAppTesting(TestCase): 

22 def setUp(self, snap_name, api_url, endpoint_url): 

23 self.snap_name = snap_name 

24 self.api_url = api_url 

25 self.endpoint_url = endpoint_url 

26 

27 def tearDown(self): 

28 responses.reset() 

29 

30 def create_app(self): 

31 app = create_app(testing=True) 

32 app.secret_key = "secret_key" 

33 app.config["WTF_CSRF_METHODS"] = [] 

34 

35 return app 

36 

37 def _get_location(self): 

38 return "{}".format(self.endpoint_url) 

39 

40 def _log_in(self, client): 

41 """Emulates test client login in the store. 

42 

43 Fill current session with `openid`, `macaroon_root` and 

44 `macaroon_discharge`. 

45 

46 Return the expected `Authorization` header for further verification 

47 in API requests. 

48 """ 

49 # Basic root/discharge macaroons pair. 

50 root = pymacaroons.Macaroon("test", "testing", "a_key") 

51 root.add_third_party_caveat("3rd", "a_caveat-key", "a_ident") 

52 discharge = pymacaroons.Macaroon("3rd", "a_ident", "a_caveat_key") 

53 

54 with client.session_transaction() as s: 

55 s["publisher"] = { 

56 "image": None, 

57 "nickname": "Toto", 

58 "fullname": "El Toto", 

59 "email": "testing@testing.com", 

60 "stores": [], 

61 } 

62 s["macaroon_root"] = root.serialize() 

63 s["macaroon_discharge"] = discharge.serialize() 

64 

65 return get_authorization_header( 

66 root.serialize(), discharge.serialize() 

67 ) 

68 

69 def check_call_by_api_url(self, calls): 

70 found = False 

71 for called in calls: 

72 if self.api_url == called.request.url: 

73 found = True 

74 self.assertEqual( 

75 self.authorization, 

76 called.request.headers.get("Authorization"), 

77 ) 

78 

79 assert found 

80 

81 class EndpointLoggedOut(BaseAppTesting): 

82 def setUp(self, snap_name, endpoint_url, method_endpoint="GET"): 

83 self.method_endpoint = method_endpoint 

84 super().setUp(snap_name, None, endpoint_url) 

85 

86 def test_access_not_logged_in(self): 

87 if self.method_endpoint == "GET": 

88 response = self.client.get(self.endpoint_url) 

89 else: 

90 response = self.client.post(self.endpoint_url, data={}) 

91 

92 self.assertEqual(302, response.status_code) 

93 self.assertEqual( 

94 "/login?next={}".format(self.endpoint_url), 

95 response.location, 

96 ) 

97 

98 class EndpointLoggedIn(BaseAppTesting): 

99 def setUp( 

100 self, 

101 snap_name, 

102 endpoint_url, 

103 api_url, 

104 method_endpoint="GET", 

105 method_api="GET", 

106 data=None, 

107 json=None, 

108 ): 

109 super().setUp( 

110 snap_name=snap_name, api_url=api_url, endpoint_url=endpoint_url 

111 ) 

112 

113 self.method_endpoint = method_endpoint 

114 self.method_api = method_api 

115 self.data = data 

116 self.json = json 

117 self.authorization = self._log_in(self.client) 

118 

119 @responses.activate 

120 def test_timeout(self): 

121 responses.add( 

122 responses.Response( 

123 method=self.method_api, 

124 url=self.api_url, 

125 body=requests.exceptions.Timeout(), 

126 status=504, 

127 ) 

128 ) 

129 

130 if self.method_endpoint == "GET": 

131 response = self.client.get(self.endpoint_url) 

132 else: 

133 if self.data: 

134 response = self.client.post( 

135 self.endpoint_url, data=self.data 

136 ) 

137 else: 

138 response = self.client.post( 

139 self.endpoint_url, json=self.json 

140 ) 

141 

142 self.check_call_by_api_url(responses.calls) 

143 

144 assert response.status_code == 504 

145 

146 @responses.activate 

147 def test_connection_error(self): 

148 responses.add( 

149 responses.Response( 

150 method=self.method_api, 

151 url=self.api_url, 

152 body=requests.exceptions.ConnectionError(), 

153 status=500, 

154 ) 

155 ) 

156 

157 if self.method_endpoint == "GET": 

158 response = self.client.get(self.endpoint_url) 

159 else: 

160 if self.data: 

161 response = self.client.post( 

162 self.endpoint_url, data=self.data 

163 ) 

164 else: 

165 response = self.client.post( 

166 self.endpoint_url, json=self.json 

167 ) 

168 

169 self.check_call_by_api_url(responses.calls) 

170 

171 assert response.status_code == 502 

172 

173 @responses.activate 

174 def test_broken_json(self): 

175 # To test this I return no json from the server, this makes the 

176 # call to the function response.json() raise a ValueError exception 

177 responses.add( 

178 responses.Response( 

179 method=self.method_api, url=self.api_url, status=500 

180 ) 

181 ) 

182 

183 if self.method_endpoint == "GET": 

184 response = self.client.get(self.endpoint_url) 

185 else: 

186 if self.data: 

187 response = self.client.post( 

188 self.endpoint_url, data=self.data 

189 ) 

190 else: 

191 response = self.client.post( 

192 self.endpoint_url, json=self.json 

193 ) 

194 

195 self.check_call_by_api_url(responses.calls) 

196 

197 assert response.status_code == 502 

198 

199 @responses.activate 

200 def test_unknown_error(self): 

201 responses.add( 

202 responses.Response( 

203 method=self.method_api, 

204 url=self.api_url, 

205 json={}, 

206 status=500, 

207 ) 

208 ) 

209 

210 if self.method_endpoint == "GET": 

211 response = self.client.get(self.endpoint_url) 

212 else: 

213 if self.data: 

214 response = self.client.post( 

215 self.endpoint_url, data=self.data 

216 ) 

217 else: 

218 response = self.client.post( 

219 self.endpoint_url, json=self.json 

220 ) 

221 

222 self.check_call_by_api_url(responses.calls) 

223 

224 assert response.status_code == 502 

225 

226 @responses.activate 

227 def test_expired_macaroon(self): 

228 responses.add( 

229 responses.Response( 

230 method=self.method_api, 

231 url=self.api_url, 

232 json={}, 

233 status=401, 

234 headers={"WWW-Authenticate": "Macaroon needs_refresh=1"}, 

235 ) 

236 ) 

237 responses.add( 

238 responses.POST, 

239 "https://login.ubuntu.com/api/v2/tokens/refresh", 

240 json={"discharge_macaroon": "macaroon"}, 

241 status=200, 

242 ) 

243 

244 if self.method_endpoint == "GET": 

245 response = self.client.get(self.endpoint_url) 

246 else: 

247 if self.data: 

248 response = self.client.post( 

249 self.endpoint_url, data=self.data 

250 ) 

251 else: 

252 response = self.client.post( 

253 self.endpoint_url, json=self.json 

254 ) 

255 

256 called = responses.calls[len(responses.calls) - 1] 

257 self.assertEqual( 

258 "https://login.ubuntu.com/api/v2/tokens/refresh", 

259 called.request.url, 

260 ) 

261 

262 assert response.status_code == 302 

263 assert response.location == self._get_location() 

264 

265 class EndpointLoggedInErrorHandling(EndpointLoggedIn): 

266 @responses.activate 

267 def test_error_4xx(self): 

268 payload = {"error_list": []} 

269 responses.add( 

270 responses.Response( 

271 method=self.method_api, 

272 url=self.api_url, 

273 json=payload, 

274 status=400, 

275 ) 

276 ) 

277 

278 if self.method_endpoint == "GET": 

279 response = self.client.get(self.endpoint_url) 

280 else: 

281 if self.data: 

282 response = self.client.post( 

283 self.endpoint_url, data=self.data 

284 ) 

285 else: 

286 response = self.client.post( 

287 self.endpoint_url, json=self.json 

288 ) 

289 

290 self.check_call_by_api_url(responses.calls) 

291 

292 assert response.status_code == 502 

293 

294 @responses.activate 

295 def test_custom_error(self): 

296 payload = { 

297 "error_list": [ 

298 {"code": "error-code1"}, 

299 {"code": "error-code2"}, 

300 ] 

301 } 

302 responses.add( 

303 responses.Response( 

304 method=self.method_api, 

305 url=self.api_url, 

306 json=payload, 

307 status=400, 

308 ) 

309 ) 

310 

311 if self.method_endpoint == "GET": 

312 response = self.client.get(self.endpoint_url) 

313 else: 

314 if self.data: 

315 response = self.client.post( 

316 self.endpoint_url, data=self.data 

317 ) 

318 else: 

319 response = self.client.post( 

320 self.endpoint_url, json=self.json 

321 ) 

322 

323 self.check_call_by_api_url(responses.calls) 

324 

325 assert response.status_code == 502 

326 

327 @responses.activate 

328 def test_account_not_signed_agreement_logged_in(self): 

329 payload = { 

330 "error_list": [ 

331 { 

332 "code": "user-not-ready", 

333 "message": "has not signed agreement", 

334 } 

335 ] 

336 } 

337 responses.add( 

338 responses.Response( 

339 method=self.method_api, 

340 url=self.api_url, 

341 json=payload, 

342 status=403, 

343 ) 

344 ) 

345 

346 if self.method_endpoint == "GET": 

347 response = self.client.get(self.endpoint_url) 

348 else: 

349 if self.data: 

350 response = self.client.post( 

351 self.endpoint_url, data=self.data 

352 ) 

353 else: 

354 response = self.client.post( 

355 self.endpoint_url, json=self.json 

356 ) 

357 

358 self.check_call_by_api_url(responses.calls) 

359 

360 self.assertEqual(302, response.status_code) 

361 self.assertEqual("/account/agreement", response.location) 

362 

363 @responses.activate 

364 def test_account_no_username_logged_in(self): 

365 payload = { 

366 "error_list": [ 

367 { 

368 "code": "user-not-ready", 

369 "message": "missing store username", 

370 } 

371 ] 

372 } 

373 responses.add( 

374 responses.Response( 

375 method=self.method_api, 

376 url=self.api_url, 

377 json=payload, 

378 status=403, 

379 ) 

380 ) 

381 

382 if self.method_endpoint == "GET": 

383 response = self.client.get(self.endpoint_url) 

384 else: 

385 if self.data: 

386 response = self.client.post( 

387 self.endpoint_url, data=self.data 

388 ) 

389 else: 

390 response = self.client.post( 

391 self.endpoint_url, json=self.json 

392 ) 

393 

394 self.check_call_by_api_url(responses.calls) 

395 

396 self.assertEqual(302, response.status_code) 

397 self.assertEqual("/account/username", response.location)