Coverage for tests/publisher/endpoint_testing.py: 92%

160 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-15 22:43 +0000

1from unittest.mock import patch 

2 

3import requests 

4 

5import pymacaroons 

6import responses 

7from flask_testing import TestCase 

8from webapp.app import create_app 

9from webapp.authentication import get_authorization_header 

10 

11# Make sure tests fail on stray responses. 

12responses.mock.assert_all_requests_are_fired = True 

13 

14 

15class BaseTestCases: 

16 """ 

17 This class has a set of test classes that should be inherited by endpoint 

18 that have authentication. 

19 

20 It is also used to avoid unittest to run this tests file. 

21 """ 

22 

23 class BaseAppTesting(TestCase): 

24 def setUp(self, snap_name, api_url, endpoint_url): 

25 self.snap_name = snap_name 

26 self.api_url = api_url 

27 self.endpoint_url = endpoint_url 

28 

29 def tearDown(self): 

30 responses.reset() 

31 

32 def create_app(self): 

33 app = create_app(testing=True) 

34 app.secret_key = "secret_key" 

35 app.config["WTF_CSRF_METHODS"] = [] 

36 

37 return app 

38 

39 def _get_location(self): 

40 return "{}".format(self.endpoint_url) 

41 

42 def _log_in(self, client): 

43 """Emulates test client login in the store. 

44 

45 Fill current session with `openid`, `macaroon_root` and 

46 `macaroon_discharge`. 

47 

48 Return the expected `Authorization` header for further verification 

49 in API requests. 

50 """ 

51 # Basic root/discharge macaroons pair. 

52 root = pymacaroons.Macaroon("test", "testing", "a_key") 

53 root.add_third_party_caveat("3rd", "a_caveat-key", "a_ident") 

54 discharge = pymacaroons.Macaroon("3rd", "a_ident", "a_caveat_key") 

55 

56 with client.session_transaction() as s: 

57 s["publisher"] = { 

58 "image": None, 

59 "nickname": "Toto", 

60 "fullname": "El Toto", 

61 "email": "testing@testing.com", 

62 "stores": [], 

63 } 

64 s["macaroon_root"] = root.serialize() 

65 s["macaroon_discharge"] = discharge.serialize() 

66 

67 return get_authorization_header( 

68 root.serialize(), discharge.serialize() 

69 ) 

70 

71 def check_call_by_api_url(self, calls): 

72 found = False 

73 for called in calls: 

74 if self.api_url == called.request.url: 

75 found = True 

76 self.assertEqual( 

77 self.authorization, 

78 called.request.headers.get("Authorization"), 

79 ) 

80 

81 assert found 

82 

83 class EndpointLoggedOut(BaseAppTesting): 

84 def setUp(self, snap_name, endpoint_url, method_endpoint="GET"): 

85 self.method_endpoint = method_endpoint 

86 super().setUp(snap_name, None, endpoint_url) 

87 

88 def test_access_not_logged_in(self): 

89 if self.method_endpoint == "GET": 

90 response = self.client.get(self.endpoint_url) 

91 else: 

92 response = self.client.post(self.endpoint_url, data={}) 

93 

94 self.assertEqual(302, response.status_code) 

95 self.assertEqual( 

96 "/login?next={}".format(self.endpoint_url), 

97 response.location, 

98 ) 

99 

100 class EndpointLoggedIn(BaseAppTesting): 

101 def setUp( 

102 self, 

103 snap_name, 

104 endpoint_url, 

105 api_url, 

106 method_endpoint="GET", 

107 method_api="GET", 

108 data=None, 

109 json=None, 

110 ): 

111 super().setUp( 

112 snap_name=snap_name, api_url=api_url, endpoint_url=endpoint_url 

113 ) 

114 

115 # Stub the blueprint-level "has releases" gate so each test does 

116 # not need to mock the extra dashboard call. Tests that exercise 

117 # the gate directly can override this in their own setUp. 

118 self._release_history_patcher = patch( 

119 "webapp.decorators._dashboard.snap_release_history", 

120 return_value={"revisions": [{"revision": 1}]}, 

121 ) 

