From 7097e2a09e78d86beed8a92bd446c10ad4a4c11b Mon Sep 17 00:00:00 2001 From: Penny <2500338766@qq.com> Date: Thu, 4 May 2023 23:33:09 +0800 Subject: [PATCH] =?UTF-8?q?feature-1.0-tx=E6=96=87=E7=94=9F=E5=9B=BE?= =?UTF-8?q?=EF=BC=9A=E6=B7=BB=E5=8A=A0=E8=85=BE=E8=AE=AF=E6=96=87=E7=94=9F?= =?UTF-8?q?=E5=9B=BE=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bnyer-common/bnyer-common-core/pom.xml | 6 + .../core/constant/RedisKeyConstant.java | 5 + .../img/controller/TiktokMiniController.java | 7 + .../bnyer/img/enums/AiPaintButtonEnum.java | 20 ++ .../img/service/StableDiffusionService.java | 6 + .../impl/StableDiffusionServiceImpl.java | 184 ++++++++++++++---- 6 files changed, 186 insertions(+), 42 deletions(-) create mode 100644 bnyer-services/bnyer-img/src/main/java/com/bnyer/img/enums/AiPaintButtonEnum.java diff --git a/bnyer-common/bnyer-common-core/pom.xml b/bnyer-common/bnyer-common-core/pom.xml index 0c67d44..dfa52ae 100644 --- a/bnyer-common/bnyer-common-core/pom.xml +++ b/bnyer-common/bnyer-common-core/pom.xml @@ -146,6 +146,12 @@ tencentcloud-sdk-java-tmt 3.1.715 + + + com.tencentcloudapi + tencentcloud-sdk-java-aiart + 3.1.715 + diff --git a/bnyer-common/bnyer-common-core/src/main/java/com/bnyer/common/core/constant/RedisKeyConstant.java b/bnyer-common/bnyer-common-core/src/main/java/com/bnyer/common/core/constant/RedisKeyConstant.java index 4e1a47b..9d20096 100644 --- a/bnyer-common/bnyer-common-core/src/main/java/com/bnyer/common/core/constant/RedisKeyConstant.java +++ b/bnyer-common/bnyer-common-core/src/main/java/com/bnyer/common/core/constant/RedisKeyConstant.java @@ -102,6 +102,11 @@ public class RedisKeyConstant { * 平台用户ai绘画键 */ public static final String PLATFORM_USER_AI_PAINT_KEY = "bnyer.img.user.aiPaint:"; + + /** + * ai绘画采用sd或tx文生图开关 + */ + public static final String AI_PAINT_BUTTON = "bnyer.img.paint.button"; /** * 艺术家上传键 */ diff --git a/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/controller/TiktokMiniController.java b/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/controller/TiktokMiniController.java index f98e478..d60772a 100644 --- a/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/controller/TiktokMiniController.java +++ b/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/controller/TiktokMiniController.java @@ -459,4 +459,11 @@ public class TiktokMiniController extends BaseController { paintCdkService.useCdk(dto.getCdk(),dto.getSource(),dto.getUserId(),dto.getAppType()); return AjaxResult.success(); } + + @ApiOperation(value="设置Ai绘画开关") + @GetMapping(value = "/setAiPaint/{buttonValue}") + public AjaxResult setAiPaint(@PathVariable int buttonValue){ + stableDiffusionService.setPaintButton(buttonValue); + return AjaxResult.success(); + } } diff --git a/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/enums/AiPaintButtonEnum.java b/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/enums/AiPaintButtonEnum.java new file mode 100644 index 0000000..dbc0764 --- /dev/null +++ b/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/enums/AiPaintButtonEnum.java @@ -0,0 +1,20 @@ +package com.bnyer.img.enums; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * ai绘画采用sd或tx文生图枚举类 + * @author chengkun + * @date 2023/4/19 17:46 + */ +@Getter +@AllArgsConstructor +public enum AiPaintButtonEnum { + SD(1,"stable diffusion"), + TX(0,"tx文生图"); + + private int code; + + private String msg; +} diff --git a/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/service/StableDiffusionService.java b/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/service/StableDiffusionService.java index 5b9cec4..d8d9025 100644 --- a/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/service/StableDiffusionService.java +++ b/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/service/StableDiffusionService.java @@ -23,4 +23,10 @@ public interface StableDiffusionService { * @param paintNUm 绘画次数 */ void addPlatformUserAiPaintNum(String appType,String platform,Long userId,int paintNUm); + + /** + * 设置AI绘画button值 + * @param buttonValue + */ + void setPaintButton(int buttonValue); } diff --git a/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/service/impl/StableDiffusionServiceImpl.java b/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/service/impl/StableDiffusionServiceImpl.java index fc2fbb3..6eca136 100644 --- a/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/service/impl/StableDiffusionServiceImpl.java +++ b/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/service/impl/StableDiffusionServiceImpl.java @@ -6,6 +6,7 @@ import com.alibaba.fastjson.JSONObject; import com.bnyer.common.core.constant.RedisKeyConstant; import com.bnyer.common.core.domain.AiPaint; import com.bnyer.common.core.dto.TextToImgDto; +import com.bnyer.common.core.exception.ServiceException; import com.bnyer.common.core.utils.TranslateUtils; import com.bnyer.common.core.utils.file.Base64ToMultipartFileUtils; import com.bnyer.common.core.vo.TextToImgVo; @@ -13,8 +14,13 @@ import com.bnyer.common.redis.service.RedisService; import com.bnyer.file.api.RemoteFileService; import com.bnyer.img.config.StableDiffusionConfig; import com.bnyer.img.config.TencentTranslateConfig; +import com.bnyer.img.enums.AiPaintButtonEnum; import com.bnyer.img.service.AiPaintService; import com.bnyer.img.service.StableDiffusionService; +import com.tencentcloudapi.aiart.v20221229.AiartClient; +import com.tencentcloudapi.aiart.v20221229.models.ResultConfig; +import com.tencentcloudapi.aiart.v20221229.models.TextToImageRequest; +import com.tencentcloudapi.aiart.v20221229.models.TextToImageResponse; import com.tencentcloudapi.common.Credential; import com.tencentcloudapi.common.exception.TencentCloudSDKException; import com.tencentcloudapi.common.profile.ClientProfile; @@ -29,10 +35,7 @@ import org.springframework.web.client.RestTemplate; import org.springframework.web.multipart.MultipartFile; import java.text.SimpleDateFormat; -import java.util.Date; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; @Service @Slf4j @@ -85,43 +88,63 @@ public class StableDiffusionServiceImpl implements StableDiffusionService { @Override public TextToImgVo textToImg(TextToImgDto param) { - try{ - String prompt = ""; - //判断prompt是否包含中文,中文则翻译,否则跳过 - if(TranslateUtils.isContainChinese(param.getPrompt())){ - //调用翻译api - prompt = translate(param.getPrompt()); - }else{ - prompt = param.getPrompt(); - } - System.out.println(prompt); - - //TODO 根据选择的风格来选择模型 - Map map = new HashMap<>(); - map.put("width",param.getWidth() == null ? 512 : param.getWidth()); - map.put("height",param.getHeight() == null ? 512 : param.getHeight()); - map.put("prompt", prompt); - //map.put("prompt", param.getPrompt()); - map.put("seed",-1); - map.put("batch_size",1); - map.put("cfg_scale",7); - map.put("restore_faces",false); - map.put("tiling",false); - map.put("eta",0); - map.put("sampler_index","DPM++ 2S a Karras"); - //map.put("sampler_index",param.getSamplerIndex()); - map.put("steps",25); - map.put("negative_prompt","easynegative"); - //log.info("请求stable_diffusion请求体为:【{}】", JSON.toJSONString(map)); - JSONObject jsonObject = restTemplate.postForObject(stableDiffusionConfig.getTxt2ImgUrl(), map, JSONObject.class); - //log.info("请求stable_diffusion响应体的为:【{}】", JSON.toJSONString(jsonObject)); - TextToImgVo img = new TextToImgVo(); - if(jsonObject != null && jsonObject.getJSONArray("images").size() > 0){ - List images = jsonObject.getJSONArray("images").toJavaList(String.class); - img.setImages(images); + + //配置控制采用sd还是tx文生图 + int button = redisService.getCacheObject(RedisKeyConstant.AI_PAINT_BUTTON); + if(button == AiPaintButtonEnum.TX.getCode()){ + //采用腾讯文生图 + try{ + Credential cred = new Credential(tencentTranslateConfig.getSecretId(), tencentTranslateConfig.getSecretKey()); + // 实例化一个http选项,可选的,没有特殊需求可以跳过 + HttpProfile httpProfile = new HttpProfile(); + httpProfile.setEndpoint("aiart.tencentcloudapi.com"); + // 实例化一个client选项,可选的,没有特殊需求可以跳过 + ClientProfile clientProfile = new ClientProfile(); + clientProfile.setHttpProfile(httpProfile); + // 实例化要请求产品的client对象,clientProfile是可选的 + AiartClient client = new AiartClient(cred, "ap-guangzhou", clientProfile); + // 实例化一个请求对象,每个接口都会对应一个request对象 + TextToImageRequest req = new TextToImageRequest(); + req.setPrompt(param.getPrompt()); + //请求风格 + String[] styles1 = new String[0]; + switch (param.getStyleName()) { + case "细腻": + styles1 = new String[]{"110"}; + break; + case "卡通": + styles1 = new String[]{"201"}; + break; + case "科幻": + styles1 = new String[]{"114"}; + break; + case "中国风": + styles1 = new String[]{"101"}; + break; + } + req.setStyles(styles1); + + //画布大小 + ResultConfig resultConfig1 = new ResultConfig(); + if(param.getWidth() == 512 && param.getHeight() == 512){ + resultConfig1.setResolution("768:768"); + }else if(param.getWidth() == 512 && param.getHeight() == 1024){ + resultConfig1.setResolution("768:1024"); + }else{ + resultConfig1.setResolution("1024:768"); + } + req.setResultConfig(resultConfig1); + + // 返回的resp是一个TextToImageResponse的实例,与请求对象对应 + TextToImageResponse resp = client.TextToImage(req); + TextToImgVo img = new TextToImgVo(); + String images = resp.getResultImage(); + List list = new ArrayList<>(); + list.add(images); + img.setImages(list); String paintId = null; Date paintTime = null; - for (String image : images) { + for (String image : list) { //base64转file MultipartFile file = new Base64ToMultipartFileUtils(image, "data:image/png;base64", "file", "tempSDImg"); //上传图片到七牛云/minio @@ -149,10 +172,82 @@ public class StableDiffusionServiceImpl implements StableDiffusionService { } img.setPaintId(paintId); img.setPaintTime(paintTime); + return img; + } catch (TencentCloudSDKException e) { + log.error("腾讯文生图调用错误!"+e.getMessage()); + throw new ServiceException(e.getMessage(),500); + } + }else{ + //采用sd + try{ + String prompt = ""; + //判断prompt是否包含中文,中文则翻译,否则跳过 + if(TranslateUtils.isContainChinese(param.getPrompt())){ + //调用翻译api + prompt = translate(param.getPrompt()); + }else{ + prompt = param.getPrompt(); + } + System.out.println(prompt); + + //TODO 根据选择的风格来选择模型 + Map map = new HashMap<>(); + map.put("width",param.getWidth() == null ? 512 : param.getWidth()); + map.put("height",param.getHeight() == null ? 512 : param.getHeight()); + map.put("prompt", prompt); + //map.put("prompt", param.getPrompt()); + map.put("seed",-1); + map.put("batch_size",1); + map.put("cfg_scale",7); + map.put("restore_faces",false); + map.put("tiling",false); + map.put("eta",0); + map.put("sampler_index","DPM++ 2S a Karras"); + //map.put("sampler_index",param.getSamplerIndex()); + map.put("steps",25); + map.put("negative_prompt","easynegative,nsfw,naked"); + //log.info("请求stable_diffusion请求体为:【{}】", JSON.toJSONString(map)); + JSONObject jsonObject = restTemplate.postForObject(stableDiffusionConfig.getTxt2ImgUrl(), map, JSONObject.class); + //log.info("请求stable_diffusion响应体的为:【{}】", JSON.toJSONString(jsonObject)); + TextToImgVo img = new TextToImgVo(); + if(jsonObject != null && jsonObject.getJSONArray("images").size() > 0){ + List images = jsonObject.getJSONArray("images").toJavaList(String.class); + img.setImages(images); + String paintId = null; + Date paintTime = null; + for (String image : images) { + //base64转file + MultipartFile file = new Base64ToMultipartFileUtils(image, "data:image/png;base64", "file", "tempSDImg"); + //上传图片到七牛云/minio + String imgStr = remoteFileService.uploadBanner(file).getData(); + //保存生辰该图片到ai绘画表 + AiPaint paint = new AiPaint(); + //paint.setId(); 主键改成雪花算法后启用 + paintId = IdUtil.getSnowflakeNextIdStr(); + paintTime = new Date(); + paint.setPaintId(paintId); + paint.setCreateTime(paintTime); + paint.setImgUrl(imgStr); + paint.setPrompt(param.getPrompt()); + paint.setModel(param.getModelName()); + paint.setStyleName(param.getStyleName()); + paint.setHeight(param.getHeight() == null ? "512" : String.valueOf(param.getHeight())); + paint.setWidth(param.getWidth() == null ? "512" : String.valueOf(param.getWidth())); + paint.setIsShow("1"); + paint.setSource(param.getPlatform()); + paint.setPainterId(param.getPainterId()); + paint.setPainterName(param.getPainterName()); + aiPaintService.insert(paint); + //写入ai绘画次数 + writePlatformUserAiPaintNum(param.getAppType(),param.getPlatform(),param.getPainterId()); + } + img.setPaintId(paintId); + img.setPaintTime(paintTime); + } + return img; + }catch (Exception e){ + log.error("文本翻译错误!"+e); } - return img; - }catch (Exception e){ - log.error("文本翻译错误!"+e); } return null; } @@ -186,4 +281,9 @@ public class StableDiffusionServiceImpl implements StableDiffusionService { String redisKey = RedisKeyConstant.PLATFORM_USER_AI_PAINT_KEY + date; redisService.hashIncr(redisKey, hashKey, -patinNum); } + + @Override + public void setPaintButton(int buttonValue) { + redisService.setCacheObject(RedisKeyConstant.AI_PAINT_BUTTON,buttonValue); + } }