Coverage for apps/outers/repositories/file_document_repository.py: 75%

83 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-22 19:03 +0000

1import io 

2from datetime import timedelta 

3from pathlib import Path 

4from typing import List, Optional, Dict, Any 

5from uuid import UUID 

6 

7import sqlalchemy 

8from minio.helpers import ObjectWriteResult 

9from sqlalchemy import exc 

10from sqlalchemy.engine import ScalarResult 

11from sqlmodel import select 

12from sqlmodel.ext.asyncio.session import AsyncSession 

13 

14from apps.inners.exceptions import repository_exception 

15from apps.inners.models.daos.document import Document 

16from apps.inners.models.daos.file_document import FileDocument 

17from apps.outers.datastores.temp_datastore import TempDatastore 

18from apps.outers.datastores.three_datastore import ThreeDatastore 

19 

20 

21class FileDocumentRepository: 

22 

23 def __init__( 

24 self, 

25 temp_datastore: TempDatastore, 

26 three_datastore: ThreeDatastore, 

27 

28 ): 

29 self.temp_datastore: TempDatastore = temp_datastore 

30 self.three_datastore: ThreeDatastore = three_datastore 

31 self.file_path: Path = self.temp_datastore.temp_datastore_setting.TEMP_DATASTORE_PATH / "file_documents" 

32 self.file_path.mkdir(exist_ok=True) 

33 

34 def save_file(self, relative_file_path: Path, file_data: bytes) -> Path: 

35 relative_file_path: Path = self.file_path / relative_file_path 

36 file_io = open(relative_file_path, "wb") 

37 file_io.write(file_data) 

38 file_io.close() 

39 

40 return relative_file_path 

41 

42 def read_file_data(self, relative_file_path: Path) -> bytes: 

43 relative_file_path: Path = self.file_path / relative_file_path 

44 file_io = open(relative_file_path, "rb") 

45 file_data = file_io.read() 

46 file_io.close() 

47 

48 return file_data 

49 

50 def remove_file(self, relative_file_path: Path): 

51 relative_file_path: Path = self.file_path / relative_file_path 

52 relative_file_path.unlink() 

53 

54 def put_object(self, object_name: str, data: bytes) -> ObjectWriteResult: 

55 return self.three_datastore.client.put_object( 

56 bucket_name="research-assistant-backend.file-documents", 

57 object_name=object_name, 

58 data=io.BytesIO(data), 

59 length=len(data) 

60 ) 

61 

62 def patch_object(self, old_object_name: str, new_object_name: str, new_data: bytes): 

63 self.remove_object( 

64 object_name=old_object_name 

65 ) 

66 self.put_object( 

67 object_name=new_object_name, 

68 data=new_data 

69 ) 

70 

71 def remove_object(self, object_name: str): 

72 self.three_datastore.client.remove_object( 

73 bucket_name="research-assistant-backend.file-documents", 

74 object_name=object_name 

75 ) 

76 

77 def get_object_url(self, object_name: str, response_headers: Dict[str, Any] = None) -> str: 

78 return self.three_datastore.client.get_presigned_url( 

79 bucket_name="research-assistant-backend.file-documents", 

80 object_name=object_name, 

81 response_headers=response_headers, 

82 method="GET", 

83 expires=timedelta(days=1) 

84 ) 

85 

86 def get_object_data(self, object_name: str) -> bytes: 

87 response = self.three_datastore.client.get_object( 

88 bucket_name="research-assistant-backend.file-documents", 

89 object_name=object_name, 

90 ) 

91 file_data: bytes = response.read() 

92 response.close() 

93 

94 return file_data 

95 

96 async def find_many_by_account_id_with_pagination( 

97 self, 

98 session: AsyncSession, 

99 account_id: UUID, 

100 page_position: int, 

101 page_size: int 

102 ) -> List[FileDocument]: 

103 found_file_document_result: ScalarResult = await session.exec( 

104 select(FileDocument) 

105 .join(Document, Document.id == FileDocument.id) 

106 .where(Document.account_id == account_id) 

107 .limit(page_size) 

108 .offset(page_size * (page_position - 1)) 

109 ) 

110 found_file_documents: List[FileDocument] = list(found_file_document_result.all()) 

111 

112 return found_file_documents 

113 

114 async def find_one_by_id_and_account_id(self, session: AsyncSession, id: UUID, account_id: UUID) -> FileDocument: 

115 try: 

116 found_file_document_result: ScalarResult = await session.exec( 

117 select(FileDocument) 

118 .join(Document, Document.id == FileDocument.id) 

119 .where(FileDocument.id == id) 

120 .where(Document.account_id == account_id) 

121 .limit(1) 

122 ) 

123 found_file_document: FileDocument = found_file_document_result.one() 

124 except sqlalchemy.exc.NoResultFound: 

125 raise repository_exception.NotFound() 

126 

127 return found_file_document 

128 

129 def create_one( 

130 self, 

131 session: AsyncSession, 

132 file_document_creator: FileDocument, 

133 file_data: Optional[bytes] = None 

134 ) -> FileDocument: 

135 try: 

136 session.add(file_document_creator) 

137 if file_data is not None: 

138 self.put_object( 

139 object_name=file_document_creator.file_name, 

140 data=file_data 

141 ) 

142 except sqlalchemy.exc.IntegrityError: 

143 raise repository_exception.IntegrityError() 

144 

145 return file_document_creator 

146 

147 async def patch_one_by_id_and_account_id( 

148 self, 

149 session: AsyncSession, 

150 id: UUID, 

151 account_id: UUID, 

152 file_document_patcher: FileDocument, 

153 file_data: Optional[bytes] = None 

154 ) -> FileDocument: 

155 found_file_document: FileDocument = await self.find_one_by_id_and_account_id( 

156 session=session, 

157 id=id, 

158 account_id=account_id 

159 ) 

160 found_file_document.patch_from(file_document_patcher.dict(exclude_none=True)) 

161 if file_data is not None: 

162 self.patch_object( 

163 old_object_name=found_file_document.file_name, 

164 new_object_name=file_document_patcher.file_name, 

165 new_data=file_data 

166 ) 

167 else: 

168 file_data = self.get_object_data( 

169 object_name=found_file_document.file_name 

170 ) 

171 self.patch_object( 

172 old_object_name=found_file_document.file_name, 

173 new_object_name=file_document_patcher.file_name, 

174 new_data=file_data 

175 ) 

176 

177 return found_file_document 

178 

179 async def delete_one_by_id_and_account_id(self, session: AsyncSession, id: UUID, account_id: UUID) -> FileDocument: 

180 found_file_document: FileDocument = await self.find_one_by_id_and_account_id( 

181 session=session, 

182 id=id, 

183 account_id=account_id 

184 ) 

185 await session.delete(found_file_document) 

186 self.remove_object(object_name=found_file_document.file_name) 

187 

188 return found_file_document