122 self._release_history_patcher.start() 

123 self.addCleanup(self._release_history_patcher.stop) 

124 

125 self.method_endpoint = method_endpoint 

126 self.method_api = method_api 

127 self.data = data 

128 self.json = json 

129 self.authorization = self._log_in(self.client) 

130 

131 @responses.activate 

132 def test_timeout(self): 

133 responses.add( 

134 responses.Response( 

135 method=self.method_api, 

136 url=self.api_url, 

137 body=requests.exceptions.Timeout(), 

138 status=504, 

139 ) 

140 ) 

141 

142 if self.method_endpoint == "GET": 

143 response = self.client.get(self.endpoint_url) 

144 else: 

145 if self.data: 

146 response = self.client.post( 

147 self.endpoint_url, data=self.data 

148 ) 

149 else: 

150 response = self.client.post( 

151 self.endpoint_url, json=self.json 

152 ) 

153 

154 self.check_call_by_api_url(responses.calls) 

155 

156 assert response.status_code == 504 

157 

158 @responses.activate 

159 def test_connection_error(self): 

160 responses.add( 

161 responses.Response( 

162 method=self.method_api, 

163 url=self.api_url, 

164 body=requests.exceptions.ConnectionError(), 

165 status=500, 

166 ) 

167 ) 

168 

169 if self.method_endpoint == "GET": 

170 response = self.client.get(self.endpoint_url) 

171 else: 

172 if self.data: 

173 response = self.client.post( 

174 self.endpoint_url, data=self.data 

175 ) 

176 else: 

177 response = self.client.post( 

178 self.endpoint_url, json=self.json 

179 ) 

180 

181 self.check_call_by_api_url(responses.calls) 

182 

183 assert response.status_code == 502 

184 

185 @responses.activate 

186 def test_broken_json(self): 

187 # To test this I return no json from the server, this makes the 

188 # call to the function response.json() raise a ValueError exception 

189 responses.add( 

190 responses.Response( 

191 method=self.method_api, url=self.api_url, status=500 

192 ) 

193 ) 

194 

195 if self.method_endpoint == "GET": 

196 response = self.client.get(self.endpoint_url) 

197 else: 

198 if self.data: 

199 response = self.client.post( 

200 self.endpoint_url, data=self.data 

201 ) 

202 else: 

203 response = self.client.post( 

204 self.endpoint_url, json=self.json 

205 ) 

206 

207 self.check_call_by_api_url(responses.calls) 

208 

209 assert response.status_code == 502 

210 

211 @responses.activate 

212 def test_unknown_error(self): 

213 responses.add( 

214 responses.Response( 

215 method=self.method_api, 

216 url=self.api_url, 

217 json={}, 

218 status=500, 

219 ) 

220 ) 

221 

222 if self.method_endpoint == "GET": 

223 response = self.client.get(self.endpoint_url) 

224 else: 

225 if self.data: 

226 response = self.client.post( 

227 self.endpoint_url, data=self.data 

228 ) 

229 else: 

230 response = self.client.post( 

231 self.endpoint_url, json=self.json 

232 ) 

233 

234 self.check_call_by_api_url(responses.calls) 

235 

236 assert response.status_code == 502 

237 

238 @responses.activate 

239 def test_expired_macaroon(self): 

240 responses.add( 

241 responses.Response( 

242 method=self.method_api, 

243 url=self.api_url, 

244 json={}, 

245 status=401, 

246 headers={"WWW-Authenticate": "Macaroon needs_refresh=1"}, 

247 ) 

248 ) 

249 responses.add( 

250 responses.POST, 

251 "https://login.ubuntu.com/api/v2/tokens/refresh", 

252 json={"discharge_macaroon": "macaroon"}, 

253 status=200, 

254 ) 

255 

256 if self.method_endpoint == "GET": 

257 response = self.client.get(self.endpoint_url) 

258 else: 

259 if self.data: 

260 response = self.client.post( 

261 self.endpoint_url, data=self.data 

262 ) 

