using AutoMapper; using Bit.Core.Entities; using Bit.Core.Repositories; using Bit.Infrastructure.EntityFramework.Repositories.Queries; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; namespace Bit.Infrastructure.EntityFramework.Repositories; public abstract class Repository : BaseEntityFrameworkRepository, IRepository where TId : IEquatable where T : class, ITableObject where TEntity : class, ITableObject { public Repository(IServiceScopeFactory serviceScopeFactory, IMapper mapper, Func> getDbSet) : base(serviceScopeFactory, mapper) { GetDbSet = getDbSet; } protected Func> GetDbSet { get; private set; } public virtual async Task GetByIdAsync(TId id) { using (var scope = ServiceScopeFactory.CreateScope()) { var dbContext = GetDatabaseContext(scope); var entity = await GetDbSet(dbContext).FindAsync(id); return Mapper.Map(entity); } } public virtual async Task CreateAsync(T obj) { using (var scope = ServiceScopeFactory.CreateScope()) { var dbContext = GetDatabaseContext(scope); obj.SetNewId(); var entity = Mapper.Map(obj); await dbContext.AddAsync(entity); await dbContext.SaveChangesAsync(); obj.Id = entity.Id; return obj; } } public virtual async Task ReplaceAsync(T obj) { using (var scope = ServiceScopeFactory.CreateScope()) { var dbContext = GetDatabaseContext(scope); var entity = await GetDbSet(dbContext).FindAsync(obj.Id); if (entity != null) { var mappedEntity = Mapper.Map(obj); dbContext.Entry(entity).CurrentValues.SetValues(mappedEntity); await dbContext.SaveChangesAsync(); } } } public virtual async Task UpsertAsync(T obj) { if (obj.Id.Equals(default(TId))) { await CreateAsync(obj); } else { await ReplaceAsync(obj); } } public virtual async Task DeleteAsync(T obj) { using (var scope = ServiceScopeFactory.CreateScope()) { var dbContext = GetDatabaseContext(scope); var entity = Mapper.Map(obj); dbContext.Remove(entity); await dbContext.SaveChangesAsync(); } } public virtual async Task RefreshDb() { using (var scope = ServiceScopeFactory.CreateScope()) { var context = GetDatabaseContext(scope); await context.Database.EnsureDeletedAsync(); await context.Database.EnsureCreatedAsync(); } } public virtual async Task> CreateMany(List objs) { using (var scope = ServiceScopeFactory.CreateScope()) { var entities = new List(); foreach (var o in objs) { o.SetNewId(); var entity = Mapper.Map(o); entities.Add(entity); } var dbContext = GetDatabaseContext(scope); await GetDbSet(dbContext).AddRangeAsync(entities); await dbContext.SaveChangesAsync(); return objs; } } public IQueryable Run(IQuery query) { using (var scope = ServiceScopeFactory.CreateScope()) { var dbContext = GetDatabaseContext(scope); return query.Run(dbContext); } } }