263 else: 

264 response = self.client.post( 

265 self.endpoint_url, json=self.json 

266 ) 

267 

268 called = responses.calls[len(responses.calls) - 1] 

269 self.assertEqual( 

270 "https://login.ubuntu.com/api/v2/tokens/refresh", 

271 called.request.url, 

272 ) 

273 

274 assert response.status_code == 302 

275 assert response.location == self._get_location() 

276 

277 class EndpointLoggedInErrorHandling(EndpointLoggedIn): 

278 @responses.activate 

279 def test_error_4xx(self): 

280 payload = {"error_list": []} 

281 responses.add( 

282 responses.Response( 

283 method=self.method_api, 

284 url=self.api_url, 

285 json=payload, 

286 status=400, 

287 ) 

288 ) 

289 

290 if self.method_endpoint == "GET": 

291 response = self.client.get(self.endpoint_url) 

292 else: 

293 if self.data: 

294 response = self.client.post( 

295 self.endpoint_url, data=self.data 

296 ) 

297 else: 

298 response = self.client.post( 

299 self.endpoint_url, json=self.json 

300 ) 

301 

302 self.check_call_by_api_url(responses.calls) 

303 

304 assert response.status_code == 502 

305 

306 @responses.activate 

307 def test_custom_error(self): 

308 payload = { 

309 "error_list": [ 

310 {"code": "error-code1"}, 

311 {"code": "error-code2"}, 

312 ] 

313 } 

314 responses.add( 

315 responses.Response( 

316 method=self.method_api, 

317 url=self.api_url, 

318 json=payload, 

319 status=400, 

320 ) 

321 ) 

322 

323 if self.method_endpoint == "GET": 

324 response = self.client.get(self.endpoint_url) 

325 else: 

326 if self.data: 

327 response = self.client.post( 

328 self.endpoint_url, data=self.data 

329 ) 

330 else: 

331 response = self.client.post( 

332 self.endpoint_url, json=self.json 

333 ) 

334 

335 self.check_call_by_api_url(responses.calls) 

336 

337 assert response.status_code == 502 

338 

339 @responses.activate 

340 def test_account_not_signed_agreement_logged_in(self): 

341 payload = { 

342 "error_list": [ 

343 { 

344 "code": "user-not-ready", 

345 "message": "has not signed agreement", 

346 } 

347 ] 

348 } 

349 responses.add( 

350 responses.Response( 

351 method=self.method_api, 

352 url=self.api_url, 

353 json=payload, 

354 status=403, 

355 ) 

356 ) 

357 

358 if self.method_endpoint == "GET": 

359 response = self.client.get(self.endpoint_url) 

360 else: 

361 if self.data: 

362 response = self.client.post( 

363 self.endpoint_url, data=self.data 

364 ) 

365 else: 

366 response = self.client.post( 

367 self.endpoint_url, json=self.json 

368 ) 

369 

370 self.check_call_by_api_url(responses.calls) 

371 

372 self.assertEqual(302, response.status_code) 

373 self.assertEqual("/account/agreement", response.location) 

374 

375 @responses.activate 

376 def test_account_no_username_logged_in(self): 

377 payload = { 

378 "error_list": [ 

379 { 

380 "code": "user-not-ready", 

381 "message": "missing store username", 

382 } 

383 ] 

384 } 

385 responses.add( 

386 responses.Response( 

387 method=self.method_api, 

388 url=self.api_url, 

389 json=payload, 

390 status=403, 

391 ) 

392 ) 

393 

394 if self.method_endpoint == "GET": 

395 response = self.client.get(self.endpoint_url) 

396 else: 

397 if self.data: 

398 response = self.client.post( 

399 self.endpoint_url, data=self.data 

400 ) 

401 else: 

402 response = self.client.post( 

403 self.endpoint_url, json=self.json 

404 ) 

405 

406 self.check_call_by_api_url(responses.calls) 

407 

408 self.assertEqual(302, response.status_code) 

409 self.assertEqual("/account/username", response.location)